├── .gitmodules ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── WORKSPACE ├── tensorflow_fold ├── BUILD ├── blocks │ ├── BUILD │ ├── block_compiler.py │ ├── block_compiler_test.py │ ├── blocks.py │ ├── blocks_test.py │ ├── examples │ │ ├── BUILD │ │ ├── calculator │ │ │ ├── BUILD │ │ │ ├── model.py │ │ │ └── train.py │ │ ├── fizzbuzz │ │ │ ├── BUILD │ │ │ └── fizzbuzz.py │ │ ├── language_id │ │ │ ├── BUILD │ │ │ ├── fetch_datasets.sh │ │ │ └── language_id.py │ │ ├── mnist │ │ │ ├── BUILD │ │ │ └── mnist.py │ │ ├── plan.py │ │ └── sentiment │ │ │ ├── BUILD │ │ │ ├── eval.py │ │ │ ├── filter_glove.py │ │ │ ├── sentiment.py │ │ │ └── train.py │ ├── layers.py │ ├── layers_test.py │ ├── loom_ops.py │ ├── metrics.py │ ├── metrics_test.py │ ├── plan.py │ ├── plan_test.py │ ├── result_types.py │ ├── result_types_test.py │ ├── test_lib.py │ ├── util.py │ └── util_test.py ├── fold.bzl ├── g3doc │ ├── animation.gif │ ├── blocks.md │ ├── cc │ │ ├── ClassWeaver.md │ │ ├── ClassWeaverOpBase.md │ │ └── index.md │ ├── index.md │ ├── proto.md │ ├── py │ │ ├── loom.md │ │ ├── td.md │ │ └── wiring.png │ ├── quick.ipynb │ ├── running.md │ ├── sentiment.ipynb │ ├── setup.md │ ├── sources.md │ └── types.md ├── llgtm │ ├── BUILD │ ├── README.md │ ├── backend │ │ ├── eigen_evaluator.cc │ │ ├── eigen_evaluator.h │ │ ├── eigen_evaluator_client.h │ │ ├── eigen_graph_implementation.cc │ │ ├── eigen_graph_implementation.h │ │ ├── llgtm_nodes.inc │ │ ├── tf_evaluator.cc │ │ ├── tf_evaluator.h │ │ ├── tf_evaluator_client.h │ │ └── tf_nodes.inc │ ├── device.cc │ ├── device.h │ ├── dimensions.h │ ├── examples │ │ ├── BUILD │ │ ├── character_rnn.cc │ │ ├── parsetree.h │ │ ├── run_all_examples.sh │ │ └── tree_rnn.cc │ ├── gradients.cc │ ├── gradients.h │ ├── graph.cc │ ├── graph.h │ ├── graph_evaluator.h │ ├── graph_implementation.h │ ├── layers.cc │ ├── layers.h │ ├── llgtm.h │ ├── platform │ │ ├── external.h │ │ └── platform.h │ ├── tensor.cc │ ├── tensor.h │ ├── tensor_opcodes.cc │ ├── tensor_opcodes.h │ ├── tensor_ops.h │ ├── tensor_ops_impl.h │ ├── test │ │ ├── evaluator_test.h │ │ ├── evaluator_test_eigen.cc │ │ ├── evaluator_test_eigen_gpu.cc │ │ ├── evaluator_test_tf.cc │ │ ├── gradients_test.h │ │ ├── gradients_test_eigen.cc │ │ ├── gradients_test_eigen_gpu.cc │ │ ├── gradients_test_tf.cc │ │ ├── graph_nocompile.cc │ │ ├── graph_nocompile_test.py │ │ ├── graph_test.cc │ │ └── test_framework.h │ ├── trainer.cc │ ├── trainer.h │ ├── util.h │ └── variable_initializers.h ├── loom │ ├── BUILD │ ├── benchmarks │ │ ├── BUILD │ │ └── iclr_2017_benchmark.py │ ├── calculator_example │ │ ├── BUILD │ │ ├── calculator.proto │ │ ├── calculator.py │ │ ├── calculator_test.py │ │ ├── eval.py │ │ ├── helpers.py │ │ ├── make_dataset.py │ │ ├── model.py │ │ ├── model_test.py │ │ └── train.py │ ├── deserializing_weaver_op.cc │ ├── deserializing_weaver_op.py │ ├── loom.proto │ ├── loom.py │ ├── loom_test.py │ ├── platform.h │ ├── python │ │ └── weaver.swig │ ├── weaver.cc │ ├── weaver.h │ ├── weaver_op_base.cc │ ├── weaver_op_base.h │ ├── weaver_op_base.py │ └── weaver_test.cc ├── public │ ├── BUILD │ ├── blocks.py │ ├── loom.h │ └── loom.py ├── run_all_examples.sh ├── util │ ├── BUILD │ ├── build_pip_package.sh │ ├── pip.sh │ ├── proto_test.py │ ├── proto_tools.cc │ ├── setup.py │ ├── test.proto │ ├── test3.proto │ └── test_main.cc └── workspace.bzl └── tools └── bazel.rc /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "tensorflow"] 2 | path = tensorflow 3 | url = https://github.com/tensorflow/tensorflow 4 | branch = r1.0 5 | [submodule "abseil-cpp"] 6 | path = abseil-cpp 7 | url = https://github.com/abseil/abseil-cpp.git 8 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Want to contribute? Great! First, read this page (including the small print at the end). 2 | 3 | ### Before you contribute 4 | Before we can use your code, you must sign the 5 | [Google Individual Contributor License Agreement] 6 | (https://cla.developers.google.com/about/google-individual) 7 | (CLA), which you can do online. The CLA is necessary mainly because you own the 8 | copyright to your changes, even after your contribution becomes part of our 9 | codebase, so we need your permission to use and distribute your code. We also 10 | need to be sure of various other things—for instance that you'll tell us if you 11 | know that your code infringes on other people's patents. You don't have to sign 12 | the CLA until after you've submitted your code for review and a member has 13 | approved it, but you must do it before we can put your code into our codebase. 14 | Before you start working on a larger contribution, you should get in touch with 15 | us first through the issue tracker with your idea so that we can help out and 16 | possibly guide you. Coordinating up front makes it much easier to avoid 17 | frustration later on. 18 | 19 | ### Code reviews 20 | All submissions, including submissions by project members, require review. We 21 | use Github pull requests for this purpose. 22 | 23 | ### The small print 24 | Contributions made by corporations are covered by a different agreement than 25 | the one above, the 26 | [Software Grant and Corporate Contributor License Agreement] 27 | (https://cla.developers.google.com/about/google-corporate). 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow Fold 2 | 3 | TensorFlow Fold is a library for 4 | creating [TensorFlow](https://www.tensorflow.org) models that consume structured 5 | data, where the structure of the computation graph depends on the structure of 6 | the input data. For example, [this model](tensorflow_fold/g3doc/sentiment.ipynb) 7 | implements [TreeLSTMs](https://arxiv.org/abs/1503.00075) for sentiment analysis 8 | on parse trees of arbitrary shape/size/depth. 9 | 10 | Fold implements [*dynamic batching*](https://arxiv.org/abs/1702.02181). 11 | Batches of arbitrarily shaped computation graphs are transformed to produce a 12 | static computation graph. This graph has the same structure regardless of what 13 | input it receives, and can be executed efficiently by TensorFlow. 14 | 15 | * [Download and Setup](tensorflow_fold/g3doc/setup.md) 16 | * [Quick Start Notebook](tensorflow_fold/g3doc/quick.ipynb) 17 | * [Documentation](tensorflow_fold/g3doc/index.md) 18 | 19 | ![animation](tensorflow_fold/g3doc/animation.gif) 20 | 21 | This animation shows a [recursive neural network](https://en.wikipedia.org/wiki/Recursive_neural_network) run with dynamic batching. Operations of the same type appearing at the same depth in the computation graph (indicated by color in the animiation) are batched together regardless of whether or not they appear in the same parse tree. The [Embed](tensorflow_fold/g3doc/py/td.md#td.Embedding) operation converts [words to vector representations](https://www.tensorflow.org/tutorials/word2vec/). The fully connected ([FC](tensorflow_fold/g3doc/py/td.md#td.FC)) operation combines word vectors to form vector representations of phrases. The output of the network is a vector representation of an entire sentence. Although only a single parse tree of a sentence is shown, the same network can run, and batch together operations, over multiple parse trees of arbitrary shapes and sizes. The TensorFlow `concat`, `while_loop`, and `gather` ops are created once, prior to variable initialization, by [Loom](tensorflow_fold/g3doc/py/loom.md), the low-level API for TensorFlow Fold. 22 | 23 | If you'd like to contribute to TensorFlow Fold, please review the 24 | [contribution guidelines](CONTRIBUTING.md). 25 | 26 | TensorFlow Fold is not an official Google product. 27 | -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- 1 | workspace(name = "org_tensorflow_fold") 2 | 3 | local_repository( 4 | name = "org_tensorflow", 5 | path = "tensorflow", 6 | ) 7 | 8 | local_repository( 9 | name = "com_google_absl", 10 | path = "abseil-cpp", 11 | ) 12 | 13 | # TensorFlow depends on "io_bazel_rules_closure" so we need this here. 14 | # Needs to be kept in sync with the same target in TensorFlow's WORKSPACE file. 15 | http_archive( 16 | name = "io_bazel_rules_closure", 17 | sha256 = "110fe68753413777944b473c25eed6368c4a0487cee23a7bac1b13cc49d3e257", 18 | strip_prefix = "rules_closure-4af89ef1db659eb41f110df189b67d4cf14073e1", 19 | urls = [ 20 | "http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz", 21 | "https://github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz", # 2017-08-28 22 | ], 23 | ) 24 | 25 | # Import all of the tensorflow dependencies. 26 | load('//tensorflow_fold:workspace.bzl', 'tf_fold_workspace') 27 | tf_fold_workspace() 28 | -------------------------------------------------------------------------------- /tensorflow_fold/BUILD: -------------------------------------------------------------------------------- 1 | # TensorFlow Fold. 2 | 3 | licenses(["notice"]) # Apache 2.0 4 | 5 | # This is needed for open-source Fold installation. 6 | filegroup( 7 | name = "headers", 8 | srcs = glob(["**/*.h"]), 9 | visibility = ["//tensorflow_fold:__subpackages__"], 10 | ) 11 | -------------------------------------------------------------------------------- /tensorflow_fold/blocks/BUILD: -------------------------------------------------------------------------------- 1 | # TensorFlow Fold is a high-level library for constructing models that is built 2 | # on top of TensorLoom. 3 | 4 | licenses(["notice"]) # Apache 2.0 5 | 6 | load("//tensorflow_fold:fold.bzl", "fold_py_library", "fold_py_test") 7 | 8 | package( 9 | default_visibility = [ 10 | "//tensorflow_fold/blocks/examples:__subpackages__", 11 | "//tensorflow_fold/public:__subpackages__", 12 | ], 13 | ) 14 | 15 | fold_py_library( 16 | name = "blocks", 17 | srcs = [ 18 | "block_compiler.py", 19 | "blocks.py", 20 | "layers.py", 21 | "loom_ops.py", 22 | "metrics.py", 23 | ], 24 | cc_deps = [ 25 | "//tensorflow_fold/util:proto_tools", 26 | ], 27 | deps = [ 28 | ":result_types", 29 | ":util", 30 | # numpy", 31 | "@six_archive//:six", 32 | "@org_tensorflow//tensorflow:tensorflow_py", 33 | "//tensorflow_fold/public:loom", 34 | ], 35 | ) 36 | 37 | fold_py_library( 38 | name = "test_lib", 39 | testonly = 1, 40 | srcs = [ 41 | "test_lib.py", 42 | ], 43 | cc_deps = [ 44 | "//tensorflow_fold/util:proto_tools", 45 | ], 46 | data = [ 47 | "//tensorflow_fold/util:test_proto_files", 48 | ], 49 | deps = [ 50 | "@org_tensorflow//tensorflow:tensorflow_py", 51 | ], 52 | ) 53 | 54 | fold_py_test( 55 | name = "blocks_test", 56 | srcs = ["blocks_test.py"], 57 | data = [ 58 | "//tensorflow_fold/util:test_proto_files", 59 | ], 60 | deps = [ 61 | ":blocks", 62 | ":test_lib", 63 | "@org_tensorflow//tensorflow:tensorflow_py", 64 | "//tensorflow_fold/util:cpp_test_proto_lib", 65 | "//tensorflow_fold/util:test_py_pb2", 66 | ], 67 | ) 68 | 69 | fold_py_test( 70 | name = "block_compiler_test", 71 | srcs = ["block_compiler_test.py"], 72 | deps = [ 73 | ":blocks", 74 | ":test_lib", 75 | "@org_tensorflow//tensorflow:tensorflow_py", 76 | ], 77 | ) 78 | 79 | fold_py_test( 80 | name = "layers_test", 81 | srcs = ["layers_test.py"], 82 | deps = [ 83 | ":blocks", 84 | ":test_lib", 85 | "@org_tensorflow//tensorflow:tensorflow_py", 86 | ], 87 | ) 88 | 89 | fold_py_test( 90 | name = "metrics_test", 91 | srcs = ["metrics_test.py"], 92 | deps = [ 93 | ":blocks", 94 | ":test_lib", 95 | "@org_tensorflow//tensorflow:tensorflow_py", 96 | ], 97 | ) 98 | 99 | fold_py_library( 100 | name = "result_types", 101 | srcs = ["result_types.py"], 102 | deps = [ 103 | # numpy", 104 | "@six_archive//:six", 105 | "@org_tensorflow//tensorflow:tensorflow_py", 106 | "//tensorflow_fold/public:loom", 107 | ], 108 | ) 109 | 110 | fold_py_test( 111 | name = "result_types_test", 112 | srcs = ["result_types_test.py"], 113 | deps = [ 114 | ":result_types", 115 | ":test_lib", 116 | "@org_tensorflow//tensorflow:tensorflow_py", 117 | ], 118 | ) 119 | 120 | fold_py_library( 121 | name = "util", 122 | srcs = ["util.py"], 123 | deps = [ 124 | # numpy", 125 | ], 126 | ) 127 | 128 | fold_py_test( 129 | name = "util_test", 130 | srcs = ["util_test.py"], 131 | deps = [ 132 | ":test_lib", 133 | ":util", 134 | "@org_tensorflow//tensorflow:tensorflow_py", 135 | ], 136 | ) 137 | 138 | fold_py_library( 139 | name = "plan", 140 | srcs = ["plan.py"], 141 | deps = [ 142 | ":util", 143 | # numpy", 144 | "@org_tensorflow//tensorflow:tensorflow_py", 145 | "@org_tensorflow//tensorflow/python/debug:debug_py", 146 | ], 147 | ) 148 | 149 | fold_py_test( 150 | name = "plan_test", 151 | srcs = ["plan_test.py"], 152 | deps = [ 153 | ":blocks", 154 | ":plan", 155 | ":test_lib", 156 | "@six_archive//:six", 157 | "@org_tensorflow//tensorflow:tensorflow_py", 158 | ], 159 | ) 160 | -------------------------------------------------------------------------------- /tensorflow_fold/blocks/examples/BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # Example models using blocks. 3 | 4 | licenses(["notice"]) # Apache 2.0 5 | 6 | package( 7 | default_visibility = [ 8 | "//tensorflow_fold/blocks/examples:__subpackages__", 9 | ], 10 | ) 11 | 12 | load("//tensorflow_fold:fold.bzl", "fold_py_binary") 13 | 14 | fold_py_binary( 15 | name = "plan", 16 | srcs = ["plan.py"], 17 | deps = [ 18 | # numpy", 19 | "@org_tensorflow//tensorflow:tensorflow_py", 20 | "@org_tensorflow//tensorflow/python/debug:debug_py", 21 | "//tensorflow_fold/blocks:util", 22 | ], 23 | ) 24 | -------------------------------------------------------------------------------- /tensorflow_fold/blocks/examples/calculator/BUILD: -------------------------------------------------------------------------------- 1 | # Calculator example for TensorFlow Fold. 2 | 3 | licenses(["notice"]) # Apache 2.0 4 | 5 | load("//tensorflow_fold:fold.bzl", "fold_py_binary", "fold_py_library") 6 | 7 | fold_py_library( 8 | name = "model", 9 | srcs = ["model.py"], 10 | deps = [ 11 | "@org_tensorflow//tensorflow:tensorflow_py", 12 | "//tensorflow_fold/public:blocks", 13 | ], 14 | ) 15 | 16 | fold_py_binary( 17 | name = "train", 18 | srcs = ["train.py"], 19 | data = [ 20 | "//tensorflow_fold/loom/calculator_example:calculator_proto_file", 21 | ], 22 | deps = [ 23 | ":model", 24 | "@six_archive//:six", 25 | "@org_tensorflow//tensorflow:tensorflow_py", 26 | "//tensorflow_fold/loom/calculator_example:calculator_py_pb2", 27 | # "//tensorflow_fold/loom/calculator_example:cpp_calculator_proto", 28 | "//tensorflow_fold/public:blocks", 29 | "//tensorflow_fold/util:proto", 30 | ], 31 | ) 32 | -------------------------------------------------------------------------------- /tensorflow_fold/blocks/examples/calculator/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """This is the model for the TensorFlow Fold calculator example.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | # import google3 20 | import tensorflow as tf 21 | import tensorflow_fold.public.blocks as td 22 | 23 | # The protobuf we're using here is from 24 | # //tensorflow_fold/loom/calculator_example/calculator.proto 25 | 26 | NUM_LABELS = 3 # negative, zero and positive. 27 | 28 | 29 | def preprocess_expression(expr): 30 | # Set the op field for numbers, so we can handle cases uniformly. 31 | if expr['number'] is not None: 32 | expr['op'] = {'name': 'NUM'} 33 | return expr 34 | 35 | 36 | def result_sign(result): 37 | if result < 0: return 0 38 | if result == 0: return 1 39 | return 2 40 | 41 | 42 | class CalculatorModel(object): 43 | """A Fold model for calculator examples.""" 44 | 45 | def __init__(self, state_size): 46 | # Expressions are either constants, or calculator ops that take other 47 | # expressions as their arguments. Since an Expression is a recursive type, 48 | # the model must likewise be recursive. A ForwardDeclaration declares the 49 | # type of expression, so it can be used before it before it is defined. 50 | expr_decl = td.ForwardDeclaration(td.PyObjectType(), state_size) 51 | 52 | # Create a block for each type of expression. 53 | # The terminals are the digits 0-9, which we map to vectors using 54 | # an embedding table. 55 | digit = (td.GetItem('number') >> td.Scalar(dtype='int32') >> 56 | td.Function(td.Embedding(10, state_size, name='terminal_embed'))) 57 | 58 | # For non terminals, recursively apply expression to the left/right sides, 59 | # concatenate the results, and pass them through a fully-connected layer. 60 | # Each operation uses different weights in the FC layer. 61 | def bin_op(name): 62 | return (td.Record([('left', expr_decl()), ('right', expr_decl())]) >> 63 | td.Concat() >> 64 | td.FC(state_size, name='FC_'+name)) 65 | 66 | # OneOf will dispatch its input to the appropriate case, based on the value 67 | # in the 'op'.'name' field. 68 | cases = td.OneOf(lambda x: x['op']['name'], 69 | {'NUM': digit, 70 | 'PLUS': bin_op('PLUS'), 71 | 'MINUS': bin_op('MINUS'), 72 | 'TIMES': bin_op('TIMES'), 73 | 'DIV': bin_op('DIV')}) 74 | 75 | # We do preprocessing to add 'NUM' as a distinct case. 76 | expression = td.InputTransform(preprocess_expression) >> cases 77 | expr_decl.resolve_to(expression) 78 | 79 | # Get logits from the root of the expression tree 80 | expression_logits = (expression >> 81 | td.FC(NUM_LABELS, activation=None, name='FC_logits')) 82 | 83 | # The result is stored in the expression itself. 84 | # We ignore it in td.Record above, and pull it out here. 85 | expression_label = (td.GetItem('result') >> 86 | td.InputTransform(result_sign) >> 87 | td.OneHot(NUM_LABELS)) 88 | 89 | # For the overall model, return a pair of (logits, labels) 90 | # The AllOf block will run each of its children on the same input. 91 | model = td.AllOf(expression_logits, expression_label) 92 | self._compiler = td.Compiler.create(model) 93 | 94 | # Get the tensorflow tensors that correspond to the outputs of model. 95 | # `logits` and `labels` are TF tensors, and we can use them to 96 | # compute losses in the usual way. 97 | (logits, labels) = self._compiler.output_tensors 98 | 99 | self._loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( 100 | logits=logits, labels=labels)) 101 | 102 | self._accuracy = tf.reduce_mean( 103 | tf.cast(tf.equal(tf.argmax(labels, 1), 104 | tf.argmax(logits, 1)), 105 | dtype=tf.float32)) 106 | 107 | self._global_step = tf.Variable(0, name='global_step', trainable=False) 108 | optr = tf.train.GradientDescentOptimizer(0.01) 109 | self._train_op = optr.minimize(self._loss, global_step=self._global_step) 110 | 111 | @property 112 | def loss(self): 113 | return self._loss 114 | 115 | @property 116 | def accuracy(self): 117 | return self._accuracy 118 | 119 | @property 120 | def train_op(self): 121 | return self._train_op 122 | 123 | @property 124 | def global_step(self): 125 | return self._global_step 126 | 127 | def build_feed_dict(self, expressions): 128 | return self._compiler.build_feed_dict(expressions) 129 | -------------------------------------------------------------------------------- /tensorflow_fold/blocks/examples/calculator/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Runs the trainer for the calculator example. 16 | 17 | This file is a minor modification to loom/calculator_example/train.py. 18 | To run, first make the data set: 19 | 20 | ./tensorflow_fold/loom/calculator_example/make_dataset \ 21 | --output_path=DIR/calc_data.dat 22 | 23 | Then run the trainer: 24 | 25 | ./tensorflow_fold/blocks/examples/calculator/train \ 26 | --train_data_path=DIR/calc_data.dat 27 | """ 28 | from __future__ import absolute_import 29 | from __future__ import division 30 | from __future__ import print_function 31 | import os 32 | # import google3 33 | import six 34 | from six.moves import xrange # pylint: disable=redefined-builtin 35 | import tensorflow as tf 36 | from tensorflow_fold.blocks.examples.calculator import model 37 | from tensorflow_fold.loom.calculator_example import calculator_pb2 38 | from tensorflow_fold.util import proto_tools 39 | 40 | tf.flags.DEFINE_string( 41 | 'train_data_path', '', 42 | 'TF Record file containing the training dataset of expressions.') 43 | tf.flags.DEFINE_integer( 44 | 'batch_size', 1000, 'How many samples to read per batch.') 45 | tf.flags.DEFINE_integer( 46 | 'embedding_length', 5, 47 | 'How long to make the expression embedding vectors.') 48 | tf.flags.DEFINE_integer( 49 | 'max_steps', 1000000, 50 | 'The maximum number of batches to run the trainer for.') 51 | 52 | # Replication flags: 53 | tf.flags.DEFINE_string('logdir', '/tmp/calculator_example', 54 | 'Directory in which to write event logs.') 55 | tf.flags.DEFINE_string('master', '', 56 | 'Tensorflow master to use.') 57 | tf.flags.DEFINE_integer('task', 0, 58 | 'Task ID of the replica running the training.') 59 | tf.flags.DEFINE_integer('ps_tasks', 0, 60 | 'Number of PS tasks in the job.') 61 | FLAGS = tf.flags.FLAGS 62 | 63 | 64 | # Find the root of the bazel repository. 65 | def source_root(): 66 | root = __file__ 67 | for _ in xrange(5): 68 | root = os.path.dirname(root) 69 | return root 70 | 71 | CALCULATOR_SOURCE_ROOT = source_root() 72 | CALCULATOR_PROTO_FILE = ('tensorflow_fold/loom/' 73 | 'calculator_example/calculator.proto') 74 | CALCULATOR_EXPRESSION_PROTO = ('tensorflow_fold.loom.' 75 | 'calculator_example.CalculatorExpression') 76 | 77 | 78 | # Make sure serialized_message_to_tree can find the calculator example proto: 79 | proto_tools.map_proto_source_tree_path('', CALCULATOR_SOURCE_ROOT) 80 | proto_tools.import_proto_file(CALCULATOR_PROTO_FILE) 81 | 82 | 83 | def iterate_over_tf_record_protos(table_path, unused_message_type): 84 | while True: 85 | for v in tf.python_io.tf_record_iterator(table_path): 86 | yield proto_tools.serialized_message_to_tree( 87 | CALCULATOR_EXPRESSION_PROTO, v) 88 | 89 | 90 | def emit_values(supervisor, session, step, values): 91 | summary = tf.Summary() 92 | for name, value in six.iteritems(values): 93 | summary_value = summary.value.add() 94 | summary_value.tag = name 95 | summary_value.simple_value = float(value) 96 | supervisor.summary_computed(session, summary, global_step=step) 97 | 98 | 99 | def main(unused_argv): 100 | train_iterator = iterate_over_tf_record_protos( 101 | FLAGS.train_data_path, calculator_pb2.CalculatorExpression) 102 | 103 | with tf.Graph().as_default(): 104 | with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): 105 | 106 | # Build the graph. 107 | classifier = model.CalculatorModel(FLAGS.embedding_length) 108 | loss = classifier.loss 109 | accuracy = classifier.accuracy 110 | train_op = classifier.train_op 111 | global_step = classifier.global_step 112 | 113 | # Set up the supervisor. 114 | supervisor = tf.train.Supervisor( 115 | logdir=FLAGS.logdir, 116 | is_chief=(FLAGS.task == 0), 117 | save_summaries_secs=10, 118 | save_model_secs=30) 119 | sess = supervisor.PrepareSession(FLAGS.master) 120 | 121 | # Run the trainer. 122 | for _ in xrange(FLAGS.max_steps): 123 | batch = [next(train_iterator) for _ in xrange(FLAGS.batch_size)] 124 | fdict = classifier.build_feed_dict(batch) 125 | 126 | _, step, loss_v, accuracy_v = sess.run( 127 | [train_op, global_step, loss, accuracy], 128 | feed_dict=fdict) 129 | print('step=%d: loss=%f accuracy=%f' % (step, loss_v, accuracy_v)) 130 | emit_values(supervisor, sess, step, 131 | {'Batch Loss': loss_v, 132 | 'Batch Accuracy': accuracy_v}) 133 | 134 | if __name__ == '__main__': 135 | tf.app.run() 136 | -------------------------------------------------------------------------------- /tensorflow_fold/blocks/examples/fizzbuzz/BUILD: -------------------------------------------------------------------------------- 1 | # Fizzbuzz example for TensorFlow Fold. 2 | 3 | licenses(["notice"]) # Apache 2.0 4 | 5 | load("//tensorflow_fold:fold.bzl", "fold_py_binary") 6 | 7 | fold_py_binary( 8 | name = "fizzbuzz", 9 | srcs = ["fizzbuzz.py"], 10 | deps = [ 11 | "@six_archive//:six", 12 | "@org_tensorflow//tensorflow:tensorflow_py", 13 | "//tensorflow_fold/public:blocks", 14 | ], 15 | ) 16 | -------------------------------------------------------------------------------- /tensorflow_fold/blocks/examples/language_id/BUILD: -------------------------------------------------------------------------------- 1 | # MNIST example for TensorFlow Fold. 2 | 3 | licenses(["notice"]) # Apache 2.0 4 | 5 | load("//tensorflow_fold:fold.bzl", "fold_py_binary") 6 | 7 | fold_py_binary( 8 | name = "language_id", 9 | srcs = ["language_id.py"], 10 | deps = [ 11 | "@org_tensorflow//tensorflow:tensorflow_py", 12 | "//tensorflow_fold/blocks/examples:plan", 13 | "//tensorflow_fold/public:blocks", 14 | ], 15 | ) 16 | -------------------------------------------------------------------------------- /tensorflow_fold/blocks/examples/language_id/fetch_datasets.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2017 Google Inc. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # See https://tatoeba.org/eng/downloads for details. 18 | 19 | # Produces a file named sentences.csv 20 | cd /tmp/ 21 | wget http://downloads.tatoeba.org/exports/sentences.tar.bz2 -O - | tar -xjf - 22 | 23 | # Keep only the sentences in the eight languages listed whose sentences are 24 | # made entirely from lower case a-z letters. 25 | awk -F"\t" '$2 ~ /(deu|eng|epo|fra|ita|nld|por|spa)/ && $3 ~ /^[\x00-\x80]+$/' < sentences.csv \ 26 | | tr -d '[:punct:]' | tr '[:upper:]' '[:lower:]' | shuf \ 27 | > roman_sentences.csv 28 | 29 | # Do an 20/80 dev/train split deteriministically by the last digit of the 30 | # sentence ID number. 31 | awk -F"\t" '$1 ~ /(1|2)\y/ {print $2","$3}' < roman_sentences.csv > roman_sentences_dev.csv 32 | awk -F"\t" '$1 !~ /(1|2)\y/ {print $2","$3}' < roman_sentences.csv > roman_sentences_train.csv 33 | -------------------------------------------------------------------------------- /tensorflow_fold/blocks/examples/mnist/BUILD: -------------------------------------------------------------------------------- 1 | # MNIST example for TensorFlow Fold. 2 | 3 | licenses(["notice"]) # Apache 2.0 4 | 5 | load("//tensorflow_fold:fold.bzl", "fold_py_binary") 6 | 7 | fold_py_binary( 8 | name = "mnist", 9 | srcs = ["mnist.py"], 10 | deps = [ 11 | "@six_archive//:six", 12 | "@org_tensorflow//tensorflow:tensorflow_py", 13 | "//tensorflow_fold/blocks/examples:plan", 14 | "//tensorflow_fold/public:blocks", 15 | ], 16 | ) 17 | -------------------------------------------------------------------------------- /tensorflow_fold/blocks/examples/mnist/mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | r"""TensorFlow Fold for MNIST with fully connected layers and dropout. 15 | 16 | With default settings the test accuracy after 20 epochs is ~ 98.4%. 17 | 18 | Build: 19 | bazel build --config=opt \ 20 | //tensorflow_fold/blocks/examples/mnist 21 | 22 | Train: 23 | ./bazel-bin/tensorflow_fold/blocks/examples/mnist/mnist 24 | 25 | Eval: 26 | ./bazel-bin/tensorflow_fold/blocks/examples/mnist/mnist \ 27 | --mode=eval --eval_interval_secs=10 # set to 0 to evaluate once and exit 28 | 29 | Inference: 30 | ./bazel-bin/tensorflow_fold/blocks/examples/mnist/mnist \ 31 | --mode=infer 32 | 33 | See below and 34 | for additional flag options. 35 | """ 36 | from __future__ import absolute_import 37 | from __future__ import division 38 | from __future__ import print_function 39 | # import google3 40 | from six.moves import xrange # pylint: disable=redefined-builtin 41 | from six.moves import zip # pylint: disable=redefined-builtin 42 | import tensorflow as tf 43 | import tensorflow_fold.blocks.examples.plan as pl 44 | import tensorflow_fold.public.blocks as td 45 | 46 | 47 | NUM_LABELS = 10 48 | INPUT_LENGTH = 784 # 28 x 28 49 | 50 | flags = tf.app.flags 51 | FLAGS = flags.FLAGS 52 | flags.DEFINE_integer('num_layers', 2, 'Number of hidden layers.') 53 | flags.DEFINE_integer('num_units', 500, 'Number of units per hidden layer.') 54 | flags.DEFINE_float('keep_prob', 0.75, 'Keep probability for dropout.') 55 | 56 | 57 | def setup_plan(plan): 58 | """Sets up a TensorFlow Fold plan for MNIST. 59 | 60 | The inputs are 28 x 28 images represented as 784-dimensional float32 61 | vectors (scaled to [0, 1] and categorical digit labels in [0, 9]. 62 | 63 | The training loss is softmax cross-entropy. There is only one 64 | metric, accuracy. In inference mode, the output is a class label. 65 | 66 | Dropout is applied before every layer (including on the inputs). 67 | 68 | Args: 69 | plan: A TensorFlow Fold plan to set up. 70 | """ 71 | # Convert the input NumPy array into a tensor. 72 | model_block = td.Vector(INPUT_LENGTH) 73 | 74 | # Create a placeholder for dropout, if we are in train mode. 75 | keep_prob = (tf.placeholder_with_default(1.0, [], name='keep_prob') 76 | if plan.mode == plan.mode_keys.TRAIN else None) 77 | 78 | # Add the fully connected hidden layers. 79 | for _ in xrange(FLAGS.num_layers): 80 | model_block >>= td.FC(FLAGS.num_units, input_keep_prob=keep_prob) 81 | 82 | # Add the linear output layer. 83 | model_block >>= td.FC(NUM_LABELS, activation=None, input_keep_prob=keep_prob) 84 | 85 | if plan.mode == plan.mode_keys.INFER: 86 | # In inference mode, we run the model directly on images. 87 | plan.compiler = td.Compiler.create(model_block) 88 | logits, = plan.compiler.output_tensors 89 | else: 90 | # In training/eval mode, we run the model on (image, label) pairs. 91 | plan.compiler = td.Compiler.create( 92 | td.Record((model_block, td.Scalar(tf.int64)))) 93 | logits, y_ = plan.compiler.output_tensors 94 | 95 | y = tf.argmax(logits, 1) # create the predicted output tensor 96 | 97 | datasets = tf.contrib.learn.datasets.mnist.load_mnist(FLAGS.logdir_base) 98 | if plan.mode == plan.mode_keys.INFER: 99 | plan.examples = datasets.test.images 100 | plan.outputs = [y] 101 | else: 102 | # Create loss and accuracy tensors, and add them to the plan. 103 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits( 104 | logits=logits, labels=y_) 105 | plan.losses['cross_entropy'] = loss 106 | accuracy = tf.reduce_mean(tf.cast(tf.equal(y, y_), tf.float32)) 107 | plan.metrics['accuracy'] = accuracy 108 | if plan.mode == plan.mode_keys.TRAIN: 109 | plan.examples = zip(datasets.train.images, datasets.train.labels) 110 | plan.dev_examples = zip(datasets.validation.images, 111 | datasets.validation.labels) 112 | # Turn dropout on for training, off for validation. 113 | plan.train_feeds[keep_prob] = FLAGS.keep_prob 114 | else: 115 | assert plan.mode == plan.mode_keys.EVAL 116 | plan.examples = zip(datasets.test.images, datasets.test.labels) 117 | 118 | 119 | def main(_): 120 | assert 0 < FLAGS.keep_prob <= 1, '--keep_prob must be in (0, 1]' 121 | pl.Plan.create_from_flags(setup_plan).run() 122 | 123 | 124 | if __name__ == '__main__': 125 | tf.app.run() 126 | -------------------------------------------------------------------------------- /tensorflow_fold/blocks/examples/sentiment/BUILD: -------------------------------------------------------------------------------- 1 | # Experiments applying TensorFlow Fold to Stanford Sentiment Treebank data 2 | # 3 | 4 | licenses(["notice"]) # Apache 2.0 5 | 6 | load("//tensorflow_fold:fold.bzl", "fold_py_binary", "fold_py_library") 7 | 8 | fold_py_library( 9 | name = "sentiment", 10 | srcs = ["sentiment.py"], 11 | deps = [ 12 | # nltk 13 | "@org_tensorflow//tensorflow:tensorflow_py", 14 | "//tensorflow_fold/public:blocks", 15 | ], 16 | ) 17 | 18 | fold_py_binary( 19 | name = "train", 20 | srcs = ["train.py"], 21 | deps = [ 22 | ":sentiment", 23 | "@org_tensorflow//tensorflow:tensorflow_py", 24 | "//tensorflow_fold/public:blocks", 25 | ], 26 | ) 27 | 28 | fold_py_binary( 29 | name = "eval", 30 | srcs = ["eval.py"], 31 | deps = [ 32 | ":sentiment", 33 | "@org_tensorflow//tensorflow:tensorflow_py", 34 | ], 35 | ) 36 | 37 | fold_py_binary( 38 | name = "filter_glove", 39 | srcs = ["filter_glove.py"], 40 | deps = [ 41 | "@org_tensorflow//tensorflow:tensorflow_py", 42 | ], 43 | ) 44 | -------------------------------------------------------------------------------- /tensorflow_fold/blocks/examples/sentiment/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Eval for TensorFlow Fold sentiment models.""" 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | import os.path 19 | # import google3 20 | import tensorflow as tf 21 | from tensorflow_fold.blocks.examples.sentiment import sentiment 22 | 23 | flags = tf.app.flags 24 | FLAGS = flags.FLAGS 25 | flags.DEFINE_string('checkpoint_file', None, 'Model checkpoint file.') 26 | FOLDS = ['test', 'dev', 'train'] 27 | 28 | 29 | def main(_): 30 | print('loading word embeddings from %s' % FLAGS.embedding_file) 31 | weight_matrix, word_idx = sentiment.load_embeddings(FLAGS.embedding_file) 32 | 33 | with tf.Session() as sess: 34 | print('restoring the model') 35 | word_embedding = sentiment.create_embedding(weight_matrix) 36 | compiler, metrics = sentiment.create_model( 37 | word_embedding, word_idx, FLAGS.lstm_num_units) 38 | saver = tf.train.Saver() 39 | saver.restore(sess, FLAGS.checkpoint_file) 40 | print('model restored from file: %s' % FLAGS.checkpoint_file) 41 | 42 | print('evaluating on trees from %s' % FLAGS.tree_dir) 43 | with compiler.multiprocessing_pool(): 44 | filenames = [os.path.join(FLAGS.tree_dir, '%s.txt' % f) for f in FOLDS] 45 | for filename in filenames: 46 | trees = sentiment.load_trees(filename) 47 | print('file: %s, #trees: %d' % (filename, len(trees))) 48 | res = sorted(sess.run(metrics, compiler.build_feed_dict(trees)).items()) 49 | print(' loss: [%s]' % 50 | ' '.join('%s: %.3e' % (name.rsplit('_', 1)[0], v) 51 | for name, v in res if name.endswith('_loss'))) 52 | print(' accuracy: [%s]' % 53 | ' '.join('%s: %.2f' % (name.rsplit('_', 1)[0], v * 100) 54 | for name, v in res if name.endswith('_hits'))) 55 | 56 | 57 | if __name__ == '__main__': 58 | tf.app.run() 59 | -------------------------------------------------------------------------------- /tensorflow_fold/blocks/examples/sentiment/filter_glove.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utility to filter GloVe vectors by vocabulary. 15 | 16 | Vectors: 17 | Sentences: 18 | (file is SOStr.txt). 19 | """ 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | import codecs 24 | # import google3 25 | import tensorflow as tf 26 | 27 | flags = tf.app.flags 28 | FLAGS = flags.FLAGS 29 | flags.DEFINE_string('glove_file', None, 'GloVe file.') 30 | flags.DEFINE_string('sentence_file', None, 'Sentence file (one per line).') 31 | flags.DEFINE_string('word_separator', '|', 'Word separator.') 32 | flags.DEFINE_string('output_file', None, 'Output file') 33 | 34 | 35 | def main(_): 36 | vocab = set() 37 | with codecs.open(FLAGS.sentence_file, encoding='utf-8') as f: 38 | for line in f: 39 | # Drop the trailing newline and strip backslashes. Split into words. 40 | vocab.update(line.strip().replace('\\', '').split(FLAGS.word_separator)) 41 | 42 | nread = 0 43 | nwrote = 0 44 | with codecs.open(FLAGS.glove_file, encoding='utf-8') as f: 45 | with codecs.open(FLAGS.output_file, 'w', encoding='utf-8') as out: 46 | for line in f: 47 | nread += 1 48 | line = line.strip() 49 | if not line: continue 50 | if line.split(u' ', 1)[0] in vocab: 51 | out.write(line + '\n') 52 | nwrote += 1 53 | 54 | print('read %s lines, wrote %s' % (nread, nwrote)) 55 | 56 | if __name__ == '__main__': 57 | tf.app.run() 58 | -------------------------------------------------------------------------------- /tensorflow_fold/blocks/examples/sentiment/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Training for TensorFlow Fold sentiment models. 15 | 16 | Data from: 17 | """ 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | import os.path 22 | # import google3 23 | import tensorflow as tf 24 | from tensorflow_fold.blocks.examples.sentiment import sentiment 25 | import tensorflow_fold.public.blocks as td 26 | 27 | 28 | flags = tf.app.flags 29 | FLAGS = flags.FLAGS 30 | flags.DEFINE_string( 31 | 'checkpoint_base', None, 'Path prefix for saving checkpoints') 32 | flags.DEFINE_integer( 33 | 'epochs', 20, 'Number of training epochs. Zero to run forever.') 34 | flags.DEFINE_integer( 35 | 'batch_size', 100, 'Number of examples per batch.') 36 | flags.DEFINE_float( 37 | 'learning_rate', 0.05, 'Learning rate for Adagrad optimizer.') 38 | flags.DEFINE_float( 39 | 'embedding_learning_rate_factor', 0.1, 40 | 'Scaling factor for gradient updates to word embedding vectors.') 41 | flags.DEFINE_float( 42 | 'keep_prob', 0.75, 'Keep probability for dropout.') 43 | 44 | 45 | def main(_): 46 | print('loading word embeddings from %s' % FLAGS.embedding_file) 47 | weight_matrix, word_idx = sentiment.load_embeddings(FLAGS.embedding_file) 48 | 49 | train_file = os.path.join(FLAGS.tree_dir, 'train.txt') 50 | print('loading training trees from %s' % train_file) 51 | train_trees = sentiment.load_trees(train_file) 52 | 53 | dev_file = os.path.join(FLAGS.tree_dir, 'dev.txt') 54 | print('loading dev trees from %s' % dev_file) 55 | dev_trees = sentiment.load_trees(dev_file) 56 | 57 | with tf.Session() as sess: 58 | print('creating the model') 59 | keep_prob = tf.placeholder_with_default(1.0, []) 60 | train_feed_dict = {keep_prob: FLAGS.keep_prob} 61 | word_embedding = sentiment.create_embedding(weight_matrix) 62 | compiler, metrics = sentiment.create_model( 63 | word_embedding, word_idx, FLAGS.lstm_num_units, keep_prob) 64 | loss = tf.reduce_sum(compiler.metric_tensors['all_loss']) 65 | opt = tf.train.AdagradOptimizer(FLAGS.learning_rate) 66 | grads_and_vars = opt.compute_gradients(loss) 67 | found = 0 68 | for i, (grad, var) in enumerate(grads_and_vars): 69 | if var == word_embedding.weights: 70 | found += 1 71 | grad = tf.scalar_mul(FLAGS.embedding_learning_rate_factor, grad) 72 | grads_and_vars[i] = (grad, var) 73 | assert found == 1 # internal consistency check 74 | train = opt.apply_gradients(grads_and_vars) 75 | saver = tf.train.Saver() 76 | 77 | print('initializing tensorflow') 78 | sess.run(tf.global_variables_initializer()) 79 | 80 | with compiler.multiprocessing_pool(): 81 | print('training the model') 82 | train_set = compiler.build_loom_inputs(train_trees) 83 | dev_feed_dict = compiler.build_feed_dict(dev_trees) 84 | dev_hits_best = 0.0 85 | for epoch, shuffled in enumerate(td.epochs(train_set, FLAGS.epochs), 1): 86 | train_loss = 0.0 87 | for batch in td.group_by_batches(shuffled, FLAGS.batch_size): 88 | train_feed_dict[compiler.loom_input_tensor] = batch 89 | _, batch_loss = sess.run([train, loss], train_feed_dict) 90 | train_loss += batch_loss 91 | dev_metrics = sess.run(metrics, dev_feed_dict) 92 | dev_loss = dev_metrics['all_loss'] 93 | dev_accuracy = ['%s: %.2f' % (k, v * 100) for k, v in 94 | sorted(dev_metrics.items()) if k.endswith('hits')] 95 | print('epoch:%4d, train_loss: %.3e, dev_loss: %.3e, dev_accuracy: [%s]' 96 | % (epoch, train_loss, dev_loss, ' '.join(dev_accuracy))) 97 | dev_hits = dev_metrics['root_hits'] 98 | if dev_hits > dev_hits_best: 99 | dev_hits_best = dev_hits 100 | save_path = saver.save(sess, FLAGS.checkpoint_base, global_step=epoch) 101 | print('model saved in file: %s' % save_path) 102 | 103 | if __name__ == '__main__': 104 | tf.app.run() 105 | -------------------------------------------------------------------------------- /tensorflow_fold/blocks/loom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Loom Ops used by Fold high-level API. This is an internal module.""" 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | # import google3 19 | import tensorflow_fold.blocks.result_types as tdt 20 | from tensorflow_fold.public import loom 21 | 22 | 23 | def _get_typeshapes(tensor_ts): 24 | return [t._type_shape for t in tensor_ts] # pylint: disable=protected-access 25 | 26 | 27 | class TaggingPassThroughOp(loom.LoomOp): 28 | """A pass-through op that adds a tag to its output type. 29 | 30 | When constructing a Fold Compiler, its root output and metrics 31 | are routed through a tagging pass-through. This is necessary to 32 | ensure that output tensors of the same type are uniquely identifiable. 33 | """ 34 | 35 | def __init__(self, passthrough_types, tags): 36 | self.passthrough_types = passthrough_types 37 | in_ts = _get_typeshapes(passthrough_types) 38 | out_ts = [loom.TypeShape(ts.dtype, ts.shape, tag) 39 | for ts, tag in zip(in_ts, tags)] 40 | super(TaggingPassThroughOp, self).__init__(in_ts, out_ts) 41 | 42 | def instantiate_batch(self, inputs): 43 | return inputs 44 | 45 | 46 | class FuncallOp(loom.LoomOp): 47 | """Loom Op that wraps a Function.""" 48 | 49 | def __init__(self, tf_fn, input_type, output_type): 50 | self.tf_fn = tf_fn 51 | in_ts = _get_typeshapes(input_type.terminal_types()) 52 | out_ts = _get_typeshapes(output_type.terminal_types()) 53 | super(FuncallOp, self).__init__(in_ts, out_ts) 54 | self._unflatten_inputs = None 55 | self._flatten_outputs = None 56 | if (isinstance(input_type, tdt.TupleType) and 57 | any(isinstance(t, tdt.TupleType) for t in input_type)): 58 | self._unflatten_inputs = input_type.unflatten 59 | if (not isinstance(output_type, tdt.TupleType) or 60 | any(isinstance(t, tdt.TupleType) for t in output_type)): 61 | self._flatten_outputs = output_type.flatten 62 | 63 | def instantiate_batch(self, inputs): 64 | if self._unflatten_inputs: 65 | inputs = self._unflatten_inputs(iter(inputs), None) 66 | outputs = self.tf_fn(*inputs) 67 | return self._flatten_outputs(outputs) if self._flatten_outputs else outputs 68 | -------------------------------------------------------------------------------- /tensorflow_fold/blocks/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Metrics for TensorFlow Fold.""" 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | # import google3 19 | import six 20 | import tensorflow_fold.blocks.blocks as tdb 21 | import tensorflow_fold.blocks.result_types as tdt 22 | 23 | 24 | class Metric(tdb.Block): 25 | """A block that computes a metric. 26 | 27 | Metrics are used in Fold when the size of a model's output is not 28 | fixed, but varies as a function of the input data. They are also 29 | handy for accumulating results across sequential and recursive 30 | computations without having the thread them through explicitly as 31 | return values. 32 | 33 | For example, to create a block `y` that takes a (label, prediction) 34 | as input, adds an L2 `'loss'` metric, and returns the prediction as 35 | its output, you could say: 36 | 37 | ```python 38 | y = Composition() 39 | with y.scope(): 40 | label = y.input[0] 41 | prediction = y.input[1] 42 | l2 = (Function(tf.sub) >> Function(tf.nn.l2_loss)).reads(label, prediction) 43 | Metric('loss').reads(l2) 44 | y.output.reads(prediction) 45 | ``` 46 | 47 | The input type of the block must be a `TensorType`, or a 48 | `(TensorType, PyObjectType)` tuple. 49 | The output type is always `VoidType`. In the tuple input case, the 50 | second item of the tuple becomes a label for the tensor value, which 51 | can be used to identify where the value came from in a nested data 52 | structure and/or batch of inputs. 53 | 54 | For example: 55 | 56 | ```python 57 | sess = tf.InteractiveSession() 58 | # We pipe Map() to Void() because blocks with sequence output types 59 | # cannot be compiled. 60 | block = td.Map(td.Scalar() >> td.Metric('foo')) >> td.Void() 61 | compiler = td.Compiler.create(block) 62 | sess.run(compiler.metric_tensors['foo'], 63 | compiler.build_feed_dict([range(3), range(4)])) => 64 | array([ 0., 1., 2., 0., 1., 2., 3.], dtype=float32) 65 | ``` 66 | 67 | Or with labels: 68 | 69 | ```python 70 | sess = tf.InteractiveSession() 71 | block = td.Map((td.Scalar(), td.Identity()) >> td.Metric('bar')) >> td.Void() 72 | compiler = td.Compiler.create(block) 73 | feed_dict, metric_labels = compiler.build_feed_dict( 74 | [[(0, 'zero'), (1, 'one')], [(2, 'two')]], 75 | metric_labels=True) 76 | metric_labels => {'bar': ['zero', 'one', 'two']} 77 | sess.run(compiler.metric_tensors['bar'], feed_dict) => 78 | array([ 0., 1., 2.], dtype=float32) 79 | ``` 80 | """ 81 | 82 | def __init__(self, metric_name): 83 | if not isinstance(metric_name, six.string_types): 84 | raise TypeError('metric_name must be a string: %s' % (metric_name,)) 85 | self._metric_name = metric_name 86 | super(Metric, self).__init__(name=str(metric_name), 87 | output_type=tdt.VoidType()) 88 | 89 | _expected_input_types = (tdt.TensorType, tdt.TupleType) 90 | 91 | def _update_input_type(self): 92 | if isinstance(self.input_type, tdt.TupleType): 93 | if len(self.input_type) != 2: 94 | raise TypeError('metric tuple input must have 2 items: %s' % 95 | self.input_type) 96 | if not isinstance(self.input_type[0], tdt.TensorType): 97 | raise TypeError('expected a tensor type, saw: %s' % self.input_type[0]) 98 | if not isinstance(self.input_type[1], tdt.PyObjectType): 99 | raise TypeError('expected a pyobj type, saw: %s' % self.input_type[1]) 100 | self._evaluate = self._evaluate_labeled 101 | self._metric_type = self.input_type[0] 102 | else: 103 | self._evaluate = self._evaluate_unlabeled 104 | self._metric_type = self.input_type 105 | 106 | def _compile(self, compiler_ctx): 107 | compiler_ctx.register_metric_op(self._metric_name, self._metric_type) 108 | 109 | def _evaluate_labeled(self, eval_ctx, x): 110 | eval_ctx.add_output(eval_ctx.op(self._metric_name, [x[0]])[0]) 111 | eval_ctx.metric_labels[self._metric_name].append(x[1]) 112 | 113 | def _evaluate_unlabeled(self, eval_ctx, x): 114 | eval_ctx.add_output(eval_ctx.op(self._metric_name, [x])[0]) 115 | -------------------------------------------------------------------------------- /tensorflow_fold/blocks/metrics_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for tensorflow_fold.blocks.metrics.""" 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | # import google3 19 | import numpy as np 20 | import six 21 | import tensorflow as tf 22 | from tensorflow_fold.blocks import test_lib 23 | import tensorflow_fold.blocks.block_compiler as tdc 24 | import tensorflow_fold.blocks.blocks as tdb 25 | import tensorflow_fold.blocks.metrics as tdm 26 | 27 | 28 | def _ispositive(x): 29 | if isinstance(x, list): 30 | return x[0] > 0 31 | else: 32 | return x > 0 33 | 34 | 35 | def _pos_neg_block(shape): 36 | """Returns a Tensor block of shape, adding positive/negative metrics.""" 37 | c = tdb.Composition() 38 | with c.scope(): 39 | tdb.OneOf(_ispositive, 40 | (tdm.Metric('negative'), tdm.Metric('positive')), 41 | pre_block=tdb.Tensor(shape)).reads(c.input) 42 | c.output.reads(tdb.Tensor(shape).reads(c.input)) 43 | return c 44 | 45 | 46 | class MetricsTest(test_lib.TestCase): 47 | 48 | def test_metrics_scalar(self): 49 | block = tdb.Map(_pos_neg_block([])) >> tdb.Sum() 50 | 51 | with self.test_session() as sess: 52 | compiler = tdc.Compiler.create(block) 53 | sess.run(tf.global_variables_initializer()) 54 | 55 | fd = compiler.build_feed_dict([[1, 2, 3, 4]]) 56 | self.assertSameStructure( 57 | [10.], sess.run(compiler.output_tensors[0], fd).tolist()) 58 | 59 | positive = compiler.metric_tensors['positive'] 60 | negative = compiler.metric_tensors['negative'] 61 | 62 | fd = compiler.build_feed_dict([[1, -2, 3, -4, 5, -6]]) 63 | pos, neg = sess.run([positive, negative], fd) 64 | np.testing.assert_equal(pos, [1, 3, 5]) 65 | np.testing.assert_equal(neg, [-2, -4, -6]) 66 | 67 | fd = compiler.build_feed_dict([[-1, -2, 3, -4, -5, -6]]) 68 | pos, neg = sess.run([positive, negative], fd) 69 | np.testing.assert_equal(pos, [3]) # test single value 70 | np.testing.assert_equal(neg, [-1, -2, -4, -5, -6]) 71 | 72 | fd = compiler.build_feed_dict([[1, 2, 3, 4, 5, 6]]) 73 | pos, neg = sess.run([positive, negative], fd) 74 | np.testing.assert_equal(pos, [1, 2, 3, 4, 5, 6]) 75 | np.testing.assert_equal(neg, []) # test no values 76 | 77 | # test batches 78 | fd = compiler.build_feed_dict([[1, 2, -3, -4], [5, 6, -7, -8, 0]]) 79 | pos, neg = sess.run([positive, negative], fd) 80 | np.testing.assert_equal(pos, [1, 2, 5, 6]) 81 | np.testing.assert_equal(neg, [-3, -4, -7, -8, 0]) 82 | 83 | def test_metrics_vector(self): 84 | block = tdb.Map(_pos_neg_block([2])) >> tdb.Sum() 85 | 86 | with self.test_session() as sess: 87 | compiler = tdc.Compiler.create(block) 88 | sess.run(tf.global_variables_initializer()) 89 | 90 | positive = compiler.metric_tensors['positive'] 91 | negative = compiler.metric_tensors['negative'] 92 | 93 | fd = compiler.build_feed_dict([[[1, 2], [-2, -3], [4, 5]]]) 94 | pos, neg = sess.run([positive, negative], fd) 95 | np.testing.assert_equal(pos, [[1, 2], [4, 5]]) 96 | np.testing.assert_equal(neg, [[-2, -3]]) 97 | 98 | def test_metrics_raises(self): 99 | sp0 = _pos_neg_block([]) 100 | spn = _pos_neg_block([2]) 101 | block = {'foo': sp0, 'bar:': spn} >> tdb.Concat() 102 | six.assertRaisesRegex( 103 | self, TypeError, 'Metric [a-z]+tive has incompatible types', 104 | tdc.Compiler.create, block) 105 | 106 | def test_metrics_labeled(self): 107 | tree1 = [1, 'a', [2, 'b'], [3, 'c'], [4, 'd']] 108 | tree2 = [5, 'e', [6, 'f', [7, 'g']]] 109 | fwd = tdb.ForwardDeclaration() 110 | 111 | leaf = (tdb.Scalar('int32'), tdb.Identity()) >> tdm.Metric('leaf') 112 | internal = tdb.AllOf( 113 | (tdb.Scalar('int32'), tdb.Identity()) >> tdm.Metric('internal'), 114 | tdb.Slice(start=2) >> tdb.Map(fwd())) >> tdb.Void() 115 | tree = tdb.OneOf(key_fn=lambda expr: len(expr) > 2, 116 | case_blocks=(leaf, internal)) 117 | fwd.resolve_to(tree) 118 | 119 | with self.test_session() as sess: 120 | c = tdc.Compiler.create(tree) 121 | feed_dict, labels = c.build_feed_dict([tree1, tree2], metric_labels=True) 122 | self.assertEqual(['b', 'c', 'd', 'g'], labels['leaf']) 123 | self.assertEqual(['a', 'e', 'f'], labels['internal']) 124 | leaf_values, internal_values = sess.run( 125 | [c.metric_tensors['leaf'], c.metric_tensors['internal']], feed_dict) 126 | np.testing.assert_equal([2, 3, 4, 7], leaf_values) 127 | np.testing.assert_equal([1, 5, 6], internal_values) 128 | 129 | if __name__ == '__main__': 130 | test_lib.main() 131 | -------------------------------------------------------------------------------- /tensorflow_fold/blocks/test_lib.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Common methods for testing TensorFlow Fold.""" 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | import collections 19 | import os 20 | # import google3 21 | import six 22 | from six.moves import xrange # pylint: disable=redefined-builtin 23 | import tensorflow as tf 24 | from tensorflow_fold.util import proto_tools 25 | 26 | # pylint: disable=g-import-not-at-top,unused-import 27 | if six.PY3: 28 | import unittest.mock as mock 29 | else: 30 | import mock 31 | # pylint: enable=g-import-not-at-top,unused-import 32 | 33 | # Make sure SerializedMessageToTree can see our proto files. 34 | proto_tools.map_proto_source_tree_path( 35 | '', os.getcwd()) # Tests run in the bazel root directory. 36 | proto_tools.import_proto_file('tensorflow_fold/util/test.proto') 37 | proto_tools.import_proto_file('tensorflow_fold/util/test3.proto') 38 | 39 | 40 | class TestCase(tf.test.TestCase): 41 | 42 | def assertRaisesWithLiteralMatch(self, exception, literal, callable_obj, 43 | *args, **kwargs): 44 | with self.assertRaises(exception) as ctx: 45 | callable_obj(*args, **kwargs) 46 | self.assertEqual(str(ctx.exception), literal) 47 | 48 | # Open-sourced here: 49 | # 50 | def assertSameStructure(self, a, b, aname='a', bname='b', msg=None): 51 | """Asserts that two values contain the same structural content. 52 | 53 | The two arguments should be data trees consisting of trees of dicts and 54 | lists. They will be deeply compared by walking into the contents of dicts 55 | and lists; other items will be compared using the == operator. 56 | If the two structures differ in content, the failure message will indicate 57 | the location within the structures where the first difference is found. 58 | This may be helpful when comparing large structures. 59 | 60 | Args: 61 | a: The first structure to compare. 62 | b: The second structure to compare. 63 | aname: Variable name to use for the first structure in assertion messages. 64 | bname: Variable name to use for the second structure. 65 | msg: Additional text to include in the failure message. 66 | """ 67 | # Accumulate all the problems found so we can report all of them at once 68 | # rather than just stopping at the first 69 | problems = [] 70 | 71 | _walk_structure_for_problems(a, b, aname, bname, problems) 72 | 73 | # Avoid spamming the user toooo much 74 | max_problems_to_show = self.maxDiff // 80 75 | if len(problems) > max_problems_to_show: 76 | problems = problems[0:max_problems_to_show-1] + ['...'] 77 | 78 | if problems: 79 | failure_message = '; '.join(problems) 80 | if msg: 81 | failure_message += (': ' + msg) 82 | self.fail(failure_message) 83 | 84 | 85 | # Open-sourced here: 86 | # 87 | def _walk_structure_for_problems(a, b, aname, bname, problem_list): 88 | """The recursive comparison behind assertSameStructure.""" 89 | if type(a) != type(b): # pylint: disable=unidiomatic-typecheck 90 | problem_list.append('%s is a %r but %s is a %r' % 91 | (aname, type(a), bname, type(b))) 92 | # If they have different types there's no point continuing 93 | return 94 | 95 | if isinstance(a, collections.Mapping): 96 | for k in a: 97 | if k in b: 98 | _walk_structure_for_problems( 99 | a[k], b[k], '%s[%r]' % (aname, k), '%s[%r]' % (bname, k), 100 | problem_list) 101 | else: 102 | problem_list.append('%s has [%r] but %s does not' % (aname, k, bname)) 103 | for k in b: 104 | if k not in a: 105 | problem_list.append('%s lacks [%r] but %s has it' % (aname, k, bname)) 106 | 107 | # Strings are Sequences but we'll just do those with regular != 108 | elif (isinstance(a, collections.Sequence) and 109 | not isinstance(a, six.string_types)): 110 | minlen = min(len(a), len(b)) 111 | for i in xrange(minlen): 112 | _walk_structure_for_problems( 113 | a[i], b[i], '%s[%d]' % (aname, i), '%s[%d]' % (bname, i), 114 | problem_list) 115 | for i in xrange(minlen, len(a)): 116 | problem_list.append('%s has [%i] but %s does not' % (aname, i, bname)) 117 | for i in xrange(minlen, len(b)): 118 | problem_list.append('%s lacks [%i] but %s has it' % (aname, i, bname)) 119 | 120 | else: 121 | if a != b: 122 | problem_list.append('%s is %r but %s is %r' % (aname, a, bname, b)) 123 | 124 | 125 | def main(): 126 | tf.test.main() 127 | -------------------------------------------------------------------------------- /tensorflow_fold/blocks/util_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for tensorflow_fold.blocks.util.""" 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | # import google3 19 | from six.moves import xrange # pylint: disable=redefined-builtin 20 | import tensorflow as tf 21 | from tensorflow_fold.blocks import test_lib 22 | from tensorflow_fold.blocks import util 23 | 24 | 25 | class UtilTest(test_lib.TestCase): 26 | 27 | def test_edible_iterator_int(self): 28 | with self.test_session() as sess: 29 | i = util.EdibleIterator(x for x in [2, 4, 6]) 30 | x = tf.placeholder(tf.int32) 31 | self.assertEqual([4, 8, 12], sess.run(x + x, {x: i}).tolist()) 32 | self.assertEqual([3, 5, 7], sess.run(x + 1, {x: i}).tolist()) 33 | 34 | def test_edible_iterator_str(self): 35 | with self.test_session() as sess: 36 | i = util.EdibleIterator(x for x in ['foo', 'bar']) 37 | x = tf.placeholder(tf.string) 38 | self.assertEqual(b'foo', sess.run(x[0], {x: i})) 39 | self.assertEqual(b'bar', sess.run(x[1], {x: i})) 40 | self.assertEqual([b'foo', b'bar', b'foo', b'bar'], 41 | sess.run(tf.concat([x, x], 0), {x: i}).tolist()) 42 | 43 | def test_edible_iterator_empty(self): 44 | with self.test_session() as sess: 45 | i = util.EdibleIterator(iter([])) 46 | x = tf.placeholder(tf.string) 47 | self.assertEqual([[]], sess.run(tf.expand_dims(x, 0), {x: i}).tolist()) 48 | self.assertEqual([b'foo'], 49 | sess.run(tf.concat([['foo'], x], 0), {x: i}).tolist()) 50 | 51 | def test_group_by_batches(self): 52 | self.assertEqual([], list(util.group_by_batches([], 2))) 53 | self.assertEqual([[1], [2], [3]], list(util.group_by_batches([1, 2, 3], 1))) 54 | self.assertEqual([[1, 2], [3]], list(util.group_by_batches([1, 2, 3], 2))) 55 | 56 | def test_group_by_batches_truncated(self): 57 | self.assertEqual([], list(util.group_by_batches([], 2, truncate=True))) 58 | self.assertEqual([[1], [2], [3]], 59 | list(util.group_by_batches([1, 2, 3], 1, truncate=True))) 60 | self.assertEqual([[1, 2]], 61 | list(util.group_by_batches([1, 2, 3], 2, truncate=True))) 62 | 63 | def test_epochs(self): 64 | self.assertEqual([[0, 0]] * 5, 65 | [list(x) for x in util.epochs((0 for _ in xrange(2)), 5)]) 66 | epochs = util.epochs(xrange(5), shuffle=False) 67 | self.assertSequenceEqual(list(next(epochs)), xrange(5)) 68 | self.assertSequenceEqual(list(next(epochs)), xrange(5)) 69 | self.assertSequenceEqual(list(next(epochs)), xrange(5)) 70 | epochs = util.epochs(xrange(5)) 71 | self.assertSequenceEqual(list(next(epochs)), xrange(5)) 72 | self.assertEqual(set(next(epochs)), set(xrange(5))) 73 | self.assertEqual(set(next(epochs)), set(xrange(5))) 74 | 75 | def test_epochs_n_is_one(self): 76 | items = [1] 77 | result, = list(util.epochs(items, 1)) 78 | self.assertIs(items, result) 79 | 80 | 81 | if __name__ == '__main__': 82 | test_lib.main() 83 | -------------------------------------------------------------------------------- /tensorflow_fold/g3doc/animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/fold/0e7ca14832a14a5f2009d4e0424783a80e7d7a2c/tensorflow_fold/g3doc/animation.gif -------------------------------------------------------------------------------- /tensorflow_fold/g3doc/cc/ClassWeaverOpBase.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # `class tensorflow::fold::WeaverOpBase` 4 | 5 | 6 | 7 | ` WeaverOpBase ` is a base class for writing TensorFlow ops kernels that schedule ops for Loom. 8 | 9 | Operations created as subclasses of ` WeaverOpBase ` should be registered with the `REGISTER_WEAVER_OP` macro. For example, ` DeserializingWeaverOp ` is registered using: 10 | 11 | ```c++ 12 | REGISTER_WEAVER_OP("DeserializingWeaver").Input("weaver_messages: string"); 13 | ``` 14 | 15 | And 16 | 17 | ```c++ 18 | REGISTER_KERNEL_BUILDER( 19 | Name("DeserializingWeaver").Device(tensorflow::DEVICE_CPU), 20 | DeserializingWeaverOp); 21 | ``` 22 | 23 | ###Member Details 24 | 25 | 26 | #### `tensorflow::fold::WeaverOpBase::WeaverOpBase(tensorflow::OpKernelConstruction *c)` 27 | 28 | 29 | 30 | Reads the `metadata`, `constant_types`, and `num_types_shapes` attributes and makes sure they're consistent. Dies if they're not. 31 | 32 | 33 | #### `virtual tensorflow::Status tensorflow::fold::WeaverOpBase::Weave(tensorflow::OpKernelContext *c, Weaver *weaver)=0` 34 | 35 | 36 | 37 | Weave is a virtual method, to be subclassed. Weave's responsibility is to read the ops inputs and use the weaver to schedule LoomOps to be executed on the loom. `Weave` should not call ` Weaver::Finalize `. 38 | 39 | 40 | #### `void tensorflow::fold::WeaverOpBase::Compute(tensorflow::OpKernelContext *c) override` 41 | 42 | 43 | 44 | Dispatches to `Weave` to build a ` Weaver `, which is then used to build the wiring diagram and constant tensors that the loom needs. 45 | -------------------------------------------------------------------------------- /tensorflow_fold/g3doc/cc/index.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # TensorFlow Fold C++ Weaver API 4 | 5 | ## Weaver 6 | 7 | The Weaver API allows the user to build a schedule to be run by a Loom, 8 | producing vectors of integers (wiring diagrams) to be fed to Loom's `tf.gather` 9 | ops in order to drive the loom according to the schedule the user specified. 10 | 11 | * [tensorflow::fold::Weaver](ClassWeaver.md) 12 | 13 | ## WeaverOpBase 14 | 15 | WeaverOpBase contains the common code required to encapsulate a function 16 | that uses a weaver to build a graph for a Loom to run inside of a TensorFlow 17 | operation capable of driving the aforementioned Loom. 18 | 19 | * [tensorflow::fold::WeaverOpBase](ClassWeaverOpBase.md) 20 | 21 | 22 | -------------------------------------------------------------------------------- /tensorflow_fold/g3doc/index.md: -------------------------------------------------------------------------------- 1 | # TensorFlow Fold: Deep Learning with Dynamic Computation Graphs 2 | 3 | TensorFlow Fold is a library for creating TensorFlow models that consume 4 | structured data, such as nested lists, dictionaries, 5 | and 6 | [protocol buffers](https://developers.google.com/protocol-buffers/). Examples of 7 | such models 8 | are 9 | [tree-recursive neural networks](https://en.wikipedia.org/wiki/Recursive_neural_network) 10 | such as models of the 11 | [Stanford sentiment treebank](http://nlp.stanford.edu/sentiment/index.html), 12 | [tree LSTMs](https://arxiv.org/pdf/1503.00075.pdf), 13 | [hierarchical LSTMs](https://arxiv.org/pdf/1506.01057v2.pdf), and 14 | [graph-convolutional neural networks](https://arxiv.org/pdf/1603.00856v3.pdf). 15 | 16 | TensorFlow by itself was not designed to work with tree or graph structured 17 | data. It does not natively support any data types other than tensors, nor does 18 | it support the complex control flow, such as recursive functions, that are 19 | typically used to run models like tree-RNNs. When the input consists of trees 20 | (e.g. parse trees from a natural language model), each tree may have a different 21 | size and shape. A standard TensorFlow model consists of a fixed graph of 22 | operations, which cannot accommodate variable-shaped data. Fold overcomes this 23 | limitation by using 24 | the [dynamic batching algorithm](https://arxiv.org/abs/1702.02181). 25 | 26 | Fold consists of a high-level API called Blocks, and a low-level API called 27 | Loom. Blocks are pure Python, whereas Loom is a mixture of Python and 28 | C++. Internally, Blocks uses Loom as its execution engine. Loom is an 29 | abstraction layer on top of TensorFlow that makes it possible to easily express 30 | computations over structures of varying sizes and shapes without the need to 31 | modify the underlying computation graph at run-time. 32 | 33 | ## Quick Links 34 | 35 | * [Blocks Tutorial](blocks.md) 36 | * [Running Blocks in TensorFlow](running.md) 37 | * [Blocks Type System](types.md) 38 | * [Python Blocks API](py/td.md) 39 | * [Python Loom API](py/loom.md) 40 | * [C++ Weaver API](cc/index.md) 41 | * [Protocol Buffer Decoding](proto.md) 42 | -------------------------------------------------------------------------------- /tensorflow_fold/g3doc/proto.md: -------------------------------------------------------------------------------- 1 | # Protocol buffers 2 | 3 | ## Introduction 4 | 5 | The `tensorflow_fold.util.proto_tools` Python C++ extension provides a function 6 | called `serialized_message_to_tree` which takes the message type of a protocol 7 | buffer and its content (serialized as a string) and converts it into a nested 8 | native Python data-structure composed of dictionaries and lists. This 9 | function's behavior is analogous to `json.loads` (except that enum values are 10 | treated specially, which will be described in more detail below.) 11 | 12 | The Fold blocks API provides a 13 | [`SerializedMessageToTree`](py/td.md#tdserializedmessagetotreemessage_type_name) 14 | block that serves as a convenient wrapper for this function. 15 | 16 | ## Rationale 17 | 18 | The outputs of `serialized_message_to_tree` can be traversed faster than the 19 | Python protocol buffer API, the resulting traversal code is more Pythonic; in 20 | particular, it eliminates the need for separate Fold blocks for dealing with 21 | protocol buffers and data loaded from JSON or other sources. 22 | 23 | ## Setup 24 | 25 | Before `serialized_message_to_tree` can be called, `proto_tools` must be told 26 | the locations of the `.proto` files (which define schema of the the protocol 27 | buffers) using `map_proto_source_tree_path(virtual_path, disk_path)`. One or 28 | more calls to `map_proto_source_tree_path` will build up a virtual source tree 29 | (in a manner analgous to Unix's `mount` command with the arguments reversed.) 30 | If all your proto files are in a single directory and their absolute import 31 | statements are written relative to that directory, then a single call to: 32 | `map_proto_source_tree_path("", dir_path)` will suffice. 33 | 34 | Next, the protocol buffer message types that you care about should be imported 35 | using `proto_tools.import_proto_file(viritual_path)`. One of the calls to 36 | `map_proto_source_tree_path` must have taken a virtual path which is a prefix of 37 | `virtual_path` for the import to resolve. `virtual_path` should point to a 38 | valid `.proto` file (after the path has been resolved), as should any paths in 39 | any import statements the `.proto` file might contain, etc. 40 | 41 | Once this is done, `proto_tools.serialized_message_to_tree(message_type, 42 | str)` should work properly with any protocol buffer message types declared in 43 | the imported proto files. (Here `message_type` is the fully qualified message 44 | type which includes the package name, e.g. `tensorflow.fold.LoomMetadata`.) 45 | 46 | See [util/proto\_test.py](../util/proto_test.py) for example usages. 47 | 48 | ## Outputs 49 | 50 | Most types of proto fields are dealt with straight-forwardly. Strings fields 51 | become Python strings, integers become Python integers, floats become Python 52 | floats and so on. Repeated fields are rendered as Python lists. Submessages 53 | are rendered as Python dictionaries whose keys are strings. 54 | 55 | Enum fields are converted into Python dictionaries containing the name, index, 56 | and number of the enum value (keyed by "name", "index", and "number" 57 | respectively.) 58 | -------------------------------------------------------------------------------- /tensorflow_fold/g3doc/py/wiring.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/fold/0e7ca14832a14a5f2009d4e0424783a80e7d7a2c/tensorflow_fold/g3doc/py/wiring.png -------------------------------------------------------------------------------- /tensorflow_fold/g3doc/setup.md: -------------------------------------------------------------------------------- 1 | # Download and Setup 2 | 3 | Fold runs under Linux; we have not tested it on other platforms. Python 2.7 and 4 | 3.3+ are both supported. We recommend installing 5 | using [Virtualenv](http://docs.python-guide.org/en/latest/dev/virtualenvs/) 6 | and [pip](https://pip.pypa.io/en/stable/). See [here](sources.md) for instructions on installing from 7 | sources, if that's how you roll. If you run into trouble, the TensorFlow main 8 | site has a list 9 | of 10 | [common problems](https://www.tensorflow.org/versions/r1.0/get_started/os_setup#common_problems) with 11 | some solutions that might be helpful. 12 | 13 | Please note that Fold requires TensorFlow 1.0; it is not compatible with earlier 14 | versions due to breaking API changes. 15 | 16 | First install Python, pip, and Virtualenv: 17 | 18 | ``` 19 | sudo apt-get install python-pip python-dev python-virtualenv 20 | ``` 21 | 22 | Create a Virtualenv environment in the directory `foo`: 23 | 24 | ``` 25 | virtualenv foo # for Python 2.7 26 | virtualenv -p python3 foo # for Python 3.3+ 27 | ``` 28 | 29 | Activate the environment: 30 | 31 | ``` 32 | source ./foo/bin/activate # if using bash 33 | source ./foo/bin/activate.csh # if using csh 34 | ``` 35 | 36 | Install the pip package for TensorFlow. For Python 2.7 CPU-only, this will be: 37 | 38 | ``` 39 | pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.0.0rc0-cp27-none-linux_x86_64.whl 40 | ``` 41 | 42 | For Python 3.3+ and/or GPU, 43 | see 44 | [here](https://www.tensorflow.org/versions/r1.0/get_started/os_setup#using_pip) 45 | for the full list of available TF binaries. 46 | 47 | Check that TensorFlow can load: 48 | 49 | ``` 50 | python -c 'import tensorflow' 51 | ``` 52 | 53 | Install the pip package for Fold. For Python 2.7, this will be: 54 | 55 | ``` 56 | pip install https://storage.googleapis.com/tensorflow_fold/tensorflow_fold-0.0.1-cp27-none-linux_x86_64.whl 57 | ``` 58 | 59 | For Python 3.3: 60 | 61 | ``` 62 | pip install https://storage.googleapis.com/tensorflow_fold/tensorflow_fold-0.0.1-py3-none-linux_x86_64.whl 63 | ``` 64 | 65 | Check that Fold can load: 66 | 67 | ``` 68 | python -c 'import tensorflow_fold' 69 | ``` 70 | 71 | Success! 72 | 73 | ## Next steps 74 | 75 | * try out the [quick start notebook](quick.ipynb) 76 | * browse the [documentation](index.md) 77 | * hacks and glory 78 | -------------------------------------------------------------------------------- /tensorflow_fold/g3doc/sources.md: -------------------------------------------------------------------------------- 1 | # Source Installation 2 | 3 | Building Fold requires Bazel; get 4 | it [here](https://bazel.build/versions/master/docs/install.html). The do: 5 | 6 | ``` 7 | virtualenv foo 8 | source ./foo/bin/activate 9 | pip install pip --upgrade 10 | pip install wheel --upgrade 11 | pip install numpy --upgrade 12 | git clone --recurse-submodules https://github.com/tensorflow/fold 13 | cd fold/tensorflow 14 | ./configure 15 | cd .. 16 | ``` 17 | 18 | Follow the 19 | instructions 20 | [here](https://www.tensorflow.org/get_started/os_setup#configure_the_installation) if 21 | you need help with the `configure` script; Fold inherits its configuration 22 | options, such as the location of Python and which optimization flags to use 23 | during compilation, from TensorFlow. 24 | 25 | ### Running the tests (optional) 26 | 27 | To run the unit tests, do: 28 | 29 | ``` 30 | pip install mock --upgrade 31 | bazel test --config=opt tensorflow_fold/... 32 | ``` 33 | 34 | When using CUDA on GPU, tests must be run sequentially: 35 | ``` 36 | bazel test --config=opt --config=cuda --jobs=1 tensorflow_fold/... 37 | ``` 38 | 39 | There is also a smoke test that runs all of the included examples: 40 | 41 | ``` 42 | pip install nltk --upgrade 43 | ./tensorflow_fold/run_all_examples.sh --config=opt 44 | ``` 45 | 46 | ### Building and installing pip wheels 47 | 48 | Build a pip wheel for Fold like so: 49 | 50 | ``` 51 | bazel build --config=opt //tensorflow_fold/util:build_pip_package 52 | ./bazel-bin/tensorflow_fold/util/build_pip_package /tmp/fold_pkg 53 | ``` 54 | 55 | You also need to build a pip wheel for TensorFlow. Unfortuately this means we 56 | need to rebuild all of TensorFlow, due to known Bazel limitations 57 | ([#1248](https://github.com/bazelbuild/bazel/issues/1248)). If you want skip 58 | this step and reuse an existing TensorFlow wheel file, make sure that the 59 | configuration and version are the same ones that Fold has to ensure consistency. 60 | 61 | ``` 62 | cd tensorflow 63 | bazel build --config=opt //tensorflow/tools/pip_package:build_pip_package 64 | ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg 65 | cd .. 66 | ``` 67 | 68 | Now install the wheels. The precise names of the `.whl` files will 69 | depend on your platform. 70 | 71 | ``` 72 | pip install /tmp/fold_pkg/tensorflow_fold-0.0.1-cp27-cp27mu-linux_x86_64.whl 73 | pip install /tmp/tensorflow_pkg/tensorflow-1.0.0rc0-cp27-cp27mu-linux_x86_64.whl 74 | ``` 75 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/README.md: -------------------------------------------------------------------------------- 1 | ### Introduction 2 | 3 | LLGTM *(The Low-level Library for Gradients, Tensors, and Matrices)* is a C++ 4 | library for deep learning models that use dynamic computation graphs. 5 | 6 | LLGTM is intended to be an alternative to the Loom library. Loom is written in 7 | python, and implements dynamic computation graphs by emulating such graphs on 8 | top of TensorFlow. The advantage of Loom is that it integrates cleanly with the 9 | rest of tensorflow, but the disadvantage is that the python interpreter becomes 10 | part of the main evaluation loop. Since python is a single-threaded, interpreted 11 | language, it can become a significant bottleneck. 12 | 13 | LLGTM has two major design goals. First, it makes graph construction and 14 | differentiation very fast. Graphs are light-weight, and allocated in an arena, 15 | so they can be created and destroyed quickly. 16 | 17 | Second, LLGTM separates graph construction and differentiation from graph 18 | evaluation. In other words, it supports multiple evaluation backends. The 19 | initial release supports two backends: a reference implementation that uses 20 | [Eigen](http://eigen.tuxfamily.org), and a TensorFlow backend that invokes 21 | TensorFlow kernels. Additional backends may be provided in the future, with no 22 | change to the user-facing API. 23 | 24 | ### Known Issues 25 | 26 | LLGTM is currently in *pre-alpha*. Simple examples compile and run, but but many 27 | features are still missing and/or incomplete. LLGTM is being released as open 28 | source in order to solicit feedback and contributions from users of TensorFlow 29 | Fold. However, we ***strongly*** suggest that users continue to use the Loom 30 | library for the time being. 31 | 32 | Known issues include the following: 33 | 34 | * Very few operations are supported. 35 | * No way to save or restore models. 36 | * Lots of other missing features. 37 | * The TensorFlow backend relies on non-public APIs, and thus does not compile without visibility chainges to the TensorFlow BUILD file. 38 | * The Eigen backend includes several kernels that have not been optimized. Do not expect performance to be competitive with state-of-the art models. 39 | 40 | ### Usage 41 | 42 | To build and test LLGTM, run the following from the *llgtm* directory: 43 | 44 | ``` 45 | bazel test :cpu_tests 46 | bazel test --config=cuda :gpu_tests 47 | source examples/run_all_examples.sh 48 | ``` 49 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/backend/eigen_evaluator_client.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_FOLD_LLGTM_BACKEND_EIGEN_EVALUATOR_CLIENT_H_ 17 | #define TENSORFLOW_FOLD_LLGTM_BACKEND_EIGEN_EVALUATOR_CLIENT_H_ 18 | 19 | // Clients should include this header file if graphs are to be evaluated with 20 | // Eigen. Only one backend may be selected per compilation unit. 21 | 22 | // This evaluator is the "reference" implementation with implementations for 23 | // all operations. Other evaluators can fall back onto this evaluator if they 24 | // don't provide an implementation for an operation. 25 | 26 | // The entry point into the evaluator during graph evaluation is InvokeKernel 27 | // in TensorNodeSelf, where Self is the type of the node, e.g., 28 | // Add
. This method dispatches to the appropriate kernel implementation 29 | // EigenKernel::Kernel, e.g., EigenKernel>::Kernel. Kernels are 30 | // declared in eigen_evaluator.h but defined in eigen_evaluator.cc, such that 31 | // including (selecting) an evaluator header file does not transitively include 32 | // any implementation-specific header such as the Eigen headers. 33 | 34 | // Kernels are wrapped in EigenKernel such that other evaluators can 35 | // include eigen_evaluator.h and fall back on its kernel implementations. 36 | // (As opposed to putting kernel code in InvokeKernel directly.) 37 | 38 | // Only one evaluator header file may be included per translation unit because 39 | // every evaluator provides different implementations of InvokeKernel for 40 | // TensorNode subclasses. Including multiple evaluators is in violation of the 41 | // C++ One Definition Rule and leads to a compile error if detected. 42 | 43 | 44 | #include "tensorflow_fold/llgtm/backend/eigen_evaluator.h" 45 | #include "tensorflow_fold/llgtm/tensor.h" 46 | 47 | #ifdef LLGTM_BACKEND_SELECTED 48 | #error "Multiple backends selected. Include only one backend header file." 49 | #else 50 | 51 | #define LLGTM_BACKEND_SELECTED Eigen 52 | 53 | namespace llgtm { 54 | 55 | class Graph; 56 | 57 | template 58 | void nodes::TensorNodeBaseSelf::InvokeKernel(Graph* graph) { 59 | LaunchEigenKernel(reinterpret_cast(this), graph); 60 | } 61 | 62 | template 63 | void nodes::TensorNodeSelf::InvokeKernel(Graph* graph) { 64 | LaunchEigenKernel(reinterpret_cast(this), graph); 65 | } 66 | 67 | #define LLGTM_NODE_DEFINITION(NODE) \ 68 | template class NODE; \ 69 | template class NODE; 70 | #define LLGTM_NODE_DEFINITION_FP(NODE) \ 71 | template class NODE; 72 | #include "tensorflow_fold/llgtm/backend/llgtm_nodes.inc" 73 | #undef LLGTM_NODE_DEFINITION 74 | 75 | // Handle classes that are not templatized by data type separately. 76 | template class nodes::TensorNodeBaseSelf; 77 | 78 | } // namespace llgtm 79 | 80 | #endif // LLGTM_BACKEND_SELECTED 81 | #endif // TENSORFLOW_FOLD_LLGTM_BACKEND_EIGEN_EVALUATOR_CLIENT_H_ 82 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/backend/eigen_graph_implementation.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow_fold/llgtm/backend/eigen_graph_implementation.h" 17 | 18 | #ifdef GOOGLE_CUDA 19 | #define EIGEN_USE_GPU 20 | #endif // GOOGLE_CUDA 21 | 22 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 23 | 24 | namespace llgtm { 25 | 26 | size_t EigenGraphImplementation::AllocateResultData( 27 | nodes::TensorNodeBase* node) { 28 | if (!node->allocates_result_data()) { 29 | return 0; 30 | } 31 | 32 | size_t size = node->result_data_size(); 33 | if (size > 0) { 34 | void* rdata = nullptr; 35 | switch (node->device()) { 36 | case kDeviceIDCPU: 37 | rdata = default_device_->allocate(size); 38 | default_allocations_.push_back(rdata); 39 | break; 40 | case kDeviceIDGPU: 41 | #ifdef GOOGLE_CUDA 42 | rdata = gpu_device_->allocate(size); 43 | gpu_allocations_.push_back(rdata); 44 | #else 45 | LOG(FATAL) << "LLGTM built without CUDA support. Use --config=cuda."; 46 | #endif // GOOGLE_CUDA 47 | break; 48 | default: 49 | LOG(FATAL) << "Device not supported: " << device_name(node->device()) 50 | << "."; 51 | } 52 | DCHECK(rdata); 53 | node->set_result_data(rdata); 54 | } 55 | return size; 56 | } 57 | 58 | 59 | EigenGraphImplementation::~EigenGraphImplementation() { 60 | for (void* d : default_allocations_) { 61 | default_device_->deallocate(d); 62 | } 63 | #ifdef GOOGLE_CUDA 64 | for (void* d : gpu_allocations_) { 65 | gpu_device_->deallocate(d); 66 | } 67 | #endif // GOOGLE_CUDA 68 | } 69 | 70 | } // namespace llgtm 71 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/backend/eigen_graph_implementation.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_FOLD_LLGTM_BACKEND_EIGEN_GRAPH_IMPLEMENTATION_H_ 17 | #define TENSORFLOW_FOLD_LLGTM_BACKEND_EIGEN_GRAPH_IMPLEMENTATION_H_ 18 | 19 | #include "tensorflow_fold/llgtm/graph_implementation.h" 20 | 21 | namespace Eigen { 22 | class DefaultDevice; 23 | class GpuDevice; 24 | } // namespace Eigen 25 | 26 | namespace llgtm { 27 | 28 | class EigenGraphImplementation : public GraphImplementation { 29 | public: 30 | EigenGraphImplementation(Eigen::DefaultDevice* default_device, 31 | Eigen::GpuDevice* gpu_device) 32 | : GraphImplementation(), 33 | default_device_(default_device), gpu_device_(gpu_device) {} 34 | 35 | size_t AllocateResultData(nodes::TensorNodeBase* node) override; 36 | 37 | ~EigenGraphImplementation() override; 38 | 39 | private: 40 | // Devices belong to EigenEvaluator. 41 | Eigen::DefaultDevice* default_device_; 42 | Eigen::GpuDevice* gpu_device_; 43 | 44 | std::vector default_allocations_; 45 | std::vector gpu_allocations_; 46 | }; 47 | 48 | } // namespace llgtm 49 | 50 | #endif // TENSORFLOW_FOLD_LLGTM_BACKEND_EIGEN_GRAPH_IMPLEMENTATION_H_ 51 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/backend/llgtm_nodes.inc: -------------------------------------------------------------------------------- 1 | // This file lists all node types that are templatized by data type. 2 | // Nodes that are not templatized by data type (DT) are handled separately. 3 | 4 | // Since the Eigen backend is the reference backend implementation, it must 5 | // support all LLGTM nodes types. 6 | 7 | // This header file is used to generate source code for all node types in 8 | // various places without having to write down all node types over and over. 9 | // It should be used as follows: 10 | // 1. Define a preprocessor macro LLGTM_NODE_DEFINITION(NODE) expanding to the 11 | // source code that should be generated for every node. 12 | // 2. (Optional) Define a preprocessor macro LLGTM_NODE_DEFINITION_FP, for 13 | // for node types that are floating-point only. If unspecified, this will 14 | // default to LLGTM_NODE_DEFINITION. 15 | // 3. Include llgtm_nodes.inc. Note: It can be included multiple times. 16 | // 4. Undefine LLGTM_DEFINITION_NODE 17 | // See also "Textual Headers" in go/totw/127. 18 | 19 | // The main use case of this file is to force template expansion of all node 20 | // types/classes. 21 | 22 | #ifndef LLGTM_NODE_DEFINITION_FP 23 | #define LLGTM_NODE_DEFINITION_FP(NODE) LLGTM_NODE_DEFINITION(NODE) 24 | #endif 25 | 26 | LLGTM_NODE_DEFINITION(nodes::Add); 27 | LLGTM_NODE_DEFINITION(nodes::AssignAdd); 28 | LLGTM_NODE_DEFINITION(nodes::Broadcast); 29 | LLGTM_NODE_DEFINITION(nodes::Concat); 30 | LLGTM_NODE_DEFINITION(nodes::ConstantFromScalar); 31 | LLGTM_NODE_DEFINITION(nodes::CopyToDevice); 32 | LLGTM_NODE_DEFINITION(nodes::Gather); 33 | LLGTM_NODE_DEFINITION(nodes::Matmul); 34 | LLGTM_NODE_DEFINITION(nodes::Multiply); 35 | LLGTM_NODE_DEFINITION(nodes::Negative); 36 | LLGTM_NODE_DEFINITION_FP(nodes::NormalRandom); 37 | LLGTM_NODE_DEFINITION_FP(nodes::Reciprocal); 38 | LLGTM_NODE_DEFINITION(nodes::Relu); 39 | LLGTM_NODE_DEFINITION(nodes::ReluGrad); 40 | LLGTM_NODE_DEFINITION(nodes::ReduceSum); 41 | LLGTM_NODE_DEFINITION(nodes::Reshape); 42 | LLGTM_NODE_DEFINITION(nodes::Scatter); 43 | LLGTM_NODE_DEFINITION_FP(nodes::Sigmoid); 44 | LLGTM_NODE_DEFINITION_FP(nodes::Softmax); 45 | LLGTM_NODE_DEFINITION_FP(nodes::SoftmaxCrossEntropy); 46 | LLGTM_NODE_DEFINITION_FP(nodes::SoftmaxSparseCrossEntropy); 47 | LLGTM_NODE_DEFINITION_FP(nodes::SoftmaxSparseCrossEntropyGrad); 48 | LLGTM_NODE_DEFINITION(nodes::Split); 49 | LLGTM_NODE_DEFINITION_FP(nodes::Tanh); 50 | LLGTM_NODE_DEFINITION(nodes::TensorValue); 51 | LLGTM_NODE_DEFINITION(nodes::TensorVariable); 52 | LLGTM_NODE_DEFINITION(nodes::Transpose); 53 | LLGTM_NODE_DEFINITION_FP(nodes::UniformRandom); 54 | LLGTM_NODE_DEFINITION(nodes::Zeros); 55 | 56 | #undef LLGTM_NODE_DEFINITION_FP 57 | 58 | // Not listed here (due to different template parameters): 59 | // LLGTM_NODE_DEFINITION(nodes::GetOutput) 60 | // LLGTM_NODE_DEFINITION(nodes::ConstantFromFunction) 61 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/backend/tf_evaluator.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_FOLD_LLGTM_BACKEND_TF_EVALUATOR_H_ 17 | #define TENSORFLOW_FOLD_LLGTM_BACKEND_TF_EVALUATOR_H_ 18 | 19 | #include "tensorflow_fold/llgtm/backend/eigen_evaluator.h" 20 | #include "tensorflow_fold/llgtm/graph_evaluator.h" 21 | #include "tensorflow_fold/llgtm/tensor.h" 22 | #include "tensorflow_fold/llgtm/tensor_ops.h" 23 | 24 | namespace tensorflow { 25 | class Device; 26 | class DeviceMgr; 27 | }; 28 | 29 | namespace llgtm { 30 | 31 | class TfKernelAdapter; 32 | 33 | class TfGraphEvaluator : public EigenEvaluator { 34 | public: 35 | using EigenEvaluator::NewGraph; 36 | 37 | TfGraphEvaluator(DeviceID device, uint64 seed); 38 | explicit TfGraphEvaluator(DeviceID device); 39 | explicit TfGraphEvaluator(); 40 | ~TfGraphEvaluator() override; 41 | 42 | void Init(); 43 | 44 | GraphImplementation* NewGraphImpl() override; 45 | 46 | private: 47 | // TfGraphImplementation is defined in tf_evaluator.cc to avoid poluting this 48 | // header file with TF dependencies. 49 | friend class TfGraphImplementation; 50 | template friend class nodes::TensorNodeBaseSelf; 51 | template friend class nodes::TensorNodeSelf; 52 | 53 | void register_tf_tensor(nodes::TensorNodeBase* node, Graph* graph); 54 | 55 | // Default implementation: If no other specialization is available, fall back 56 | // to the corresponding Eigen kernel. 57 | template 58 | void InvokeKernel(NodeType* node, Graph* graph) { 59 | // Print a log message if a TF kernel was requested for a certain node type 60 | // but is not currently implemented. Exceptions are ConstantFromFunction, 61 | // GetOutput and Value: These operations will always use the Eigen kernel. 62 | if (node->opcode() != kOpConstantFromFunction && 63 | node->opcode() != kOpGetOutput && 64 | node->opcode() != kOpValue) { 65 | LOG(INFO) << "No TF kernel available for " << node->type_str() << "."; 66 | } 67 | 68 | // Just a sanity check: Ensure that we do not fall back to an Eigen kernel 69 | // if a node type is in tf_nodes.inc, i.e., it should have a TF kernel. 70 | #define LLGTM_TF_KERNEL_DEFINITION(NODETYPE) \ 71 | static_assert(!std::is_same, NodeType>::value, \ 72 | "Node in tf_nodes.inc, but no specialization applies."); 73 | #include "tensorflow_fold/llgtm/backend/tf_nodes.inc" 74 | #undef LLGTM_TF_KERNEL_DEFINITION 75 | 76 | LaunchEigenKernel(node, graph); 77 | register_tf_tensor(node, graph); 78 | } 79 | 80 | // TF kernel declarations. 81 | #define LLGTM_TF_KERNEL_DEFINITION(NODETYPE) \ 82 | void InvokeKernel(NODETYPE* node, Graph* graph); 83 | #include "tensorflow_fold/llgtm/backend/tf_nodes.inc" 84 | #undef LLGTM_TF_KERNEL_DEFINITION 85 | 86 | // Helper method that can invoke either UniformRandom or NormalRandom. 87 | template 88 | void InvokeRandomKernel(NodeType* node, Graph* graph, 89 | TfKernelAdapter* kernel); 90 | 91 | std::unique_ptr device_mgr_; 92 | 93 | // Note: these devices are owned by the above device_mgr and cached here. 94 | std::vector devices_; 95 | 96 | // Map from opcode to tensorflow kernel 97 | std::vector> kernels_; 98 | }; 99 | 100 | } // namespace llgtm 101 | 102 | #endif // TENSORFLOW_FOLD_LLGTM_BACKEND_TF_EVALUATOR_H_ 103 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/backend/tf_evaluator_client.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_FOLD_LLGTM_BACKEND_TF_EVALUATOR_CLIENT_H_ 17 | #define TENSORFLOW_FOLD_LLGTM_BACKEND_TF_EVALUATOR_CLIENT_H_ 18 | 19 | // Clients should include this header file if graphs are to be evaluated with 20 | // TensorFlow. Only one backend may be selected per compilation unit. 21 | 22 | // See eigen_evaluator.h (which is the reference backend implementation) for 23 | // more documentation on implementation details. 24 | 25 | #include "tensorflow_fold/llgtm/backend/tf_evaluator.h" 26 | #include "tensorflow_fold/llgtm/tensor.h" 27 | 28 | #ifdef LLGTM_BACKEND_SELECTED 29 | #error "Multiple backends selected. Include only one backend header file." 30 | #else 31 | 32 | #define LLGTM_BACKEND_SELECTED Tf 33 | 34 | namespace llgtm { 35 | 36 | class Graph; 37 | 38 | template 39 | void nodes::TensorNodeBaseSelf::InvokeKernel(Graph* graph) { 40 | CHECK_EQ(this->device(), kDeviceIDCPU); // TODO(matthiasspringer): GPU support. 41 | auto* evaluator = reinterpret_cast(graph->evaluator()); 42 | evaluator->InvokeKernel(reinterpret_cast(this), graph); 43 | } 44 | 45 | template 46 | void nodes::TensorNodeSelf::InvokeKernel(Graph* graph) { 47 | CHECK_EQ(this->device(), kDeviceIDCPU); // TODO(matthiasspringer): GPU support. 48 | auto* evaluator = reinterpret_cast(graph->evaluator()); 49 | evaluator->InvokeKernel(reinterpret_cast(this), graph); 50 | } 51 | 52 | #define LLGTM_NODE_DEFINITION(NODE) \ 53 | template class NODE; \ 54 | template class NODE; 55 | #include "tensorflow_fold/llgtm/backend/llgtm_nodes.inc" 56 | #undef LLGTM_NODE_DEFINITION 57 | 58 | // Handle classes that are not templatized by data type separately. 59 | template class nodes::TensorNodeBaseSelf; 60 | 61 | } // namespace llgtm 62 | 63 | #endif // LLGTM_BACKEND_SELECTED 64 | #endif // TENSORFLOW_FOLD_LLGTM_BACKEND_TF_EVALUATOR_CLIENT_H_ 65 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/backend/tf_nodes.inc: -------------------------------------------------------------------------------- 1 | // This file lists all node types that are supported by the TF backend. 2 | // Nodes that are not templatized by data type (DT) are handled separately. 3 | 4 | // See llgtm_nodes.inc for detailed explanations. 5 | 6 | LLGTM_TF_KERNEL_DEFINITION(nodes::Add); 7 | LLGTM_TF_KERNEL_DEFINITION(nodes::ConstantFromScalar); 8 | LLGTM_TF_KERNEL_DEFINITION(nodes::Matmul); 9 | LLGTM_TF_KERNEL_DEFINITION(nodes::Multiply); 10 | LLGTM_TF_KERNEL_DEFINITION(nodes::Negative); 11 | LLGTM_TF_KERNEL_DEFINITION(nodes::NormalRandom); 12 | LLGTM_TF_KERNEL_DEFINITION(nodes::Reciprocal); 13 | LLGTM_TF_KERNEL_DEFINITION(nodes::ReduceSum); 14 | LLGTM_TF_KERNEL_DEFINITION(nodes::Relu); 15 | LLGTM_TF_KERNEL_DEFINITION(nodes::ReluGrad); 16 | LLGTM_TF_KERNEL_DEFINITION(nodes::Reshape); 17 | LLGTM_TF_KERNEL_DEFINITION(nodes::Sigmoid); 18 | LLGTM_TF_KERNEL_DEFINITION(nodes::Tanh); 19 | LLGTM_TF_KERNEL_DEFINITION(nodes::Transpose); 20 | LLGTM_TF_KERNEL_DEFINITION(nodes::UniformRandom); 21 | LLGTM_TF_KERNEL_DEFINITION(nodes::Zeros); 22 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/device.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow_fold/llgtm/device.h" 17 | 18 | namespace llgtm { 19 | 20 | absl::string_view device_name(DeviceID device) { 21 | switch (device) { 22 | case kDeviceIDCPU: 23 | return "CPU (Device 0)"; 24 | case kDeviceIDGPU: 25 | return "GPU (Device 1)"; 26 | case kDeviceIDUnspecified: 27 | return "Unspecified Device"; 28 | default: 29 | return "Invalid Device"; 30 | } 31 | } 32 | 33 | } // namespace llgtm 34 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/device.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_FOLD_LLGTM_DEVICE_H_ 17 | #define TENSORFLOW_FOLD_LLGTM_DEVICE_H_ 18 | 19 | #include "absl/strings/string_view.h" 20 | 21 | namespace llgtm { 22 | 23 | // Identifies a device such as GPU or CPU. 24 | using DeviceID = uint8_t; 25 | 26 | // TODO(matthiasspringer): Replace hard-coded device IDs with proper device 27 | // management. 28 | enum Devices : DeviceID { 29 | kDeviceIDCPU = 0, 30 | kDeviceIDGPU = 1, 31 | 32 | // GPU device is available only in CUDA builds. Maximum ID is an 33 | // exclusive bound. 34 | #ifdef GOOGLE_CUDA 35 | kDeviceMaximumID = 2, 36 | #else 37 | kDeviceMaximumID = 1, 38 | #endif 39 | 40 | kDeviceIDUnspecified = 255 41 | }; 42 | 43 | // Return the name of the given device. 44 | absl::string_view device_name(DeviceID device); 45 | 46 | } // namespace llgtm 47 | 48 | #endif // TENSORFLOW_FOLD_LLGTM_DEVICE_H_ 49 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/examples/BUILD: -------------------------------------------------------------------------------- 1 | # Example models for LLGTM. 2 | 3 | licenses(["notice"]) # Apache 2.0 4 | 5 | cc_binary( 6 | name = "character_rnn", 7 | srcs = [ 8 | "character_rnn.cc", 9 | ], 10 | deps = [ 11 | "@com_google_absl//absl/memory", 12 | "@com_google_absl//absl/strings", 13 | "//tensorflow_fold/llgtm:llgtm_eigen", 14 | ], 15 | ) 16 | 17 | cc_binary( 18 | name = "tree_rnn", 19 | srcs = [ 20 | "parsetree.h", 21 | "tree_rnn.cc", 22 | ], 23 | deps = [ 24 | "@com_google_absl//absl/memory", 25 | "@com_google_absl//absl/strings", 26 | "//tensorflow_fold/llgtm:llgtm_eigen", 27 | ], 28 | ) 29 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/examples/parsetree.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | // Simple tree data structure used by tree_rnn.cc. 17 | 18 | #ifndef TENSORFLOW_FOLD_LLGTM_EXAMPLES_PARSETREE_H_ 19 | #define TENSORFLOW_FOLD_LLGTM_EXAMPLES_PARSETREE_H_ 20 | 21 | #include 22 | #include 23 | #include 24 | 25 | #include "absl/memory/memory.h" 26 | #include "absl/strings/string_view.h" 27 | #include "tensorflow_fold/llgtm/llgtm.h" 28 | 29 | namespace llgtm { 30 | namespace parse_trees { 31 | 32 | class LeafNode; 33 | class PhraseNode; 34 | 35 | // Base class for nodes in the parse tree. 36 | // Trees are an algebraic data type: 37 | // data TreeNode = LeafNode string | ParseNode TreeNode TreeNode 38 | // [https://en.wikipedia.org/wiki/Algebraic_data_type] 39 | class TreeNode { 40 | public: 41 | TreeNode() {} 42 | TreeNode(const TreeNode&) = delete; 43 | virtual ~TreeNode() {} 44 | 45 | TreeNode& operator=(const TreeNode&) = delete; 46 | 47 | // Derived classes must override one of the following to return this. 48 | virtual LeafNode* get_leaf() { return nullptr; } 49 | virtual PhraseNode* get_phrase() { return nullptr; } 50 | 51 | // Pattern match on the type of the TreeNode. 52 | // Execute leaf_case(this) or phrase_case(this) depending on type. 53 | template 54 | auto SwitchOnType(F1 leaf_case, F2 phrase_case) 55 | -> decltype(leaf_case(static_cast(nullptr))) { 56 | if (auto* leaf = get_leaf()) { 57 | return leaf_case(leaf); 58 | } else { 59 | auto* phrase = get_phrase(); 60 | CHECK(phrase != nullptr); 61 | return phrase_case(phrase); 62 | } 63 | } 64 | }; 65 | 66 | // A leaf (terminal) node in the parse tree, which contains a single word. 67 | class LeafNode : public TreeNode { 68 | public: 69 | explicit LeafNode(string word) : word_(std::move(word)) {} 70 | ~LeafNode() override {} 71 | 72 | absl::string_view word() const { return word_; } 73 | 74 | LeafNode* get_leaf() override { return this; } 75 | 76 | private: 77 | const string word_; 78 | }; 79 | 80 | // A phrase (non-terminal) in the parse tree, which contains sub-phrases. 81 | class PhraseNode : public TreeNode { 82 | public: 83 | using NodeType = std::unique_ptr; 84 | 85 | PhraseNode() = delete; 86 | ~PhraseNode() override {} 87 | 88 | explicit PhraseNode(NodeType a) { 89 | sub_nodes_.emplace_back(std::move(a)); 90 | } 91 | 92 | PhraseNode(NodeType a, NodeType b) { 93 | sub_nodes_.emplace_back(std::move(a)); 94 | sub_nodes_.emplace_back(std::move(b)); 95 | } 96 | 97 | PhraseNode(NodeType a, NodeType b, NodeType c) { 98 | // std::initializer_list doesn't work with move-only types, so we have 99 | // to do this the hard way. 100 | sub_nodes_.emplace_back(std::move(a)); 101 | sub_nodes_.emplace_back(std::move(b)); 102 | sub_nodes_.emplace_back(std::move(c)); 103 | } 104 | 105 | const std::vector& sub_nodes() const { return sub_nodes_; } 106 | 107 | PhraseNode* get_phrase() override { return this; } 108 | 109 | private: 110 | std::vector sub_nodes_; 111 | }; 112 | 113 | // Creates a new LeafNode. 114 | inline std::unique_ptr Leaf(string str) { 115 | return absl::make_unique(std::move(str)); 116 | } 117 | 118 | // Creates a new PhraseNode. 119 | template 120 | inline std::unique_ptr Phrase(Args... args) { 121 | return absl::make_unique(std::move(args)...); 122 | } 123 | 124 | } // namespace parse_trees 125 | } // namespace llgtm 126 | 127 | #endif // TENSORFLOW_FOLD_LLGTM_EXAMPLES_PARSETREE_H_ 128 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/examples/run_all_examples.sh: -------------------------------------------------------------------------------- 1 | # To run, type "source examples/run_all_examples.sh" 2 | 3 | bazel run -c opt examples:character_rnn -- --num_steps=100 --alsologtostderr 4 | bazel run -c opt examples:tree_rnn -- --num_steps=100 --alsologtostderr 5 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/graph.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow_fold/llgtm/graph.h" 17 | 18 | #include "tensorflow_fold/llgtm/layers.h" 19 | #include "tensorflow_fold/llgtm/tensor_ops_impl.h" 20 | 21 | namespace llgtm { 22 | 23 | TensorBase GraphImplementation::Layer( 24 | Graph* g, class Layer* layer, InputList inputs, DeviceID device) { 25 | // Default implementation merely adds the nodes to the graph. 26 | if (!layer->initialized()) { 27 | // On first invocation, initialize the layer and figure out the types. 28 | CHECK_EQ(inputs.size(), layer->num_inputs()); 29 | TensorBase result = layer->Invoke(g, inputs, device); 30 | layer->Initialize(g->evaluator(), inputs, result); 31 | return result; 32 | } else { 33 | // On subsequent invocations, check the types. 34 | int batch_size = layer->CheckInputs(inputs); 35 | TensorBase result = layer->Invoke(g, inputs, device); 36 | layer->CheckOutputs(result, batch_size); 37 | return result; 38 | } 39 | } 40 | 41 | 42 | // Evaluate all nodes up to the current node. 43 | // Eventually we'll block on the GraphExecutor to do this, but the initial 44 | // prototype version is single-threaded. 45 | int GraphImplementation::Eval(Graph* graph) { 46 | int current = graph->current_node(); 47 | int gsize = graph->size(); 48 | for (int i = current; i < gsize; ++i) { 49 | nodes::TensorNodeBase* node = graph->node(i); 50 | 51 | // If device is still unspecified, use default graph device. 52 | if (node->device() == kDeviceIDUnspecified) { 53 | graph->promote_node_device_shallow(node, graph->default_device()); 54 | } 55 | 56 | DCHECK_EQ(node->id(), i); 57 | node->InvokeKernel(graph); 58 | } 59 | int num_evaluated = gsize - current; 60 | set_current_node(graph, gsize); 61 | return num_evaluated; 62 | } 63 | 64 | 65 | void GraphImplementation::set_current_node(Graph* g, int i) { 66 | g->set_current_node(i); 67 | } 68 | 69 | 70 | Graph::~Graph() { 71 | #ifndef NDEBUG 72 | // Check that there are no dangling references. 73 | // Release refcounts from out of order nodes. 74 | for (nodes::TensorNodeBase* node : out_of_order_nodes_) { 75 | for (int j = 0, ni = node->num_inputs(); j < ni; ++j) { 76 | node->sub_expression(j).release(); 77 | } 78 | } 79 | // Destroy nodes in inverse order so the ref counts work out. 80 | for (int i = nodes_.size() - 1; i >= 0; --i) { 81 | nodes::TensorNodeBase* node = nodes_[i]; 82 | for (int j = 0, ni = node->num_inputs(); j < ni; ++j) { 83 | node->sub_expression(j).release(); 84 | } 85 | node->~TensorNodeBase(); // Destructor checks refcount. 86 | } 87 | #endif // NDEBUG 88 | } 89 | 90 | 91 | void Graph::Dump(std::ostream& out) { 92 | for (nodes::TensorNodeBase* node : nodes_) { 93 | out << "n_" << node->id() << " = " << opcode_name(node->opcode()); 94 | out << "("; 95 | 96 | auto* s_iter = node->sub_expressions(); 97 | auto* s_end = s_iter + node->num_inputs(); 98 | const char* sep = ""; 99 | for (; s_iter != s_end; ++s_iter) { 100 | out << sep << "n_" << s_iter->get()->id(); 101 | sep = ", "; 102 | } 103 | out << ")"; 104 | out << " [" << node->num_uses_ << "]"; 105 | out << "\n"; 106 | } 107 | } 108 | 109 | } // namespace llgtm 110 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/graph_implementation.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | // Base class for the underlying implementation of a Graph. 17 | // Each Evaluator may extend GraphImplementation with additional information 18 | // necessary to evaluate a graph. 19 | 20 | #ifndef TENSORFLOW_FOLD_LLGTM_GRAPH_IMPLEMENTATION_H_ 21 | #define TENSORFLOW_FOLD_LLGTM_GRAPH_IMPLEMENTATION_H_ 22 | 23 | #include 24 | #include 25 | #include 26 | 27 | #include "tensorflow_fold/llgtm/platform/platform.h" 28 | #include "tensorflow_fold/llgtm/tensor.h" 29 | #include "absl/types/span.h" 30 | 31 | 32 | namespace llgtm { 33 | 34 | class Graph; 35 | class GraphEvaluator; 36 | class Layer; 37 | 38 | 39 | // Underlying implementation class for a Graph. A Graph contains a pointer 40 | // to an implementation, so that classes derived from GraphEvaluator may 41 | // instantiate derived versions of GraphImplementation. The implementation 42 | // is responsible for allocating memory and evaluating nodes. 43 | // TODO(delesley): Make this thread-safe. 44 | class GraphImplementation { 45 | public: 46 | // InputList is used to pass a list of inputs to Layer() and Layer::Invoke(). 47 | // It is guaranteed to support operator[], and be constructable from 48 | // std::initializer_list. Clients should not rely on other operations. 49 | using InputList = absl::Span; 50 | template using InputListT = absl::Span>; 51 | 52 | static const int kArenaBlockSize = 1 << 16; // 64 k blocks 53 | static const int kNodeAlignment = 8; // Align to 64-bit boundaries 54 | static const int kResultAlignment = TensorType::kResultAlignment; 55 | 56 | GraphImplementation(const GraphImplementation& g) = delete; 57 | virtual ~GraphImplementation() {} 58 | 59 | GraphImplementation& operator=(const GraphImplementation& g) = delete; 60 | 61 | // Allocates memory from the arena. 62 | void* AllocateInArena(size_t size) { 63 | return arena_.AllocAligned(size, kNodeAlignment); 64 | } 65 | 66 | // Allocates result data for node, and returns size of allocation. 67 | virtual size_t AllocateResultData(nodes::TensorNodeBase* node) = 0; 68 | 69 | // Invokes a layer. 70 | virtual TensorBase Layer(Graph* g, class Layer* layer, InputList inputs, 71 | DeviceID device); 72 | 73 | // Evaluate all nodes in the given graph. 74 | virtual int Eval(Graph* graph); 75 | 76 | protected: 77 | friend class Graph; 78 | friend class GraphEvaluator; 79 | 80 | GraphImplementation() : arena_(kArenaBlockSize) {} 81 | 82 | // Exposes set_current_node to subclasses of GraphImplementation. 83 | void set_current_node(Graph* g, int i); 84 | 85 | // Arena for allocating nodes, inputs, and types. 86 | platform::Arena arena_; 87 | }; 88 | 89 | 90 | } // namespace llgtm 91 | 92 | #endif // TENSORFLOW_FOLD_LLGTM_GRAPH_IMPLEMENTATION_H_ 93 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/llgtm.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | // Include all important files in the llgtm library, except the backend. 17 | 18 | #ifndef TENSORFLOW_FOLD_LLGTM_LLGTM_H_ 19 | #define TENSORFLOW_FOLD_LLGTM_LLGTM_H_ 20 | 21 | #include "tensorflow_fold/llgtm/gradients.h" 22 | #include "tensorflow_fold/llgtm/graph.h" 23 | #include "tensorflow_fold/llgtm/layers.h" 24 | #include "tensorflow_fold/llgtm/tensor.h" 25 | #include "tensorflow_fold/llgtm/tensor_ops_impl.h" 26 | #include "tensorflow_fold/llgtm/trainer.h" 27 | #include "tensorflow_fold/llgtm/variable_initializers.h" 28 | 29 | #endif // TENSORFLOW_FOLD_LLGTM_LLGTM_H_ 30 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/platform/external.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_FOLD_LLGTM_PLATFORM_EXTERNAL_H_ 17 | #define TENSORFLOW_FOLD_LLGTM_PLATFORM_EXTERNAL_H_ 18 | 19 | #include 20 | #include 21 | 22 | #include "tensorflow/core/lib/core/arena.h" 23 | #include "tensorflow/core/platform/logging.h" 24 | #include "tensorflow/core/platform/macros.h" 25 | #include "tensorflow/core/util/command_line_flags.h" 26 | 27 | namespace llgtm { 28 | 29 | using string = std::string; 30 | using int64 = int64_t; 31 | using uint64 = uint64_t; 32 | 33 | namespace platform { 34 | 35 | // Use arena from TF Core. Use different arena inside Google. 36 | using Arena = tensorflow::core::Arena; 37 | 38 | // Tensorflow only allows flags of type float, whereas internal Google 39 | // libraries only allow flags of type double, so we convert as appropriate. 40 | template 41 | struct FlagType { using type = T; }; 42 | 43 | template<> 44 | struct FlagType { using type = float; }; 45 | 46 | // The following macros provide a convenient mechanism for declaring, 47 | // parsing, and using command line flags. See examples for usage. 48 | // The macros are the public interface; the classes are private. 49 | #define BEGIN_COMMAND_LINE_FLAGS \ 50 | llgtm::platform::CommandLineFlagRegistry global_commandline_flag_registry 51 | 52 | #define GET_FLAG_TYPE(TYPE) \ 53 | llgtm::platform::FlagType::type 54 | 55 | #define DEFINE_FLAG(TYPE, NAME, DEFVAL, DOCSTR) \ 56 | llgtm::platform::CommandLineFlag FLAGS_ ## NAME( \ 57 | #NAME, #TYPE, #DEFVAL, DOCSTR, DEFVAL, &global_commandline_flag_registry) 58 | 59 | #define GET_CL_FLAG(NAME) \ 60 | FLAGS_ ## NAME.value() 61 | 62 | #define PARSE_COMMAND_LINE_FLAGS(ARGC, ARGV) \ 63 | global_commandline_flag_registry.Parse(&ARGC, ARGV) 64 | 65 | 66 | namespace { // 67 | class CommandLineFlagRegistry; 68 | 69 | // Base class for command line flags. 70 | // Flags will register themselves with a global registry upon construction. 71 | class CommandLineFlagBase { 72 | public: 73 | CommandLineFlagBase() = delete; 74 | CommandLineFlagBase(const CommandLineFlagBase& other) = delete; 75 | 76 | CommandLineFlagBase(const char* name, const char* type, 77 | const char* valstr, const char* docstr) 78 | : name_(name), type_(type), default_value_(valstr), docstring_(docstr) 79 | {} 80 | 81 | void PrintUsage() { 82 | std::cout << "--" << name_ << "=" << default_value_ 83 | << " {" << type_ << "} " << docstring_ << "\n"; 84 | } 85 | 86 | private: 87 | friend class CommandLineFlagRegistry; 88 | 89 | const char* name_; 90 | const char* type_; 91 | const char* default_value_; 92 | const char* docstring_; 93 | }; 94 | 95 | 96 | // Derived class for command line flags of a particular type. 97 | template 98 | class CommandLineFlag : CommandLineFlagBase { 99 | public: 100 | CommandLineFlag(const char* name, const char* type, 101 | const char* valstr, const char* docstr, 102 | T default_value, CommandLineFlagRegistry* registry); 103 | 104 | const T& value() const { return value_; } 105 | 106 | private: 107 | friend class CommandLineFlagRegistry; 108 | 109 | T value_; 110 | }; 111 | 112 | 113 | // Class which holds a set of command line flags. 114 | // The registry is responsible for parsing command line flags, and for 115 | // printing out usage information. 116 | class CommandLineFlagRegistry { 117 | public: 118 | CommandLineFlagRegistry() {} 119 | ~CommandLineFlagRegistry() {} 120 | 121 | template 122 | void Register(CommandLineFlag* flag) { 123 | flags_.push_back(flag); 124 | tf_flags_.emplace_back(flag->name_, &flag->value_, flag->docstring_); 125 | } 126 | 127 | void Parse(int* argc, char** argv) { 128 | if (*argc > 1) { 129 | if (string(argv[1]) == string("--help")) { 130 | PrintUsage(); 131 | exit(0); 132 | } 133 | } 134 | tensorflow::Flags::Parse(argc, argv, tf_flags_); 135 | } 136 | 137 | void PrintUsage() { 138 | std::cout << "Usage: " << "\n"; 139 | for (auto* f : flags_) f->PrintUsage(); 140 | } 141 | 142 | private: 143 | std::vector flags_; 144 | std::vector tf_flags_; 145 | }; 146 | 147 | 148 | template 149 | inline CommandLineFlag::CommandLineFlag(const char* name, 150 | const char* type, 151 | const char* valstr, 152 | const char* docstr, 153 | T default_value, 154 | CommandLineFlagRegistry* registry) 155 | : CommandLineFlagBase(name, type, valstr, docstr), value_(default_value) { 156 | registry->Register(this); 157 | } 158 | } // namespace 159 | 160 | } // namespace platform 161 | } // namespace llgtm 162 | 163 | #endif // TENSORFLOW_FOLD_LLGTM_PLATFORM_EXTERNAL_H_ 164 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/platform/platform.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_FOLD_LLGTM_PLATFORM_PLATFORM_H_ 17 | #define TENSORFLOW_FOLD_LLGTM_PLATFORM_PLATFORM_H_ 18 | 19 | // A platform defines aliases for classes such as Arena and includes headers 20 | // for "base" functionality such as logging. 21 | 22 | // There are two platforms: GOOGLE and EXTERNAL (open sourced version). 23 | // EXTERNAL falls back to TF as a dependency, whereas GOOGLE can use internal 24 | // headers and dependencies. 25 | 26 | #if defined(LLGTM_PLATFORM_GOOGLE) && defined(LLGTM_PLATFORM_EXTERNAL) 27 | #error "Multiple platforms selected. Include one platform in BUILD target." 28 | #endif 29 | 30 | #if defined(LLGTM_PLATFORM_GOOGLE) 31 | #include "tensorflow_fold/llgtm/platform/google.h" 32 | #elif defined(LLGTM_PLATFORM_EXTERNAL) 33 | #include "tensorflow_fold/llgtm/platform/external.h" 34 | #else 35 | #error "No platform selected. Include one platform in BUILD target." 36 | #endif 37 | 38 | #endif // TENSORFLOW_FOLD_LLGTM_PLATFORM_PLATFORM_H_ 39 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/tensor.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow_fold/llgtm/tensor.h" 17 | #include "tensorflow_fold/llgtm/gradients.h" 18 | #include "tensorflow_fold/llgtm/graph.h" 19 | 20 | namespace llgtm { 21 | namespace nodes { 22 | 23 | void TensorNodeBase::set_device(DeviceID device) { 24 | device_ = device; 25 | } 26 | 27 | // Tuples cannot be consumed directly. 28 | void Tuple::InvokeKernel(Graph* graph) {} 29 | 30 | // Tuples cannot be consumed directly. 31 | void Tuple::ComputeGradients(TensorBase error, Gradients* gradients) { 32 | LOG(FATAL) << "Cannot pass gradients through a Tuple."; 33 | } 34 | 35 | void GetOutput::ComputeGradients(TensorBase error, Gradients* gradients) { 36 | TensorNodeBase* multi = sub_expression(0).get(); 37 | if (multi->is_differentiable()) { 38 | gradients->PropagateMultiError(multi, error, output_index_); 39 | } 40 | } 41 | 42 | 43 | } // namespace nodes 44 | } // namespace llgtm 45 | 46 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/tensor_opcodes.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow_fold/llgtm/tensor_opcodes.h" 17 | 18 | namespace llgtm { 19 | 20 | #define LLGTM_OPCODE_NAME_CASE(OP) case kOp ## OP: return #OP; 21 | 22 | absl::string_view opcode_name(TensorOpcode op) { 23 | switch (op) { 24 | LLGTM_OPCODE_NAME_CASE(Tuple); 25 | LLGTM_OPCODE_NAME_CASE(GetOutput); 26 | LLGTM_OPCODE_NAME_CASE(CopyToDevice); 27 | 28 | LLGTM_OPCODE_NAME_CASE(Value); 29 | LLGTM_OPCODE_NAME_CASE(Variable); 30 | LLGTM_OPCODE_NAME_CASE(AssignAdd); 31 | LLGTM_OPCODE_NAME_CASE(Zeros); 32 | LLGTM_OPCODE_NAME_CASE(ConstantFromFunction); 33 | LLGTM_OPCODE_NAME_CASE(ConstantFromScalar); 34 | LLGTM_OPCODE_NAME_CASE(UniformRandom); 35 | LLGTM_OPCODE_NAME_CASE(NormalRandom); 36 | 37 | LLGTM_OPCODE_NAME_CASE(Broadcast); 38 | LLGTM_OPCODE_NAME_CASE(ReduceSum); 39 | LLGTM_OPCODE_NAME_CASE(Transpose); 40 | LLGTM_OPCODE_NAME_CASE(Reshape); 41 | LLGTM_OPCODE_NAME_CASE(Concat); 42 | LLGTM_OPCODE_NAME_CASE(Split); 43 | LLGTM_OPCODE_NAME_CASE(Gather); 44 | LLGTM_OPCODE_NAME_CASE(Scatter); 45 | 46 | LLGTM_OPCODE_NAME_CASE(Negative); 47 | LLGTM_OPCODE_NAME_CASE(Reciprocal); 48 | 49 | LLGTM_OPCODE_NAME_CASE(Add); 50 | LLGTM_OPCODE_NAME_CASE(Multiply); 51 | 52 | LLGTM_OPCODE_NAME_CASE(Matmul); 53 | 54 | LLGTM_OPCODE_NAME_CASE(Relu); 55 | LLGTM_OPCODE_NAME_CASE(ReluGrad); 56 | LLGTM_OPCODE_NAME_CASE(Sigmoid); 57 | LLGTM_OPCODE_NAME_CASE(Tanh); 58 | LLGTM_OPCODE_NAME_CASE(Softmax); 59 | LLGTM_OPCODE_NAME_CASE(SoftmaxCrossEntropy); 60 | LLGTM_OPCODE_NAME_CASE(SoftmaxSparseCrossEntropy); 61 | LLGTM_OPCODE_NAME_CASE(SoftmaxSparseCrossEntropyGrad); 62 | 63 | case kMaximumTensorOpcode: return "Invalid"; 64 | } 65 | } 66 | #undef LLGTM_OPCODE_NAME_CASE 67 | 68 | 69 | #define LLGTM_DTYPE_NAME_CASE(TYPE) case kDT ## TYPE: return #TYPE; 70 | 71 | absl::string_view dtype_name(TensorDataType type) { 72 | switch (type) { 73 | LLGTM_DTYPE_NAME_CASE(void); 74 | LLGTM_DTYPE_NAME_CASE(bool); 75 | LLGTM_DTYPE_NAME_CASE(int8); 76 | LLGTM_DTYPE_NAME_CASE(int16); 77 | LLGTM_DTYPE_NAME_CASE(int32); 78 | LLGTM_DTYPE_NAME_CASE(int64); 79 | LLGTM_DTYPE_NAME_CASE(float32); 80 | LLGTM_DTYPE_NAME_CASE(float64); 81 | 82 | case kMaximumTensorDataType: return "Invalid"; 83 | } 84 | } 85 | 86 | #undef LLGTM_DTYPE_NAME_CASE 87 | 88 | } // end namespace llgtm 89 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/tensor_opcodes.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | // Enumerations and helper functions that are used in tensor.h. 17 | 18 | #ifndef TENSORFLOW_FOLD_LLGTM_TENSOR_OPCODES_H_ 19 | #define TENSORFLOW_FOLD_LLGTM_TENSOR_OPCODES_H_ 20 | 21 | #include "tensorflow_fold/llgtm/platform/platform.h" 22 | #include "absl/strings/string_view.h" 23 | 24 | namespace llgtm { 25 | 26 | // TODO(delesley): Consider using tensorflow/core/framework/types.proto 27 | // For now, we are trying to limit dependencies on tensorflow. 28 | enum TensorDataType : uint8_t { 29 | kDTvoid, 30 | kDTbool, 31 | kDTint8, 32 | kDTint16, 33 | kDTint32, 34 | kDTint64, 35 | // TODO(delesley): decide if we want to support unsigned ints. 36 | kDTfloat32, 37 | kDTfloat64, 38 | kMaximumTensorDataType 39 | }; 40 | 41 | 42 | // Return the size of the given dtype. 43 | inline int sizeof_dtype(TensorDataType dtype) { 44 | switch (dtype) { 45 | case kDTvoid: return 0; 46 | case kDTbool: return sizeof(bool); 47 | case kDTint8: return sizeof(int8_t); 48 | case kDTint16: return sizeof(int16_t); 49 | case kDTint32: return sizeof(int32_t); 50 | case kDTint64: return sizeof(int64_t); 51 | case kDTfloat32: return sizeof(float); 52 | case kDTfloat64: return sizeof(double); 53 | case kMaximumTensorDataType: 54 | break; 55 | } 56 | LOG(FATAL) << "Invalid DType."; 57 | } 58 | 59 | 60 | // Template function converts a C++ type to TensorDataType at compile time. 61 | template struct CppTypeToDType; 62 | 63 | #define TYPE2DTYPEDEF(Type, DTypeVal) \ 64 | template<> struct CppTypeToDType { \ 65 | static const TensorDataType dtype = DTypeVal; \ 66 | }; 67 | 68 | TYPE2DTYPEDEF(bool, kDTbool); 69 | TYPE2DTYPEDEF(int8_t, kDTint8); 70 | TYPE2DTYPEDEF(int16_t, kDTint16); 71 | TYPE2DTYPEDEF(int32_t, kDTint32); 72 | TYPE2DTYPEDEF(int64_t, kDTint64); 73 | TYPE2DTYPEDEF(float, kDTfloat32); 74 | TYPE2DTYPEDEF(double, kDTfloat64); 75 | 76 | #undef TYPE2DTYPEDEF 77 | 78 | 79 | enum TensorOpcode : int8_t { 80 | kOpGetOutput, // Get the nth output. 81 | kOpTuple, // Collect a set of gradients into a single node. 82 | kOpCopyToDevice, // Copy result data to a different device. 83 | 84 | // Constants. 85 | kOpValue, // A tensor value (fully allocated in memory). 86 | kOpVariable, // A reference to a variable. 87 | kOpAssignAdd, // Adds a gradient to variable. 88 | kOpZeros, // A tensor full of zeros. 89 | kOpConstantFromFunction, // A tensor constant, initialized from a function. 90 | kOpConstantFromScalar, // A tensor initialized with a scalar. 91 | kOpUniformRandom, // A tensor of random numbers (uniform distr.). 92 | kOpNormalRandom, // A tensor of random numbers (normal distr.). 93 | 94 | // Tensor operations. 95 | kOpBroadcast, 96 | kOpReduceSum, 97 | kOpTranspose, 98 | kOpReshape, 99 | kOpConcat, 100 | kOpSplit, 101 | kOpGather, 102 | kOpScatter, 103 | 104 | // Element-wise unary arithmetic operations 105 | kOpNegative, 106 | kOpReciprocal, 107 | 108 | // Element-wise arithmetic operations. 109 | kOpAdd, 110 | kOpMultiply, 111 | 112 | // Matrix operations. 113 | kOpMatmul, 114 | 115 | // Neural network activations. 116 | kOpRelu, 117 | kOpReluGrad, 118 | kOpSigmoid, 119 | kOpTanh, 120 | kOpSoftmax, 121 | kOpSoftmaxCrossEntropy, 122 | kOpSoftmaxSparseCrossEntropy, 123 | kOpSoftmaxSparseCrossEntropyGrad, 124 | 125 | kMaximumTensorOpcode 126 | }; 127 | 128 | // Return the name of the given opcode. 129 | absl::string_view opcode_name(TensorOpcode op); 130 | 131 | // Return the name of the given type. 132 | absl::string_view dtype_name(TensorDataType type); 133 | 134 | } // end namespace llgtm 135 | 136 | #endif // TENSORFLOW_FOLD_LLGTM_TENSOR_OPCODES_H_ 137 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/test/evaluator_test_eigen.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow_fold/llgtm/backend/eigen_evaluator_client.h" 17 | #include "tensorflow_fold/llgtm/test/evaluator_test.h" 18 | 19 | namespace llgtm { 20 | namespace { 21 | using Configuration = TestConfiguration; 22 | 23 | INSTANTIATE_TYPED_TEST_CASE_P(Eigen, EvaluatorTest, Configuration); 24 | } // namespace 25 | } // namespace llgtm 26 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/test/evaluator_test_eigen_gpu.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow_fold/llgtm/backend/eigen_evaluator_client.h" 17 | #include "tensorflow_fold/llgtm/test/evaluator_test.h" 18 | #include "tensorflow_fold/llgtm/test/test_framework.h" 19 | 20 | #ifndef GOOGLE_CUDA 21 | #error "LLGTM built without CUDA support. Use --config=cuda." 22 | #endif // GOOGLE_CUDA 23 | 24 | #ifndef NDEBUG 25 | // TODO(matthiasspringer): Investigate why this is not working. 26 | // Workarounds: a) Run in optimized mode. 27 | // b) Possibly use --config=nvcc8 (does not compile). 28 | #error "Cannot run Eigen with CUDA support in debug mode. Run in opt mode." 29 | #endif // NDEBUG 30 | 31 | namespace llgtm { 32 | namespace { 33 | using Configuration = TestConfiguration; 34 | 35 | INSTANTIATE_TYPED_TEST_CASE_P(Eigen, EvaluatorTest, Configuration); 36 | 37 | class EvaluatorTestEigenGPUExtra : public DeviceAwareTest {}; 38 | 39 | TEST_F(EvaluatorTestEigenGPUExtra, TestCopyToDevice) { 40 | EigenEvaluator evaluator; 41 | Dimensions dims; 42 | 43 | float one = 1.0; 44 | 45 | { 46 | Graph g = evaluator.NewGraph(kDeviceIDGPU); 47 | auto v_one = g.Value(dims, &one); 48 | auto v_one_cpu = g.Value(dims, &one, kDeviceIDCPU); 49 | 50 | // Calculate on GPU, copy result back to CPU. 51 | auto result = g.CopyToDevice(v_one, kDeviceIDCPU); 52 | g.Eval(); 53 | 54 | this->ExpectEq(result, v_one_cpu); 55 | } 56 | } 57 | 58 | TEST_F(EvaluatorTestEigenGPUExtra, TestDevicePromotion) { 59 | EigenEvaluator evaluator; 60 | Dimensions dims; 61 | 62 | float one = 1.0; 63 | float two = 2.0; 64 | float five = 5.0; 65 | 66 | { 67 | Graph g = evaluator.NewGraph(kDeviceIDCPU); 68 | auto v_one = g.Value(dims, &one); // unspecified dev. 69 | auto v_two = g.Value(dims, &two); // unspecified dev. 70 | auto v_two_gpu = g.Value(dims, &two, kDeviceIDGPU); 71 | auto v_five = g.Value(dims, &five); 72 | 73 | // Calculate on GPU, copy result back to CPU. 74 | auto sum = g.Add(g.Add(v_one, v_two), v_two_gpu); 75 | EXPECT_EQ(v_one.device(), kDeviceIDGPU); 76 | 77 | auto result = g.CopyToDevice(sum, kDeviceIDCPU); 78 | g.Eval(); 79 | 80 | this->ExpectEq(result, v_five); 81 | } 82 | } 83 | 84 | } // namespace 85 | } // namespace llgtm 86 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/test/evaluator_test_tf.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow_fold/llgtm/backend/tf_evaluator_client.h" 17 | #include "tensorflow_fold/llgtm/test/evaluator_test.h" 18 | 19 | namespace llgtm { 20 | namespace { 21 | using Configuration = TestConfiguration; 22 | 23 | INSTANTIATE_TYPED_TEST_CASE_P(Tf, EvaluatorTest, Configuration); 24 | } // namespace 25 | } // namespace llgtm 26 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/test/gradients_test_eigen.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow_fold/llgtm/backend/eigen_evaluator_client.h" 17 | #include "tensorflow_fold/llgtm/test/gradients_test.h" 18 | 19 | namespace llgtm { 20 | namespace { 21 | using Configuration = TestConfiguration; 22 | 23 | INSTANTIATE_TYPED_TEST_CASE_P(Eigen, GradientsTest, Configuration); 24 | } // namespace 25 | } // namespace llgtm 26 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/test/gradients_test_eigen_gpu.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow_fold/llgtm/backend/eigen_evaluator_client.h" 17 | #include "tensorflow_fold/llgtm/test/gradients_test.h" 18 | #include "tensorflow_fold/llgtm/test/test_framework.h" 19 | 20 | #ifndef GOOGLE_CUDA 21 | #error "LLGTM built without CUDA support. Use --config=cuda." 22 | #endif // GOOGLE_CUDA 23 | 24 | #ifndef NDEBUG 25 | // See evaluator_test_eigen_gpu.cc 26 | #error "Cannot run Eigen with CUDA support in debug mode. Run in opt mode." 27 | #endif // NDEBUG 28 | 29 | namespace llgtm { 30 | namespace { 31 | using Configuration = TestConfiguration; 32 | 33 | INSTANTIATE_TYPED_TEST_CASE_P(EigenGPU, GradientsTest, Configuration); 34 | 35 | class GradientsTestExtra : public DeviceAwareTest {}; 36 | 37 | TEST_F(GradientsTestExtra, TestCopyToDevice) { 38 | // This test is based on GradientsTest::TestNegative. 39 | Dimensions dims; 40 | 41 | EigenEvaluator evaluator; 42 | VariableSet model(&evaluator); 43 | 44 | VarNameSpace* a_space = model.NewNameSpace("a_layer"); 45 | VarNameSpace* b_space = model.NewNameSpace("b_layer", a_space); 46 | 47 | auto* va = model.NewVariable("a", dims, a_space, 48 | ScalarInitializer(4.0f)); 49 | auto* vb = model.NewVariable("b", dims, b_space, 50 | ScalarInitializer(5.0f)); 51 | 52 | { 53 | Graph g = evaluator.NewGraph(kDeviceIDCPU); 54 | 55 | auto a = g.Variable(va); 56 | auto b = g.Variable(vb); 57 | 58 | // Promote multiply to GPU. 59 | auto multiply_gpu = g.Add(g.Zeros(dims, kDeviceIDGPU), 60 | g.Multiply(a, g.Negative(b))); 61 | auto expr = g.Add(g.CopyToDevice(g.Negative(a), kDeviceIDCPU), 62 | g.CopyToDevice(multiply_gpu, kDeviceIDCPU)); 63 | 64 | Gradients grads(&model); 65 | g.ComputeGradients(&grads, expr); 66 | 67 | auto a_grad = g.Gradient(grads, va); 68 | auto b_grad = g.Gradient(grads, vb); 69 | 70 | // a = 4, b = 5 71 | // expr = -a + a*(-b) 72 | // grad_a = -1 - 5 = -6 73 | // grad_b = 0 - 4 = -4 74 | auto a_grad_expected = g.ConstantFromScalar(dims, -6.0f); 75 | auto b_grad_expected = g.ConstantFromScalar(dims, -4.0f); 76 | 77 | g.Eval(); 78 | 79 | this->ExpectEq(a_grad, a_grad_expected); 80 | this->ExpectEq(b_grad, b_grad_expected); 81 | } 82 | } 83 | 84 | } // namespace 85 | } // namespace llgtm 86 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/test/gradients_test_tf.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow_fold/llgtm/backend/tf_evaluator_client.h" 17 | #include "tensorflow_fold/llgtm/test/gradients_test.h" 18 | 19 | namespace llgtm { 20 | namespace { 21 | using Configuration = TestConfiguration; 22 | 23 | INSTANTIATE_TYPED_TEST_CASE_P(Tf, GradientsTest, Configuration); 24 | } // namespace 25 | } // namespace llgtm 26 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/test/graph_nocompile.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow_fold/llgtm/graph.h" 17 | #include "tensorflow_fold/llgtm/tensor.h" 18 | 19 | void failed_type_checks() { 20 | #ifdef TEST_DIMENSIONS 21 | // These are compile-time errors. 22 | std::array dim_list = {{1, 2, 3, 4, 5}}; 23 | llgtm::Dimensions dims(dim_list); 24 | #endif 25 | } 26 | 27 | #ifdef TEST_MULTIPLE_BACKENDS 28 | // Including more than one backend is forbidden. In that case, the linker has 29 | // multiple different TensorNodeSelf::InvokeKernel implementations to choose 30 | // from, violating ODR (One Definition Rule). This test ensures that we show 31 | // a useful compile error message. 32 | #include "tensorflow_fold/llgtm/backend/eigen_evaluator_client.h" 33 | #include "tensorflow_fold/llgtm/backend/tf_evaluator_client.h" 34 | #endif 35 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/test/graph_nocompile_test.py: -------------------------------------------------------------------------------- 1 | """Negative compilation unit tests for LLGTM graphs.""" 2 | 3 | class GraphNoCompileTest(object): 4 | """Negative compilation tests LLGTM graphs.""" 5 | 6 | def testCompilerErrors(self): 7 | # Not currently implemented in open source. 8 | pass 9 | 10 | if __name__ == '__main__': 11 | pass 12 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/trainer.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | // Interface for training neural networks. 17 | // A Trainer applies gradients to a VariableSet. 18 | 19 | #ifndef TENSORFLOW_FOLD_LLGTM_TRAINER_H_ 20 | #define TENSORFLOW_FOLD_LLGTM_TRAINER_H_ 21 | 22 | #include "tensorflow_fold/llgtm/gradients.h" 23 | #include "tensorflow_fold/llgtm/tensor.h" 24 | #include "tensorflow_fold/llgtm/tensor_ops.h" 25 | 26 | namespace llgtm { 27 | 28 | // Base class for Trainers. 29 | class Trainer { 30 | public: 31 | Trainer() {} 32 | Trainer(const Trainer&) = delete; 33 | Trainer(Trainer&&) = delete; 34 | virtual ~Trainer() {} 35 | 36 | Trainer& operator=(const Trainer&) = delete; 37 | Trainer& operator=(Trainer&&) = delete; 38 | 39 | // Adds nodes to the graph which apply the gradients to the variables. 40 | // The model will not be updated until Graph.Eval() is called. 41 | virtual void ApplyGradients(const Gradients& grads) = 0; 42 | 43 | // Computes gradients for the given loss, and add nodes to the graph which 44 | // applies them to the model. The model will not be updated until 45 | // Graph.Eval() is called. It will fail if graph or model are nullptr. 46 | void ComputeAndApplyGradients(Graph* graph, VariableSet* model, 47 | Tensor loss); 48 | 49 | VariableSet* variables() { return variables_; } 50 | 51 | protected: 52 | // Initializes the trainer, which happens on the first call to ApplyGradients. 53 | // Derived classes may override this method to allocate additional state 54 | // associated with training (e.g. momentum). The overridden version should 55 | // call the base class version. 56 | virtual void Initialize(const Gradients& grads); 57 | 58 | // Initializes the trainer if it has not been initialized already. 59 | // Derived classes must call this method at the start of ApplyGradients. 60 | void CheckInitialize(const Gradients& grads); 61 | 62 | private: 63 | VariableSet* variables_ = nullptr; 64 | }; 65 | 66 | 67 | // Trainer which implements stochastic gradient descent. 68 | class SGDTrainer : public Trainer { 69 | public: 70 | SGDTrainer() = delete; 71 | explicit SGDTrainer(float learning_rate) : learning_rate_(learning_rate) {} 72 | 73 | float learning_rate() const { return learning_rate_; } 74 | 75 | void ApplyGradients(const Gradients& grads) override; 76 | 77 | private: 78 | float learning_rate_; 79 | }; 80 | 81 | 82 | // Trainer which implements the stochastic gradient descent with momentum. 83 | class MomentumTrainer : public Trainer { 84 | public: 85 | MomentumTrainer() = delete; 86 | 87 | // Creates a new MomentumTrainer. 88 | explicit MomentumTrainer(float learning_rate, float momentum = 0.9f) 89 | : learning_rate_(learning_rate), momentum_(momentum) {} 90 | 91 | float learning_rate() const { return learning_rate_; } 92 | float momentum() const { return momentum_; } 93 | 94 | // For every differentiable variable in the model, MomentumTrainer 95 | // creates another variable of the same shape to track its momentum. 96 | VariableSet* momentum_variables() { return momentum_variables_.get(); } 97 | 98 | Variable* momentum_variable(VariableBase* var) { 99 | DCHECK_EQ(var->variable_set(), variables()); 100 | DCHECK_LE(var->id(), momentum_var_map_.size()); 101 | return momentum_var_map_[var->id()]; 102 | } 103 | 104 | void ApplyGradients(const Gradients& grads) override; 105 | 106 | protected: 107 | void Initialize(const Gradients& grads) override; 108 | 109 | private: 110 | // The rate at which the momentum ramps up from 0 to its proper value. 111 | static constexpr float kMomentumRampRate = 0.05f; 112 | 113 | std::unique_ptr momentum_variables_; 114 | std::vector*> momentum_var_map_; 115 | 116 | float learning_rate_; 117 | float momentum_; 118 | float active_momentum_ = 0.0f; 119 | }; 120 | 121 | } // namespace llgtm 122 | 123 | #endif // TENSORFLOW_FOLD_LLGTM_TRAINER_H_ 124 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/util.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_FOLD_LLGTM_UTIL_H_ 17 | #define TENSORFLOW_FOLD_LLGTM_UTIL_H_ 18 | 19 | #include 20 | #include 21 | #include 22 | #include "tensorflow_fold/llgtm/tensor.h" 23 | 24 | namespace llgtm { 25 | 26 | // Wraps an iterator by applying a function to each element. 27 | // The function is called on-demand, essentially implementing a lazy 28 | // functional-style map over a collection of elements. The class F must 29 | // support an operator() that returns a reference to T. 30 | // Note that the wrapped iterator is an input iterator only. 31 | template 32 | class IteratorWrapper : public std::iterator { 33 | public: 34 | // Creates an IteratorWrapper from an iterator and a function f. 35 | explicit IteratorWrapper(IterType iter, F f = F()) 36 | : iter_(iter), function_(f) {} 37 | 38 | IteratorWrapper() = default; 39 | IteratorWrapper(const IteratorWrapper& p) = default; 40 | IteratorWrapper(IteratorWrapper&& p) = default; 41 | 42 | IteratorWrapper& operator=(const IteratorWrapper& p) = default; 43 | IteratorWrapper& operator=(IteratorWrapper&& p) = default; 44 | 45 | T& operator*() { return function_(*iter_); } 46 | const T& operator*() const { return function_(*iter_); } 47 | 48 | T* operator->() { return &function_(*iter_); } 49 | const T* operator->() const { return &function_(*iter_); } 50 | 51 | IteratorWrapper& operator++() { 52 | ++iter_; 53 | return *this; 54 | } 55 | 56 | IteratorWrapper operator++(int) { 57 | return IteratorWrapper(iter_++, function_); 58 | } 59 | 60 | bool operator==(IteratorWrapper other) { return iter_ == other.iter_; } 61 | bool operator!=(IteratorWrapper other) { return iter_ != other.iter_; } 62 | 63 | private: 64 | IterType iter_; 65 | F function_; 66 | }; 67 | 68 | // A functor that calls get() on a unique_ptr. 69 | // For use with IteratorWrapper, to produce iterators over containers that 70 | // don't expose the implementation details (ownership) of the container. 71 | template 72 | struct UniquePtrGetter { 73 | T& operator()(std::unique_ptr& ptr) { return *ptr; } // NOLINT 74 | const T& operator()(const std::unique_ptr& ptr) { return *ptr; } 75 | }; 76 | 77 | 78 | // A helper class that can calculate mean, stddev, and variance of a 79 | // distribution. Uses Welford's algorithm. 80 | // See also: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance. 81 | template 82 | class OnlineStats { 83 | public: 84 | // Initialize with data. 85 | explicit OnlineStats(const DT* data, size_t num_elements) { 86 | reset(); 87 | 88 | for (int i = 0; i < num_elements; ++i) { 89 | this->add(data[i]); 90 | } 91 | } 92 | 93 | OnlineStats() : OnlineStats(/*data=*/ nullptr, 0) {} 94 | 95 | void add(DT element) { 96 | ++n_; 97 | 98 | // Welford's algorithm (to compute variance). 99 | double d1 = element - mean_; 100 | mean_ += d1 / n_; 101 | double d2 = element - mean_; 102 | m2_ += d1 * d2; 103 | } 104 | 105 | void reset() { 106 | n_ = 0; 107 | mean_ = m2_ = 0.0; 108 | min_ = std::numeric_limits
::max(); 109 | max_ = std::numeric_limits
::min(); 110 | } 111 | 112 | double stddev() { 113 | return std::sqrt(this->variance()); 114 | } 115 | 116 | double variance() { 117 | return m2_ / n_; 118 | } 119 | 120 | double mean() { 121 | return mean_; 122 | } 123 | 124 | DT min() { 125 | return min_; 126 | } 127 | 128 | DT max() { 129 | return max_; 130 | } 131 | 132 | private: 133 | int n_; 134 | double mean_; 135 | double m2_; 136 | DT min_; 137 | DT max_; 138 | }; 139 | 140 | } // end namespace llgtm 141 | 142 | #endif // TENSORFLOW_FOLD_LLGTM_UTIL_H_ 143 | -------------------------------------------------------------------------------- /tensorflow_fold/llgtm/variable_initializers.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_FOLD_LLGTM_VARIABLE_INITIALIZERS_H_ 17 | #define TENSORFLOW_FOLD_LLGTM_VARIABLE_INITIALIZERS_H_ 18 | 19 | #include "tensorflow_fold/llgtm/dimensions.h" 20 | #include "tensorflow_fold/llgtm/graph.h" 21 | 22 | namespace llgtm { 23 | 24 | 25 | // Initialize all elements of a variable with arbitrary dimensions to zero. 26 | template 27 | class ZerosInitializer { 28 | public: 29 | ZerosInitializer() {} 30 | 31 | Tensor
operator()(Graph* g, const Dimensions& dims, DeviceID device) { 32 | return g->Zeros
(dims, device); 33 | } 34 | }; 35 | 36 | 37 | // Initialize all elements of a variable with arbitrary dimensions to a given 38 | // scalar value. 39 | template 40 | class ScalarInitializer_ { 41 | public: 42 | explicit ScalarInitializer_(DT val) : val_(val) {} 43 | 44 | Tensor
operator()(Graph* g, const Dimensions& dims, DeviceID device) { 45 | return g->ConstantFromScalar(dims, val_, device); 46 | } 47 | 48 | private: 49 | DT val_; 50 | }; 51 | 52 | // Separate entry point for automatic template argument deduction. 53 | template 54 | ScalarInitializer_
ScalarInitializer(DT val) { 55 | return ScalarInitializer_
(val); 56 | } 57 | 58 | 59 | // Initialize all elements of a tensor variable using an array of values. The 60 | // array must be at least as big as the variable. 61 | template 62 | class TensorInitializer_ { 63 | public: 64 | explicit TensorInitializer_(DT *f) : val_(f) {} 65 | 66 | Tensor
operator()(Graph *g, const Dimensions& dims, DeviceID device) { 67 | return g->Value(dims, val_, device); 68 | } 69 | 70 | private: 71 | DT *val_; 72 | }; 73 | 74 | // Separate entry point for automatic template argument deduction. 75 | template 76 | TensorInitializer_
TensorInitializer(DT *val) { 77 | return TensorInitializer_
(val); 78 | } 79 | 80 | 81 | // Initialize all elements of a variable with arbitrary dimensions using a 82 | // function. 83 | template 84 | class FunctionInitializer_ { 85 | public: 86 | explicit FunctionInitializer_(F f) : function_(f) {} 87 | 88 | Tensor
operator()(Graph *g, const Dimensions& dims, DeviceID device) { 89 | return g->ConstantFromFunction(dims, function_, device); 90 | } 91 | 92 | private: 93 | F function_; 94 | }; 95 | 96 | // Separate entry point for automatic template argument deduction of DT. 97 | template 98 | FunctionInitializer_ FunctionInitializer(F f) { 99 | return FunctionInitializer_(f); 100 | } 101 | 102 | 103 | // Initialize all elements of a vector variable with uniformly distributed 104 | // random numbers. 105 | template 106 | class UniformRandomInitializer { 107 | public: 108 | explicit UniformRandomInitializer(uint64_t seed) 109 | : seed_(seed), has_seed_(true) {} 110 | 111 | UniformRandomInitializer() : has_seed_(false) {} 112 | 113 | Tensor
operator()(Graph *g, const Dimensions& dims, DeviceID device) { 114 | return g->UniformRandom
(dims, has_seed_ ? seed_ : 0, device); 115 | } 116 | 117 | private: 118 | uint64_t seed_; 119 | bool has_seed_; 120 | }; 121 | 122 | 123 | // Initialize all elements of a vector variable with normally distributed 124 | // random numbers. 125 | template 126 | class NormalRandomInitializer { 127 | public: 128 | explicit NormalRandomInitializer(uint64_t seed, DT mean = 0.0f, 129 | DT stddev = 1.0f) 130 | : seed_(seed), has_seed_(true), mean_(mean), stddev_(stddev) {} 131 | 132 | explicit NormalRandomInitializer(DT mean = 0.0f, DT stddev = 1.0f) 133 | : has_seed_(false), mean_(mean), stddev_(stddev) {} 134 | 135 | Tensor
operator()(Graph *g, const Dimensions& dims, DeviceID device) { 136 | Tensor
rand = g->NormalRandom
(dims, has_seed_ ? seed_ : 0, device); 137 | 138 | // TODO(delesley): Generalize. This only works if DT is float or double. 139 | auto mean = g->ConstantFromScalar
(dims, mean_); 140 | auto stddev = g->ConstantFromScalar
(dims, stddev_); 141 | return g->Add(mean, g->Multiply(rand, stddev)); 142 | } 143 | 144 | private: 145 | uint64_t seed_; 146 | bool has_seed_; 147 | DT mean_; 148 | DT stddev_; 149 | }; 150 | 151 | 152 | } // namespace llgtm 153 | 154 | #endif // TENSORFLOW_FOLD_LLGTM_VARIABLE_INITIALIZERS_H_ 155 | -------------------------------------------------------------------------------- /tensorflow_fold/loom/BUILD: -------------------------------------------------------------------------------- 1 | # Loom enables TensorFlow to be easily and efficiently used to train 2 | # networks whose shape differs on a per training example basis. This allows 3 | # TensorFlow to handle architectures including Tree RNNs, high dimensional 4 | # sparse variants of GridLSTM and arbitrary architectures containing multiple 5 | # LSTMs of example-dependant depth. Loom is implemented as a library on 6 | # top of TensorFlow. 7 | 8 | licenses(["notice"]) # Apache 2.0 9 | 10 | load( 11 | "//tensorflow_fold:fold.bzl", 12 | "fold_cc_library", 13 | "fold_cc_test", 14 | "fold_proto_library", 15 | "fold_py_binary", 16 | "fold_py_library", 17 | "fold_py_test", 18 | "fold_py_wrap_cc", 19 | "fold_tf_op_py", 20 | ) 21 | 22 | package( 23 | default_visibility = [ 24 | "//tensorflow_fold/public:__subpackages__", 25 | ], 26 | ) 27 | 28 | fold_proto_library( 29 | srcs = ["loom.proto"], 30 | cc_deps = [ 31 | "@org_tensorflow//tensorflow/core:protos_all_cc", 32 | ], 33 | cc_name = "loom_proto", 34 | py_deps = [ 35 | "@org_tensorflow//tensorflow/core:protos_all_py", 36 | ], 37 | py_name = "loom_py_pb2", 38 | ) 39 | 40 | fold_cc_library( 41 | name = "platform", 42 | hdrs = ["platform.h"], 43 | deps = [ 44 | "@org_tensorflow//tensorflow/core:framework_headers_lib", 45 | ], 46 | ) 47 | 48 | fold_cc_library( 49 | name = "weaver", 50 | srcs = ["weaver.cc"], 51 | hdrs = ["weaver.h"], 52 | deps = [ 53 | ":loom_proto", 54 | ":platform", 55 | "@org_tensorflow//third_party/eigen3", 56 | "@org_tensorflow//tensorflow/core:framework_headers_lib", 57 | "@org_tensorflow//tensorflow/core:protos_all_cc", 58 | ], 59 | ) 60 | 61 | fold_cc_test( 62 | name = "weaver_test", 63 | srcs = ["weaver_test.cc"], 64 | deps = [ 65 | ":loom_proto", 66 | ":weaver", 67 | "@org_tensorflow//tensorflow/core:protos_all_cc", 68 | ], 69 | ) 70 | 71 | fold_py_wrap_cc( 72 | name = "pywrapweaver", 73 | srcs = ["python/weaver.swig"], 74 | deps = [ 75 | ":platform", 76 | ":weaver", 77 | ], 78 | ) 79 | 80 | fold_cc_library( 81 | name = "weaver_op_base", 82 | srcs = ["weaver_op_base.cc"], 83 | hdrs = ["weaver_op_base.h"], 84 | deps = [ 85 | ":weaver", 86 | "@org_tensorflow//tensorflow/core:framework_headers_lib", 87 | ], 88 | ) 89 | 90 | fold_py_library( 91 | name = "weaver_op_base_py", 92 | srcs = ["weaver_op_base.py"], 93 | deps = [ 94 | "@org_tensorflow//tensorflow:tensorflow_py", 95 | ], 96 | ) 97 | 98 | # The following two BUILD rules are an example of how you would use 99 | # weaver_op_base to define an op. 100 | 101 | fold_cc_library( 102 | name = "deserializing_weaver_op_cc", 103 | srcs = ["deserializing_weaver_op.cc"], 104 | deps = [ 105 | ":weaver", 106 | ":weaver_op_base", 107 | "@org_tensorflow//tensorflow/core:framework_headers_lib", 108 | ], 109 | alwayslink = 1, 110 | ) 111 | 112 | fold_tf_op_py( 113 | name = "deserializing_weaver_op", 114 | srcs = ["deserializing_weaver_op.py"], 115 | cc_deps = [":deserializing_weaver_op_cc"], 116 | py_deps = [ 117 | ":weaver_op_base_py", 118 | "@org_tensorflow//tensorflow:tensorflow_py", 119 | ], 120 | ) 121 | 122 | fold_py_library( 123 | name = "loom", 124 | srcs = ["loom.py"], 125 | deps = [ 126 | ":deserializing_weaver_op", 127 | ":loom_py_pb2", 128 | ":pywrapweaver", 129 | # numpy", 130 | "@org_tensorflow//tensorflow:tensorflow_py", 131 | ], 132 | ) 133 | 134 | fold_py_test( 135 | name = "loom_test", 136 | srcs = ["loom_test.py"], 137 | deps = [ 138 | ":loom", 139 | "@org_tensorflow//tensorflow:tensorflow_py", 140 | ], 141 | ) 142 | -------------------------------------------------------------------------------- /tensorflow_fold/loom/benchmarks/BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # Bechmarks for the Loom library. 3 | 4 | licenses(["notice"]) # Apache 2.0 5 | 6 | load("//tensorflow_fold:fold.bzl", "fold_py_binary") 7 | 8 | fold_py_binary( 9 | name = "iclr_2017_benchmark", 10 | srcs = ["iclr_2017_benchmark.py"], 11 | deps = [ 12 | "@org_tensorflow//tensorflow:tensorflow_py", 13 | "//tensorflow_fold/public:loom", 14 | ], 15 | ) 16 | -------------------------------------------------------------------------------- /tensorflow_fold/loom/calculator_example/BUILD: -------------------------------------------------------------------------------- 1 | # Tree RNN smoke test (learn a calculator.) 2 | 3 | licenses(["notice"]) # Apache 2.0 4 | 5 | load( 6 | "//tensorflow_fold:fold.bzl", 7 | "fold_proto_library", 8 | "fold_py_binary", 9 | "fold_py_extension", 10 | "fold_py_library", 11 | "fold_py_test", 12 | ) 13 | 14 | package( 15 | default_visibility = [ 16 | "//tensorflow_fold:__subpackages__", 17 | ], 18 | ) 19 | 20 | fold_proto_library( 21 | srcs = ["calculator.proto"], 22 | cc_name = "calculator_proto", 23 | py_name = "calculator_py_pb2", 24 | ) 25 | 26 | filegroup( 27 | name = "calculator_proto_file", 28 | srcs = ["calculator.proto"], 29 | ) 30 | 31 | fold_py_extension( 32 | name = "cpp_calculator_proto", 33 | srcs = [], 34 | outs = [], 35 | deps = [":calculator_proto"], 36 | ) 37 | 38 | fold_py_library( 39 | name = "calculator", 40 | srcs = ["calculator.py"], 41 | deps = [ 42 | ":calculator_py_pb2", 43 | ], 44 | ) 45 | 46 | fold_py_test( 47 | name = "calculator_test", 48 | srcs = ["calculator_test.py"], 49 | deps = [ 50 | ":calculator", 51 | ":calculator_py_pb2", 52 | "@protobuf_archive//:protobuf_python", 53 | "@six_archive//:six", 54 | "@org_tensorflow//tensorflow:tensorflow_py", 55 | ], 56 | ) 57 | 58 | fold_py_binary( 59 | name = "make_dataset", 60 | srcs = ["make_dataset.py"], 61 | deps = [ 62 | ":calculator", 63 | "@six_archive//:six", 64 | "@org_tensorflow//tensorflow:tensorflow_py", 65 | ], 66 | ) 67 | 68 | fold_py_library( 69 | name = "model", 70 | srcs = ["model.py"], 71 | deps = [ 72 | ":calculator", 73 | ":calculator_py_pb2", 74 | # numpy", 75 | "@six_archive//:six", 76 | "@org_tensorflow//tensorflow:tensorflow_py", 77 | "//tensorflow_fold/public:loom", 78 | ], 79 | ) 80 | 81 | fold_py_test( 82 | name = "model_test", 83 | srcs = ["model_test.py"], 84 | deps = [ 85 | ":calculator", 86 | ":model", 87 | "@six_archive//:six", 88 | "@org_tensorflow//tensorflow:tensorflow_py", 89 | ], 90 | ) 91 | 92 | fold_py_library( 93 | name = "helpers", 94 | srcs = ["helpers.py"], 95 | deps = [ 96 | "@six_archive//:six", 97 | "@org_tensorflow//tensorflow:tensorflow_py", 98 | ], 99 | ) 100 | 101 | fold_py_binary( 102 | name = "train", 103 | srcs = ["train.py"], 104 | deps = [ 105 | ":calculator_py_pb2", 106 | ":helpers", 107 | ":model", 108 | "@six_archive//:six", 109 | "@org_tensorflow//tensorflow:tensorflow_py", 110 | ], 111 | ) 112 | 113 | fold_py_binary( 114 | name = "eval", 115 | srcs = ["eval.py"], 116 | deps = [ 117 | ":calculator_py_pb2", 118 | ":helpers", 119 | ":model", 120 | "@org_tensorflow//tensorflow:tensorflow_py", 121 | ], 122 | ) 123 | -------------------------------------------------------------------------------- /tensorflow_fold/loom/calculator_example/calculator.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package tensorflow_fold.loom.calculator_example; 4 | 5 | message CalculatorExpression { 6 | optional int64 number = 1; 7 | enum OpCode { 8 | PLUS = 0; 9 | MINUS = 1; 10 | TIMES = 2; 11 | DIV = 3; 12 | } 13 | optional OpCode op = 2; 14 | optional CalculatorExpression left = 3; 15 | optional CalculatorExpression right = 4; 16 | 17 | optional int64 result = 5; 18 | } 19 | -------------------------------------------------------------------------------- /tensorflow_fold/loom/calculator_example/calculator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Expression evaluation and generation.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import random 21 | 22 | # import google3 23 | from tensorflow_fold.loom.calculator_example import calculator_pb2 24 | 25 | 26 | def random_expression(max_depth): 27 | """Recursively build a random CalculatorExpression.""" 28 | def build(expression, max_depth): 29 | if max_depth == 0 or random.uniform(0, 1) < 1.0 / 3.0: 30 | expression.number = random.choice(range(10)) 31 | else: 32 | expression.op = random.choice( 33 | calculator_pb2.CalculatorExpression.OpCode.values()) 34 | build(expression.left, max_depth - 1) 35 | build(expression.right, max_depth - 1) 36 | expression = calculator_pb2.CalculatorExpression() 37 | build(expression, max_depth) 38 | return expression 39 | 40 | 41 | def validate_expression(expression, recurse=True): 42 | """Check that 'expression' has the correct subset of fields set. 43 | 44 | Args: 45 | expression: An expression to validate. 46 | recurse: whether to recurse to the left and right fields if present. 47 | (Default: true). 48 | 49 | Raises: 50 | NameError: If an unknown op is found. 51 | TypeError: If expr contains things that aren't tuples or ints, or if it 52 | contains a tuple of the wrong arity. 53 | """ 54 | if expression.HasField('number') == expression.HasField('op'): 55 | raise TypeError('Exactly one of {number, op} is required.') 56 | 57 | if expression.HasField('op') != expression.HasField('left'): 58 | raise TypeError('left should be present if and only if op is.') 59 | 60 | if expression.HasField('op') != expression.HasField('right'): 61 | raise TypeError('right should be present if and only if op is.') 62 | 63 | if expression.HasField('op'): 64 | if expression.op not in calculator_pb2.CalculatorExpression.OpCode.values(): 65 | raise NameError('Unrecognized op : ', expression.op) 66 | if recurse: 67 | validate_expression(expression.left, True) 68 | validate_expression(expression.right, True) 69 | 70 | 71 | def expression_depth(expression): 72 | validate_expression(expression, recurse=False) 73 | if expression.HasField('op'): 74 | return 1 + max(expression_depth(expression.left), 75 | expression_depth(expression.right)) 76 | return 0 # expression is a terminal (number). 77 | 78 | 79 | def evaluate_expression(expression): 80 | """Computes an integer from an expression by performing the operations.""" 81 | validate_expression(expression, recurse=False) 82 | if expression.HasField('number'): 83 | return expression.number 84 | a = evaluate_expression(expression.left) 85 | b = evaluate_expression(expression.right) 86 | if expression.op == calculator_pb2.CalculatorExpression.PLUS: 87 | return a + b 88 | if expression.op == calculator_pb2.CalculatorExpression.MINUS: 89 | return a - b 90 | if expression.op == calculator_pb2.CalculatorExpression.TIMES: 91 | return a * b 92 | if expression.op == calculator_pb2.CalculatorExpression.DIV: 93 | if b == 0: 94 | return 0 95 | else: 96 | return a // b 97 | else: 98 | raise NameError('Unrecognized op: ' + expression.op) 99 | -------------------------------------------------------------------------------- /tensorflow_fold/loom/calculator_example/calculator_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for calculator example.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import random 21 | 22 | # import google3 23 | from six.moves import xrange # pylint: disable=redefined-builtin 24 | import tensorflow as tf 25 | from google.protobuf import text_format 26 | from tensorflow_fold.loom.calculator_example import calculator 27 | from tensorflow_fold.loom.calculator_example import calculator_pb2 28 | 29 | 30 | def evaluate_expression(string): 31 | return calculator.evaluate_expression( 32 | text_format.Parse(string, calculator_pb2.CalculatorExpression())) 33 | 34 | 35 | class CalculatorTest(tf.test.TestCase): 36 | 37 | def test_generated_expression_depth(self): 38 | random.seed(0xdeadbeef) # Make RandomExpression deterministic. 39 | for _ in xrange(1000): 40 | expression = calculator.random_expression(5) 41 | calculator.validate_expression(expression) 42 | self.assertTrue(calculator.expression_depth(expression) <= 5) 43 | 44 | def test_eval(self): 45 | self.assertEqual(0, evaluate_expression( 46 | """op: DIV left right""")) 47 | 48 | # Division by zero defaults to zero. 49 | for n in xrange(10): 50 | self.assertEqual(n, evaluate_expression( 51 | """number: {n}""".format(n=n))) 52 | self.assertEqual(3 + n, evaluate_expression( 53 | """op: PLUS left right""".format(n=n))) 54 | self.assertEqual(2 * n, evaluate_expression( 55 | """op: PLUS left right""".format(n=n))) 56 | self.assertEqual(0, evaluate_expression( 57 | """op: MINUS left right""".format(n=n))) 58 | self.assertEqual(n * n * n, evaluate_expression( 59 | """op: TIMES 60 | left 61 | right right> 62 | """.format(n=n))) 63 | self.assertEqual(n, evaluate_expression( 64 | """op: DIV left right""".format( 65 | x=5 * n + 3))) 66 | 67 | if __name__ == '__main__': 68 | tf.test.main() 69 | -------------------------------------------------------------------------------- /tensorflow_fold/loom/calculator_example/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Runs the evaluator for the calculator smoke test.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | import time 20 | 21 | # import google3 22 | import tensorflow as tf 23 | from tensorflow_fold.loom.calculator_example import calculator_pb2 24 | from tensorflow_fold.loom.calculator_example import helpers 25 | from tensorflow_fold.loom.calculator_example import model 26 | 27 | 28 | tf.flags.DEFINE_string( 29 | 'validation_data_path', 30 | '', 31 | 'TF Record containing the validation dataset of expressions.') 32 | tf.flags.DEFINE_integer( 33 | 'embedding_length', 5, 34 | 'How long to make the expression embedding vectors.') 35 | tf.flags.DEFINE_string( 36 | 'eval_master', '', 37 | 'Tensorflow master to use.') 38 | tf.flags.DEFINE_string( 39 | 'logdir', '/tmp/calculator_smoketest', 40 | 'Directory where we read models and write event logs.') 41 | tf.flags.DEFINE_integer( 42 | 'eval_interval_secs', 60, 43 | 'Time interval between eval runs. Zero to do a single eval then exit.') 44 | FLAGS = tf.flags.FLAGS 45 | 46 | 47 | def string_to_expression(string): 48 | expression = calculator_pb2.CalculatorExpression() 49 | expression.ParseFromString(string) 50 | return expression 51 | 52 | 53 | def main(unused_argv): 54 | validation_table = tf.python_io.tf_record_iterator(FLAGS.validation_data_path) 55 | print('Reading validation table...') 56 | validation_data = [string_to_expression(v) for v in validation_table] 57 | print('Done reading validation table...') 58 | 59 | with tf.Graph().as_default(): 60 | global_step = tf.Variable(0, name='global_step', trainable=False) 61 | classifier = model.CalculatorSignClassifier(FLAGS.embedding_length) 62 | 63 | loss = classifier.loss() 64 | accuracy = classifier.accuracy() 65 | 66 | saver = tf.train.Saver() 67 | supervisor = tf.train.Supervisor( 68 | logdir=FLAGS.logdir, 69 | recovery_wait_secs=FLAGS.eval_interval_secs) 70 | sess = supervisor.PrepareSession( 71 | FLAGS.eval_master, 72 | wait_for_checkpoint=True, 73 | start_standard_services=False) 74 | 75 | while not supervisor.ShouldStop(): 76 | ckpt = tf.train.get_checkpoint_state(FLAGS.logdir) 77 | if ckpt and ckpt.model_checkpoint_path: 78 | saver.restore(sess, ckpt.model_checkpoint_path) 79 | else: 80 | continue 81 | step, validation_loss, validation_accuracy = sess.run( 82 | [global_step, loss, accuracy], 83 | feed_dict=classifier.build_feed_dict(validation_data)) 84 | print('Step %d: loss=%f accuracy=%f' % ( 85 | step, validation_loss, validation_accuracy)) 86 | helpers.EmitValues(supervisor, sess, step, 87 | {'Validation Loss': validation_loss, 88 | 'Validation Accuracy': validation_accuracy}) 89 | if not FLAGS.eval_interval_secs: break 90 | time.sleep(FLAGS.eval_interval_secs) 91 | 92 | supervisor.Stop() 93 | 94 | 95 | if __name__ == '__main__': 96 | tf.app.run() 97 | -------------------------------------------------------------------------------- /tensorflow_fold/loom/calculator_example/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Helper functions for dealing with TensorFlow.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | # import google3 21 | import six 22 | import tensorflow as tf 23 | 24 | 25 | def EmitValues(supervisor, session, step, values): 26 | summary = tf.Summary() 27 | for name, value in six.iteritems(values): 28 | summary_value = summary.value.add() 29 | summary_value.tag = name 30 | summary_value.simple_value = float(value) 31 | supervisor.SummaryComputed(session, summary, global_step=step) 32 | -------------------------------------------------------------------------------- /tensorflow_fold/loom/calculator_example/make_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """make_dataset randomly constructs datsets for the calculator smoketest. 16 | 17 | The datasets are in TF Record format. The contents of the TF record are 18 | CalculatorExpression protos. 19 | """ 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | # import google3 25 | from six.moves import xrange # pylint: disable=redefined-builtin 26 | import tensorflow as tf 27 | from tensorflow_fold.loom.calculator_example import calculator 28 | 29 | tf.flags.DEFINE_string('output_path', '', 30 | 'Where to write the TFRecord file with the expressions.') 31 | tf.flags.DEFINE_integer('max_expression_depth', 5, 32 | 'Maximum expression depth.') 33 | tf.flags.DEFINE_integer('num_samples', 1000, 34 | 'How many samples to put into the table.') 35 | FLAGS = tf.flags.FLAGS 36 | 37 | 38 | def make_expression(): 39 | expression = calculator.random_expression(FLAGS.max_expression_depth) 40 | expression.result = calculator.evaluate_expression(expression) 41 | return expression.SerializeToString() 42 | 43 | 44 | def main(unused_argv): 45 | record_output = tf.python_io.TFRecordWriter(FLAGS.output_path) 46 | for _ in xrange(FLAGS.num_samples): 47 | record_output.write(make_expression()) 48 | record_output.close() 49 | 50 | 51 | if __name__ == '__main__': 52 | tf.app.run() 53 | -------------------------------------------------------------------------------- /tensorflow_fold/loom/calculator_example/model_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Smoke test for Loom Calculator Model.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | import random 20 | 21 | # import google3 22 | from six.moves import xrange # pylint: disable=redefined-builtin 23 | import tensorflow as tf 24 | 25 | from tensorflow_fold.loom.calculator_example import calculator 26 | from tensorflow_fold.loom.calculator_example import model 27 | 28 | 29 | class ModelTest(tf.test.TestCase): 30 | 31 | def test_loss_goes_down(self): 32 | # Ensure determinism: 33 | tf.set_random_seed(0xdeadbeef) 34 | random.seed(0xdeadbeef) 35 | 36 | expression_list = [] 37 | for _ in xrange(1000): 38 | expression = calculator.random_expression(5) 39 | expression.result = calculator.evaluate_expression(expression) 40 | expression_list.append(expression) 41 | 42 | classifier = model.CalculatorSignClassifier(embedding_length=3) 43 | 44 | variables = classifier.variables() 45 | loss = classifier.loss() 46 | 47 | optr = tf.train.GradientDescentOptimizer(0.001) 48 | trainer = optr.minimize(loss, var_list=variables) 49 | 50 | with tf.Session() as sess: 51 | sess.run(tf.global_variables_initializer()) 52 | old_loss = classifier.loss().eval( 53 | feed_dict=classifier.build_feed_dict(expression_list)) 54 | 55 | for _ in xrange(20): 56 | sess.run([trainer], 57 | feed_dict=classifier.build_feed_dict(expression_list)) 58 | 59 | new_loss = classifier.loss().eval( 60 | feed_dict=classifier.build_feed_dict(expression_list)) 61 | 62 | self.assertLess(new_loss, old_loss) 63 | 64 | 65 | if __name__ == '__main__': 66 | tf.test.main() 67 | -------------------------------------------------------------------------------- /tensorflow_fold/loom/calculator_example/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Runs the trainer for the calculator smoketest.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | # import google3 21 | from six.moves import xrange # pylint: disable=redefined-builtin 22 | import tensorflow as tf 23 | from tensorflow_fold.loom.calculator_example import calculator_pb2 24 | from tensorflow_fold.loom.calculator_example import helpers 25 | from tensorflow_fold.loom.calculator_example import model 26 | 27 | 28 | tf.flags.DEFINE_string( 29 | 'train_data_path', '', 30 | 'TF Record file containing the training dataset of expressions.') 31 | tf.flags.DEFINE_integer( 32 | 'batch_size', 1000, 'How many samples to read per batch.') 33 | tf.flags.DEFINE_integer( 34 | 'embedding_length', 5, 35 | 'How long to make the expression embedding vectors.') 36 | tf.flags.DEFINE_integer( 37 | 'max_steps', 1000000, 38 | 'The maximum number of batches to run the trainer for.') 39 | 40 | # Replication flags: 41 | tf.flags.DEFINE_string('logdir', '/tmp/calculator_smoketest', 42 | 'Directory in which to write event logs.') 43 | tf.flags.DEFINE_string('master', '', 44 | 'Tensorflow master to use.') 45 | tf.flags.DEFINE_integer('task', 0, 46 | 'Task ID of the replica running the training.') 47 | tf.flags.DEFINE_integer('ps_tasks', 0, 48 | 'Number of PS tasks in the job.') 49 | FLAGS = tf.flags.FLAGS 50 | 51 | 52 | def iterate_over_tf_record_protos(table_path, message_type): 53 | while True: 54 | for v in tf.python_io.tf_record_iterator(table_path): 55 | message = message_type() 56 | message.ParseFromString(v) 57 | yield message 58 | 59 | 60 | def main(unused_argv): 61 | train_iterator = iterate_over_tf_record_protos( 62 | FLAGS.train_data_path, calculator_pb2.CalculatorExpression) 63 | 64 | with tf.Graph().as_default(): 65 | with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): 66 | 67 | # Build the graph. 68 | global_step = tf.Variable(0, name='global_step', trainable=False) 69 | classifier = model.CalculatorSignClassifier(FLAGS.embedding_length) 70 | 71 | variables = classifier.variables() 72 | loss = classifier.loss() 73 | accuracy = classifier.accuracy() 74 | 75 | optr = tf.train.GradientDescentOptimizer(0.01) 76 | trainer = optr.minimize(loss, global_step=global_step, var_list=variables) 77 | 78 | # Set up the supervisor. 79 | supervisor = tf.train.Supervisor( 80 | logdir=FLAGS.logdir, 81 | is_chief=(FLAGS.task == 0), 82 | save_summaries_secs=10, 83 | save_model_secs=30) 84 | sess = supervisor.PrepareSession(FLAGS.master) 85 | 86 | # Run the trainer. 87 | for _ in xrange(FLAGS.max_steps): 88 | batch = [next(train_iterator) for _ in xrange(FLAGS.batch_size)] 89 | 90 | _, step, batch_loss, batch_accuracy = sess.run( 91 | [trainer, global_step, loss, accuracy], 92 | feed_dict=classifier.build_feed_dict(batch)) 93 | print('step=%d: batch loss=%f accuracy=%f' % ( 94 | step, batch_loss, batch_accuracy)) 95 | helpers.EmitValues(supervisor, sess, step, 96 | {'Batch Loss': batch_loss, 97 | 'Batch Accuracy': batch_accuracy}) 98 | 99 | if __name__ == '__main__': 100 | tf.app.run() 101 | -------------------------------------------------------------------------------- /tensorflow_fold/loom/deserializing_weaver_op.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | Licensed under the Apache License, Version 2.0 (the "License"); 3 | you may not use this file except in compliance with the License. 4 | You may obtain a copy of the License at 5 | http://www.apache.org/licenses/LICENSE-2.0 6 | Unless required by applicable law or agreed to in writing, software 7 | distributed under the License is distributed on an "AS IS" BASIS, 8 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | See the License for the specific language governing permissions and 10 | limitations under the License. 11 | ==============================================================================*/ 12 | 13 | #include "tensorflow_fold/loom/weaver.h" 14 | #include "tensorflow_fold/loom/weaver_op_base.h" 15 | #include "tensorflow/core/framework/op.h" 16 | #include "tensorflow/core/framework/op_kernel.h" 17 | 18 | namespace tensorflow { 19 | namespace fold { 20 | 21 | REGISTER_WEAVER_OP("DeserializingWeaver") 22 | .Input("weaver_messages: string"); 23 | 24 | // A Weaver op which: 25 | // 1. Reads one or more serialized WeaverMessages from `weaver_messages`, its 26 | // input tensor. 27 | // 2. Merges them if there are more than one, and 28 | // 3. Creates output tensors that can drive the Loom using the resulting Weaver. 29 | // 30 | // (Item 3 is handled by WeaverOpBase.) 31 | // 32 | // Note: the reason merges are supported in this op to allow the user to 33 | // pre-compute many WeaverMessages (for example, one per element of the training 34 | // set) and then group them together into random mini-batches at run-time. 35 | // 36 | // A second reason merges are supported is that for large input examples, and 37 | // large batch sizes, merges done in advance of `DeserializingWeaverOp` could 38 | // push the resulting `WeaverMessage` over the protocol buffer size limit. 39 | class DeserializingWeaverOp : public WeaverOpBase { 40 | public: 41 | explicit DeserializingWeaverOp(tensorflow::OpKernelConstruction *c) 42 | : WeaverOpBase(c) {} 43 | 44 | tensorflow::Status Weave( 45 | tensorflow::OpKernelContext *c, Weaver* weaver) override { 46 | auto weaver_messages = c->input(0).flat(); 47 | if (weaver_messages.size() < 1) { 48 | return tensorflow::errors::InvalidArgument( 49 | "weaver_messages must contain at least one value."); 50 | } 51 | if (!weaver->Deserialize(weaver_messages(0))) { 52 | return tensorflow::errors::Internal( 53 | "Failed to deserialize WeaverMessage: ", weaver->error_string()); 54 | } 55 | 56 | // Note: If necessary, this loop could be sped up by merging the messages in 57 | // a multi-threaded way instead of in sequence. 58 | for (int64 i = 1; i < weaver_messages.size(); ++i) { 59 | if (!weaver->MergeFromSerialized(weaver_messages(i))) { 60 | return tensorflow::errors::Internal( 61 | "Failed to merge WeaverMessage", i, ":", weaver->error_string()); 62 | } 63 | } 64 | return tensorflow::Status::OK(); 65 | } 66 | }; 67 | 68 | REGISTER_KERNEL_BUILDER( 69 | Name("DeserializingWeaver").Device(tensorflow::DEVICE_CPU), 70 | DeserializingWeaverOp); 71 | 72 | } // namespace fold 73 | } // namespace tensorflow 74 | -------------------------------------------------------------------------------- /tensorflow_fold/loom/deserializing_weaver_op.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """DeserializingWeaver.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os.path 21 | 22 | # import google3 23 | import tensorflow as tf 24 | from tensorflow_fold.loom.weaver_op_base import RegisterWeaverOp 25 | 26 | # This is a gross hack that (apparently) prevents Python from 27 | # occasionally segfaulting at shutdown when unlinking dynamic 28 | # libraries, possibly related to . We need to 29 | # call some function tf.pywrap_tensorflow that munges protos, *before* 30 | # we dlopen the deserializing weaver library (which also munges protos). 31 | tf.pywrap_tensorflow.list_devices() 32 | 33 | _deserializing_weaver = tf.load_op_library(os.path.join( 34 | tf.resource_loader.get_data_files_path(), '_deserializing_weaver_op.so')) 35 | deserializing_weaver = _deserializing_weaver.deserializing_weaver 36 | 37 | RegisterWeaverOp('DeserializingWeaver') 38 | -------------------------------------------------------------------------------- /tensorflow_fold/loom/loom.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | import "tensorflow/core/framework/types.proto"; 4 | import "tensorflow/core/framework/tensor.proto"; 5 | 6 | package tensorflow.fold; 7 | 8 | message LoomMetadata { 9 | // max_depth: maximum depth to nest loom operations (or -1 if the graph is 10 | // being constructed with a while loop). 11 | optional int32 max_depth = 1; 12 | 13 | message TypeShapeMetadata { 14 | // The type of the TypeShape. 15 | optional tensorflow.DataType dtype = 1; 16 | 17 | // The shape of the TypeShape (size of each dimension, in order). 18 | repeated int64 shape = 2 [packed=true]; 19 | 20 | // A TypeShape's `tag` field is used to ensure the distinctness of 21 | // TypeShapes that might otherwise collide in some cases. (This is 22 | // important because Loom's output tensors are indexed by TypeShape). 23 | optional string tag = 3; 24 | 25 | // `name` is used to refer to the TypeShape when printing debug messages. 26 | optional string name = 4; 27 | 28 | // The names of particular tensors of this TypeShape given as inputs to the 29 | // loom (via `named_tensors`). 30 | repeated string tensor_names = 5; 31 | 32 | // If true then the Weaver does not track constants of this type-shape, 33 | // either because they are being supplied by a mechanism external to the 34 | // scheduler (direct_feed_dict mode) or because they are being fed in as a 35 | // single batch tensor of unspecified length. 36 | optional bool is_batch_input = 6; 37 | } 38 | repeated TypeShapeMetadata type_shape_metadata = 2; 39 | 40 | message OpMetadata { 41 | // The name of this operation. 42 | optional string name = 1; 43 | 44 | // The types of the inputs to this operation. (Values are indices into 45 | // TypeShapeMetadata). 46 | repeated int32 input_ts_idx = 2; 47 | 48 | // The types of the outputs of this operation. (Values are indices into 49 | // TypeShapeMetadata). 50 | repeated int32 output_ts_idx = 3; 51 | } 52 | repeated OpMetadata op_metadata = 3; 53 | } 54 | 55 | // WeaverMessage contains the wiring diagram (or schedule) which drives the 56 | // TensorFlow network created in Loom._setup_network in loom.py. 57 | message WeaverMessage { 58 | // This block of fields should mirror the contents of LoomResult (see 59 | // weaver.h) and all of them should have the same length. 60 | // Those LoomResults are in turn populated by calls to _Weaver's methods. 61 | repeated int32 depth = 1 [packed=true]; 62 | repeated int32 ts_idx = 2 [packed=true]; 63 | repeated int32 op_idx = 3 [packed=true]; 64 | repeated int32 op_output_idx = 4 [packed=true]; 65 | repeated int32 pos_idx = 5 [packed=true]; 66 | repeated int32 cached_passthrough = 6 [packed=true]; 67 | 68 | repeated int32 num_constants_by_type_shape = 7 [packed=true]; 69 | 70 | // The constants of each type-shape are stacked along their 0'th dimensions 71 | // and then converted to TensorProto to serialize. 72 | repeated tensorflow.TensorProto constant_values_by_type_shape = 10; 73 | 74 | 75 | message WiringMessage { 76 | // The values of the `wiring_results_` field from Weaver 77 | repeated int32 result_id = 1 [packed=true]; 78 | 79 | // The keys of the `wiring_results_` field from Weaver: 80 | optional int32 depth = 2; 81 | optional int32 op_idx = 3; 82 | optional int32 arg_idx = 4; 83 | } 84 | 85 | repeated WiringMessage wiring = 8; 86 | 87 | // The result IDs marked as outputs using `Weaver::AddOutput`. 88 | repeated int32 output_result_id = 9 [packed=true]; 89 | } 90 | -------------------------------------------------------------------------------- /tensorflow_fold/loom/platform.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | Licensed under the Apache License, Version 2.0 (the "License"); 3 | you may not use this file except in compliance with the License. 4 | You may obtain a copy of the License at 5 | http://www.apache.org/licenses/LICENSE-2.0 6 | Unless required by applicable law or agreed to in writing, software 7 | distributed under the License is distributed on an "AS IS" BASIS, 8 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | See the License for the specific language governing permissions and 10 | limitations under the License. 11 | ==============================================================================*/ 12 | 13 | #ifndef TENSORFLOW_FOLD_LOOM_PLATFORM_H_ 14 | #define TENSORFLOW_FOLD_LOOM_PLATFORM_H_ 15 | 16 | #if !defined(PLATFORM_GOOGLE) 17 | #include 18 | 19 | #include "tensorflow/core/platform/types.h" 20 | 21 | namespace tensorflow { 22 | namespace fold { 23 | 24 | using std::string; 25 | 26 | } // namespace fold 27 | } // namespace tensorflow 28 | 29 | #endif // !defined(PLATFORM_GOOGLE) 30 | 31 | #endif // TENSORFLOW_FOLD_LOOM_PLATFORM_H_ 32 | -------------------------------------------------------------------------------- /tensorflow_fold/loom/weaver_op_base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utility code for defining an op based on WeaverOpBase.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | # import google3 21 | import tensorflow as tf 22 | 23 | 24 | def RegisterWeaverOp(op_name): 25 | tf.NotDifferentiable(op_name) 26 | -------------------------------------------------------------------------------- /tensorflow_fold/public/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | load("//tensorflow_fold:fold.bzl", "fold_cc_library", "fold_py_library") 4 | 5 | package( 6 | default_visibility = [ 7 | "//visibility:public", 8 | ], 9 | ) 10 | 11 | fold_py_library( 12 | name = "loom", 13 | srcs = ["loom.py"], 14 | deps = [ 15 | "//tensorflow_fold/loom", 16 | ], 17 | ) 18 | 19 | fold_cc_library( 20 | name = "loom_cc", 21 | hdrs = ["loom.h"], 22 | deps = [ 23 | "//tensorflow_fold/loom:deserializing_weaver_op_cc", 24 | "//tensorflow_fold/loom:weaver", 25 | "//tensorflow_fold/loom:weaver_op_base", 26 | ], 27 | ) 28 | 29 | fold_py_library( 30 | name = "blocks", 31 | srcs = ["blocks.py"], 32 | deps = [ 33 | "//tensorflow_fold/blocks", 34 | "//tensorflow_fold/blocks:plan", 35 | "//tensorflow_fold/blocks:result_types", 36 | "//tensorflow_fold/blocks:util", 37 | ], 38 | ) 39 | -------------------------------------------------------------------------------- /tensorflow_fold/public/blocks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """High-level Blocks API for [TensorFlow Fold](../index.md). 16 | 17 | ## Compiler 18 | 19 | @@Compiler 20 | 21 | ## Blocks for input 22 | 23 | @@Tensor 24 | @@Scalar 25 | @@Vector 26 | @@InputTransform 27 | @@SerializedMessageToTree 28 | @@OneHot 29 | @@OneHotFromList 30 | @@Optional 31 | 32 | ## Blocks for composition 33 | 34 | @@Composition 35 | @@Pipe 36 | @@Record 37 | @@AllOf 38 | 39 | ## Blocks for tensors 40 | 41 | @@FromTensor 42 | @@Function 43 | @@Concat 44 | @@Zeros 45 | 46 | ## Blocks for sequences 47 | 48 | @@Map 49 | @@Fold 50 | @@RNN 51 | @@Reduce 52 | @@Sum 53 | @@Min 54 | @@Max 55 | @@Mean 56 | @@Broadcast 57 | @@Zip 58 | @@ZipWith 59 | @@NGrams 60 | @@Nth 61 | @@GetItem 62 | @@Length 63 | @@Slice 64 | 65 | ## Other blocks 66 | 67 | @@ForwardDeclaration 68 | @@OneOf 69 | @@Metric 70 | @@Identity 71 | @@Void 72 | 73 | ## Layers 74 | 75 | @@FC 76 | @@Embedding 77 | @@FractalNet 78 | @@ScopedLayer 79 | 80 | ## Types 81 | 82 | @@TensorType 83 | @@VoidType 84 | @@PyObjectType 85 | @@TupleType 86 | @@SequenceType 87 | @@BroadcastSequenceType 88 | 89 | ## Plans 90 | 91 | @@Plan 92 | @@TrainPlan 93 | @@EvalPlan 94 | @@InferPlan 95 | @@define_plan_flags 96 | @@plan_default_params 97 | 98 | ## Conversion functions 99 | 100 | @@convert_to_block 101 | @@convert_to_type 102 | @@canonicalize_type 103 | 104 | ## Utilities 105 | 106 | @@EdibleIterator 107 | @@group_by_batches 108 | @@epochs 109 | @@parse_spec 110 | @@build_optimizer_from_params 111 | @@create_variable_scope 112 | 113 | ## Abstract classes 114 | 115 | @@IOBase 116 | @@Block 117 | @@Layer 118 | @@ResultType 119 | """ 120 | 121 | # This is the entrypoint for importing the TensorFlow Fold Blocks library. 122 | # We suggest importing it as: 123 | # import tensorflow_fold.public.blocks as td 124 | 125 | ## Regenerating the Docs 126 | # 127 | # Fold's API docs are extracted from the toplevel docstring of 128 | # third_party.tensorflow_fold.public.blocks and docstrings from the other 129 | # files that it refers to. 130 | # 131 | 132 | # pylint: disable=wildcard-import, unused-import 133 | from tensorflow_fold.blocks.block_compiler import * 134 | from tensorflow_fold.blocks.blocks import * 135 | from tensorflow_fold.blocks.layers import * 136 | from tensorflow_fold.blocks.metrics import * 137 | from tensorflow_fold.blocks.plan import * 138 | from tensorflow_fold.blocks.result_types import * 139 | from tensorflow_fold.blocks.util import * 140 | # pylint: enable=wildcard-import, unused-import 141 | -------------------------------------------------------------------------------- /tensorflow_fold/public/loom.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | Licensed under the Apache License, Version 2.0 (the "License"); 3 | you may not use this file except in compliance with the License. 4 | You may obtain a copy of the License at 5 | http://www.apache.org/licenses/LICENSE-2.0 6 | Unless required by applicable law or agreed to in writing, software 7 | distributed under the License is distributed on an "AS IS" BASIS, 8 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | See the License for the specific language governing permissions and 10 | limitations under the License. 11 | ==============================================================================*/ 12 | 13 | #ifndef TENSORFLOW_FOLD_PUBLIC_LOOM_H_ 14 | #define TENSORFLOW_FOLD_PUBLIC_LOOM_H_ 15 | 16 | #include "tensorflow_fold/loom/weaver.h" 17 | #include "tensorflow_fold/loom/weaver_op_base.h" 18 | 19 | #endif // TENSORFLOW_FOLD_PUBLIC_LOOM_H_ 20 | -------------------------------------------------------------------------------- /tensorflow_fold/public/loom.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """This is the low-level Loom API for [TensorFlow Fold](../index.md). 16 | 17 | As a simple example here's a loom that lets you evaluate arbitrary trees of 18 | element-wise adds and multiplies on floating point vectors of length 3. 19 | 20 | ```python 21 | class BinaryLoomOp(loom.LoomOp): 22 | 23 | def __init__(self, type_shape, op): 24 | self._op = op 25 | super(BinaryLoomOp, self).__init__( 26 | [type_shape, type_shape], [type_shape]) 27 | 28 | def instantiate_batch(self, inputs): 29 | return [self._op(inputs[0], inputs[1])] 30 | 31 | # Set up some loom ops: 32 | 33 | x_var = tf.Variable(tf.zeros(3, dtype='float64'), name='x') 34 | y_var = tf.Variable(tf.zeros(3, dtype='float64'), name='y') 35 | vec_3 = loom.TypeShape('float64', (3,)) 36 | vec_loom = loom.Loom( 37 | named_tensors={'x': x_var, 'y': y_var}, 38 | named_ops={'add': BinaryLoomOp(vec_3, tf.add), 39 | 'mul': BinaryLoomOp(vec_3, tf.mul)}) 40 | 41 | vec_3_out = vec_loom.output_tensor(vec_3) 42 | 43 | loss = tf.nn.l2_loss(vec_3_out) 44 | 45 | # Then, when parsing a particular example we can do something like: 46 | with tf.Session() as sess: 47 | sess.run(tf.global_variables_initializer()) 48 | # Loop over examples (the expression we build should depend on the example but 49 | # for simplicity we'll just build x + [1, 5, 8]) 50 | weaver = vec_loom.make_weaver() 51 | x = weaver.x 52 | c = weaver(np.array([1, 5, 8], dtype='float64')) 53 | x_plus_c = weaver.add(x, c) 54 | # In this case vec_3_out will contain a 1x3 matrix whose rows is x+c 55 | print("loss=", loss.eval(feed_dict=weaver.build_feed_dict([x_plus_c]))) 56 | 57 | # Note: you could also evaluate multiple expressions at once. For example: 58 | weaver = vec_loom.make_weaver() 59 | x = weaver.x 60 | y = weaver.y 61 | c_squared = weaver.add(weaver.mul(x, x), weaver.mul(y, y)) 62 | x_plus_y = weaver.add(x, y) 63 | x_plus_y_squared = weaver.mul(x_plus_y, x_plus_y) 64 | print("loss=", loss.eval( 65 | feed_dict=weaver.build_feed_dict([c_squared, x_plus_y_squared]))) 66 | # In this case vec_3_out will contain a 2x3 matrix whose rows are x^2+y^2 and 67 | # (x+y)^2 (with multiplication being component-wise.) 68 | ``` 69 | 70 | @@TypeShape 71 | @@LoomOp 72 | @@PassThroughLoomOp 73 | @@Loom 74 | @@Weaver 75 | """ 76 | 77 | # This is the entrypoint for importing the TensorFlow Fold Loom library. 78 | # We suggest importing it as: 79 | # import tensorflow_fold.public.loom 80 | 81 | ## Regenerating the Docs 82 | # 83 | # Fold's API docs are extracted from the toplevel docstring of 84 | # third_party.tensorflow_fold.public.blocks and docstrings from the other 85 | # files that it refers to. 86 | # 87 | 88 | # pylint: disable=wildcard-import, unused-import 89 | from tensorflow_fold.loom.loom import * 90 | # pylint: enable=wildcard-import, unused-import 91 | -------------------------------------------------------------------------------- /tensorflow_fold/run_all_examples.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Smoke test that runs all examples. Command-line arguments are 3 | # forwarded to 'bazel build', e.g. 4 | # ./tensorflow_fold/run_all_examples.sh --config=opt 5 | set -o verbose 6 | set -e 7 | bazel build "$@" tensorflow_fold/... 8 | 9 | # loom benchmark 10 | ./bazel-bin/tensorflow_fold/loom/benchmarks/iclr_2017_benchmark \ 11 | --vector_size=8 --tree_size=4 --num_repeats=1 --alsologtostderr 12 | 13 | ## loom calculator 14 | TMP=$(mktemp -d) 15 | ./bazel-bin/tensorflow_fold/loom/calculator_example/make_dataset \ 16 | --output_path="${TMP}"/examples --num_samples=15 --alsologtostderr 17 | ./bazel-bin/tensorflow_fold/loom/calculator_example/train \ 18 | --train_data_path="${TMP}"/examples --batch_size=10 --max_steps=2 \ 19 | --logdir="${TMP}" --alsologtostderr 20 | ./bazel-bin/tensorflow_fold/loom/calculator_example/eval \ 21 | --validation_data_path="${TMP}"/examples --eval_interval_secs=0 \ 22 | --logdir="${TMP}" --alsologtostderr 23 | 24 | # blocks calculator 25 | TMP2=$(mktemp -d) 26 | ./bazel-bin/tensorflow_fold/blocks/examples/calculator/train \ 27 | --train_data_path="${TMP}"/examples --batch_size=10 --max_steps=2 \ 28 | --logdir="${TMP2}" --alsologtostderr 29 | 30 | # blocks fizzbuzz 31 | ./bazel-bin/tensorflow_fold/blocks/examples/fizzbuzz/fizzbuzz \ 32 | --validation_size=5 --steps=3 --batches_per_step=2 --alsologtostderr 33 | 34 | # blocks mnist 35 | TMP=$(mktemp -d) 36 | ./bazel-bin/tensorflow_fold/blocks/examples/mnist/mnist \ 37 | --logdir_base="${TMP}" --epochs=1 --alsologtostderr 38 | 39 | # blocks sentiment 40 | TMP=$(mktemp -d) 41 | echo '( 0.1 0.2' > "${TMP}"/glove 42 | echo ') 0.3 0.4' >> "${TMP}"/glove 43 | echo 'foo 0.1 0.2' >> "${TMP}"/glove 44 | echo 'bar 0.3 0.4' >> "${TMP}"/glove 45 | echo 'foo|bar|)|(|baz' > "${TMP}"/sents 46 | echo '(3 (1 bar) (2 mu))' > "${TMP}"/train.txt 47 | echo '(3 (1 bar) (2 mu))' > "${TMP}"/dev.txt 48 | echo '(3 (1 bar) (2 mu))' > "${TMP}"/test.txt 49 | ./bazel-bin/tensorflow_fold/blocks/examples/sentiment/filter_glove \ 50 | --glove_file="${TMP}"/glove --sentence_file="${TMP}"/sents \ 51 | --output_file="${TMP}"/glove_filtered --alsologtostderr 52 | ./bazel-bin/tensorflow_fold/blocks/examples/sentiment/train \ 53 | --checkpoint_base="${TMP}"/model --epochs=2 --tree_dir="${TMP}" \ 54 | --embedding_file="${TMP}"/glove_filtered --alsologtostderr 55 | ./bazel-bin/tensorflow_fold/blocks/examples/sentiment/eval \ 56 | --checkpoint_file="${TMP}"/model-1 --tree_dir="${TMP}" \ 57 | --embedding_file="${TMP}"/glove_filtered --alsologtostderr 58 | 59 | # blocks language_id 60 | TMP=$(mktemp -d) 61 | if [[ ! -f /tmp/roman_sentences.csv ]]; then 62 | ./tensorflow_fold/blocks/examples/language_id/fetch_datasets.sh 63 | fi 64 | ./bazel-bin/tensorflow_fold/blocks/examples/language_id/language_id \ 65 | --logdir_base="${TMP}" --epochs=2 --alsologtostderr 66 | -------------------------------------------------------------------------------- /tensorflow_fold/util/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | load( 4 | "//tensorflow_fold:fold.bzl", 5 | "fold_cc_library", 6 | "fold_proto_library", 7 | "fold_py_extension", 8 | "fold_py_library", 9 | "fold_py_test", 10 | ) 11 | 12 | package( 13 | default_visibility = [ 14 | "//visibility:public", 15 | ], 16 | ) 17 | 18 | fold_proto_library( 19 | srcs = ["test.proto"], 20 | cc_name = "test_proto", 21 | py_name = "test_py_pb2", 22 | ) 23 | 24 | fold_proto_library( 25 | srcs = ["test3.proto"], 26 | cc_name = "test3_proto", 27 | py_name = "test3_py_pb2", 28 | ) 29 | 30 | filegroup( 31 | name = "test_proto_files", 32 | srcs = [ 33 | "test.proto", 34 | "test3.proto", 35 | ], 36 | ) 37 | 38 | # main() for C++ tests 39 | fold_cc_library( 40 | name = "test_main", 41 | testonly = 1, 42 | srcs = ["test_main.cc"], 43 | linkopts = ["-lm"], 44 | deps = [ 45 | "@com_google_googletest//:gtest", 46 | "@org_tensorflow//tensorflow/core:framework_lite", 47 | "@org_tensorflow//tensorflow/core:testlib", 48 | ], 49 | ) 50 | 51 | fold_py_extension( 52 | name = "proto_tools", 53 | srcs = ["proto_tools.cc"], 54 | outs = ["proto_tools.so"], 55 | deps = [ 56 | "@org_tensorflow//tensorflow/core:framework_headers_lib", 57 | "@org_tensorflow//util/python:python_headers", 58 | "@protobuf_archive//:protobuf", 59 | ], 60 | ) 61 | 62 | fold_py_library( 63 | name = "proto", 64 | srcs = [], 65 | cc_deps = [":proto_tools"], 66 | ) 67 | 68 | fold_py_test( 69 | name = "proto_test", 70 | srcs = ["proto_test.py"], 71 | data = [ 72 | "test.proto", 73 | "test3.proto", 74 | ], 75 | deps = [ 76 | ":proto", 77 | ":test3_py_pb2", 78 | ":test_py_pb2", 79 | "@org_tensorflow//tensorflow:tensorflow_py", 80 | ], 81 | ) 82 | 83 | fold_py_library( 84 | name = "cpp_test_proto_lib", 85 | srcs = [], 86 | deps = [ 87 | ":test3_py_pb2", 88 | ":test_py_pb2", 89 | ], 90 | ) 91 | 92 | sh_binary( 93 | name = "build_pip_package", 94 | srcs = ["build_pip_package.sh"], 95 | data = [ 96 | "//tensorflow_fold:headers", 97 | "//tensorflow_fold/public:blocks", 98 | "//tensorflow_fold/public:loom", 99 | "//tensorflow_fold/public:loom_cc", 100 | ], 101 | ) 102 | -------------------------------------------------------------------------------- /tensorflow_fold/util/build_pip_package.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2017 Google Inc. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Script for building a pip package. 18 | # 19 | # Based on tensorflow/tools/pip_package/build_pip_package.sh. 20 | set -e 21 | 22 | function main() { 23 | if [ $# -lt 1 ] ; then 24 | echo "No destination dir provided" 25 | exit 1 26 | fi 27 | 28 | DEST=$1 29 | TMPDIR=$(mktemp -d -t tmp.XXXXXXXXXX) 30 | 31 | echo $(date) : "=== Using tmpdir: ${TMPDIR}" 32 | 33 | if [ ! -d bazel-bin/tensorflow_fold ]; then 34 | echo "Could not find bazel-bin. Did you run from the root of the build tree?" 35 | exit 1 36 | fi 37 | 38 | cp -R \ 39 | bazel-bin/tensorflow_fold/util/build_pip_package.runfiles/org_tensorflow_fold/tensorflow_fold \ 40 | "${TMPDIR}" 41 | 42 | cp -f "${TMPDIR}"/tensorflow_fold/public/blocks.py "${TMPDIR}"/tensorflow_fold/__init__.py 43 | cp -f "${TMPDIR}"/tensorflow_fold/public/loom.py "${TMPDIR}"/tensorflow_fold/loom/__init__.py 44 | 45 | cp tensorflow_fold/util/setup.py ${TMPDIR} 46 | 47 | # Before we leave the top-level directory, make sure we know how to 48 | # call python. 49 | source tensorflow/tools/python_bin_path.sh 50 | 51 | pushd ${TMPDIR} 52 | echo $(date) : "=== Building wheel" 53 | "${PYTHON_BIN_PATH:-python}" setup.py bdist_wheel >/dev/null 54 | mkdir -p ${DEST} 55 | cp dist/* ${DEST} 56 | popd 57 | rm -rf ${TMPDIR} 58 | echo $(date) : "=== Output wheel file is in: ${DEST}" 59 | } 60 | 61 | main "$@" 62 | -------------------------------------------------------------------------------- /tensorflow_fold/util/test.proto: -------------------------------------------------------------------------------- 1 | // This file contains protocol buffers that are generically useful for tests. 2 | syntax = "proto2"; 3 | 4 | package tensorflow.fold.util; 5 | 6 | message SomeEnumScope { 7 | enum SomeEnum { 8 | THIS = 0; 9 | THAT = 1; 10 | THE_OTHER = 2; 11 | } 12 | } 13 | 14 | message EveryTypeList { 15 | repeated EveryType item = 1; 16 | } 17 | 18 | message EveryType { 19 | optional int32 some_int32 = 1; 20 | optional int64 some_int64 = 2; 21 | optional uint32 some_uint32 = 3; 22 | optional uint64 some_uint64 = 4; 23 | optional double some_double = 5; 24 | optional float some_float = 6; 25 | optional bool some_bool = 7; 26 | 27 | repeated int32 many_int32 = 11; 28 | repeated int64 many_int64 = 12; 29 | repeated uint32 many_uint32 = 13; 30 | repeated uint64 many_uint64 = 14; 31 | repeated double many_double = 15; 32 | repeated float many_float = 16; 33 | repeated bool many_bool = 17; 34 | 35 | optional SomeEnumScope.SomeEnum some_enum = 20; 36 | repeated SomeEnumScope.SomeEnum many_enum = 21; 37 | 38 | optional string some_string = 30; 39 | repeated string many_string = 31; 40 | 41 | optional EveryAtom some_nest = 40; 42 | repeated EveryAtom many_nest = 41; 43 | } 44 | 45 | message EveryAtom { 46 | optional int32 some_int32 = 1; 47 | optional int64 some_int64 = 2; 48 | optional uint32 some_uint32 = 3; 49 | optional uint64 some_uint64 = 4; 50 | optional double some_double = 5; 51 | optional float some_float = 6; 52 | optional bool some_bool = 7; 53 | optional SomeEnumScope.SomeEnum some_enum = 8; 54 | optional string some_string = 9; 55 | } 56 | 57 | message OneAtom { 58 | oneof atom_type { 59 | int32 some_int32 = 1; 60 | int64 some_int64 = 2; 61 | uint32 some_uint32 = 3; 62 | uint64 some_uint64 = 4; 63 | double some_double = 5; 64 | float some_float = 6; 65 | bool some_bool = 7; 66 | SomeEnumScope.SomeEnum some_enum = 8; 67 | string some_string = 9; 68 | } 69 | } 70 | 71 | message CyclicType { 72 | optional int32 some_int32 = 1; 73 | optional bool some_bool = 7; 74 | repeated int32 many_int32 = 11; 75 | repeated bool many_bool = 17; 76 | optional SomeEnumScope.SomeEnum some_enum = 20; 77 | optional CyclicType some_same = 40; 78 | repeated CyclicType many_same = 41; 79 | } 80 | 81 | message Nested1 { 82 | optional string baz = 1; 83 | } 84 | 85 | message Nested2 { 86 | optional string bar = 1; 87 | optional Nested1 nested1 = 2; 88 | } 89 | 90 | message Nested3 { 91 | optional string foo = 1; 92 | optional Nested2 nested2 = 2; 93 | } 94 | 95 | message NonConsecutiveEnumMessage { 96 | enum NonConsecutiveEnum { 97 | SEVEN = 7; 98 | THREE = 3; 99 | } 100 | optional NonConsecutiveEnum the_enum = 1; 101 | } 102 | -------------------------------------------------------------------------------- /tensorflow_fold/util/test3.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorflow.fold.util; 4 | 5 | message SomeEnum3Scope { 6 | enum SomeEnum3 { 7 | THIS = 0; 8 | THAT = 1; 9 | THE_OTHER = 3; 10 | } 11 | } 12 | 13 | message EveryType3 { 14 | int32 some_int32 = 1; 15 | int64 some_int64 = 2; 16 | uint32 some_uint32 = 3; 17 | uint64 some_uint64 = 4; 18 | double some_double = 5; 19 | float some_float = 6; 20 | bool some_bool = 7; 21 | 22 | repeated int32 many_int32 = 11; 23 | repeated int64 many_int64 = 12; 24 | repeated uint32 many_uint32 = 13; 25 | repeated uint64 many_uint64 = 14; 26 | repeated double many_double = 15; 27 | repeated float many_float = 16; 28 | repeated bool many_bool = 17; 29 | 30 | SomeEnum3Scope.SomeEnum3 some_enum = 20; 31 | repeated SomeEnum3Scope.SomeEnum3 many_enum = 21; 32 | 33 | string some_string = 30; 34 | repeated string many_string = 31; 35 | 36 | EveryAtom3 some_nest = 40; 37 | repeated EveryAtom3 many_nest = 41; 38 | } 39 | 40 | message CyclicType3 { 41 | int32 some_int32 = 1; 42 | bool some_bool = 7; 43 | repeated int32 many_int32 = 11; 44 | repeated bool many_bool = 17; 45 | SomeEnum3Scope.SomeEnum3 some_enum = 20; 46 | CyclicType3 some_same = 40; 47 | repeated CyclicType3 many_same = 41; 48 | } 49 | 50 | message EveryAtom3 { 51 | int32 some_int32 = 1; 52 | int64 some_int64 = 2; 53 | uint32 some_uint32 = 3; 54 | uint64 some_uint64 = 4; 55 | double some_double = 5; 56 | float some_float = 6; 57 | bool some_bool = 7; 58 | SomeEnum3Scope.SomeEnum3 some_enum = 8; 59 | string some_string = 9; 60 | } 61 | -------------------------------------------------------------------------------- /tensorflow_fold/util/test_main.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 Google Inc. All Rights Reserved. 2 | Licensed under the Apache License, Version 2.0 (the "License"); 3 | you may not use this file except in compliance with the License. 4 | You may obtain a copy of the License at 5 | http://www.apache.org/licenses/LICENSE-2.0 6 | Unless required by applicable law or agreed to in writing, software 7 | distributed under the License is distributed on an "AS IS" BASIS, 8 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | See the License for the specific language governing permissions and 10 | limitations under the License. 11 | ==============================================================================*/ 12 | 13 | // A program with a main that is suitable for unittests, including those 14 | // that also define microbenchmarks. Based on whether the user specified 15 | // the --benchmark_filter flag which specifies which benchmarks to run, 16 | // we will either run benchmarks or run the gtest tests in the program. 17 | 18 | #include "tensorflow/core/platform/platform.h" 19 | #include "tensorflow/core/platform/types.h" 20 | 21 | #if defined(PLATFORM_GOOGLE) || defined(__ANDROID__) 22 | // main() is supplied by gunit_main 23 | #else 24 | #include "gtest/gtest.h" 25 | #include "tensorflow/core/lib/core/stringpiece.h" 26 | #include "tensorflow/core/platform/test_benchmark.h" 27 | 28 | GTEST_API_ int main(int argc, char** argv) { 29 | std::cout << "Running main() from test_main.cc\n"; 30 | 31 | testing::InitGoogleTest(&argc, argv); 32 | for (int i = 1; i < argc; i++) { 33 | if (tensorflow::StringPiece(argv[i]).starts_with("--benchmarks=")) { 34 | const char* pattern = argv[i] + strlen("--benchmarks="); 35 | tensorflow::testing::Benchmark::Run(pattern); 36 | return 0; 37 | } 38 | } 39 | return RUN_ALL_TESTS(); 40 | } 41 | #endif 42 | -------------------------------------------------------------------------------- /tensorflow_fold/workspace.bzl: -------------------------------------------------------------------------------- 1 | # TensorFlow Fold external dependencies that can be loaded in WORKSPACE 2 | # files. 3 | 4 | load('@org_tensorflow//tensorflow:workspace.bzl', 'tf_workspace') 5 | 6 | # All TensorFlow Fold external dependencies. 7 | # workspace_dir is the absolute path to the TensorFlow Fold repo. If linked 8 | # as a submodule, it'll likely be '__workspace_dir__ + "/fold"' 9 | def tf_fold_workspace(): 10 | tf_workspace(tf_repo_name = "org_tensorflow") 11 | 12 | # ===== gRPC dependencies ===== 13 | native.bind( 14 | name = "libssl", 15 | actual = "@boringssl//:ssl", 16 | ) 17 | 18 | native.bind( 19 | name = "zlib", 20 | actual = "@zlib_archive//:zlib", 21 | ) 22 | 23 | native.bind( 24 | name = "gmock", 25 | actual = "@gmock_archive//:gmock", 26 | ) 27 | -------------------------------------------------------------------------------- /tools/bazel.rc: -------------------------------------------------------------------------------- 1 | import %workspace%/tensorflow/tools/bazel.rc 2 | build --package_path=%workspace%:%workspace%/tensorflow/ 3 | 4 | --------------------------------------------------------------------------------