├── CONTRIBUTING ├── LICENSE ├── README.md ├── convergence_test.py ├── evaluate.py ├── meta.py ├── meta_test.py ├── networks.py ├── networks_test.py ├── preprocess.py ├── preprocess_test.py ├── problems.py ├── problems_test.py ├── train.py └── util.py /CONTRIBUTING: -------------------------------------------------------------------------------- 1 | Want to contribute? Great! First, read this page (including the small print at the end). 2 | 3 | ### Before you contribute 4 | Before we can use your code, you must sign the 5 | [Google Individual Contributor License Agreement] 6 | (https://cla.developers.google.com/about/google-individual) 7 | (CLA), which you can do online. The CLA is necessary mainly because you own the 8 | copyright to your changes, even after your contribution becomes part of our 9 | codebase, so we need your permission to use and distribute your code. We also 10 | need to be sure of various other things—for instance that you'll tell us if you 11 | know that your code infringes on other people's patents. You don't have to sign 12 | the CLA until after you've submitted your code for review and a member has 13 | approved it, but you must do it before we can put your code into our codebase. 14 | Before you start working on a larger contribution, you should get in touch with 15 | us first through the issue tracker with your idea so that we can help out and 16 | possibly guide you. Coordinating up front makes it much easier to avoid 17 | frustration later on. 18 | 19 | ### Code reviews 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. 22 | 23 | ### The small print 24 | Contributions made by corporations are covered by a different agreement than 25 | the one above, the 26 | [Software Grant and Corporate Contributor License Agreement] 27 | (https://cla.developers.google.com/about/google-corporate). 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Learning to Learn](https://arxiv.org/abs/1606.04474) in TensorFlow 2 | 3 | 4 | ## Dependencies 5 | 6 | * [TensorFlow >=1.0](https://www.tensorflow.org/) 7 | * [Sonnet >=1.0](https://github.com/deepmind/sonnet) 8 | 9 | 10 | ## Training 11 | 12 | ``` 13 | python train.py --problem=mnist --save_path=./mnist 14 | ``` 15 | 16 | Command-line flags: 17 | 18 | * `save_path`: If present, the optimizer will be saved to the specified path 19 | every time the evaluation performance is improved. 20 | * `num_epochs`: Number of training epochs. 21 | * `log_period`: Epochs before mean performance and time is reported. 22 | * `evaluation_period`: Epochs before the optimizer is evaluated. 23 | * `evaluation_epochs`: Number of evaluation epochs. 24 | * `problem`: Problem to train on. See [Problems](#problems) section below. 25 | * `num_steps`: Number of optimization steps. 26 | * `unroll_length`: Number of unroll steps for the optimizer. 27 | * `learning_rate`: Learning rate. 28 | * `second_derivatives`: If `true`, the optimizer will try to compute second 29 | derivatives through the loss function specified by the problem. 30 | 31 | 32 | ## Evaluation 33 | 34 | ``` 35 | python evaluate.py --problem=mnist --optimizer=L2L --path=./mnist 36 | ``` 37 | 38 | Command-line flags: 39 | 40 | * `optimizer`: `Adam` or `L2L`. 41 | * `path`: Path to saved optimizer, only relevant if using the `L2L` optimizer. 42 | * `learning_rate`: Learning rate, only relevant if using `Adam` optimizer. 43 | * `num_epochs`: Number of evaluation epochs. 44 | * `seed`: Seed for random number generation. 45 | * `problem`: Problem to evaluate on. See [Problems](#problems) section below. 46 | * `num_steps`: Number of optimization steps. 47 | 48 | 49 | ## Problems 50 | 51 | The training and evaluation scripts support the following problems (see 52 | `util.py` for more details): 53 | 54 | * `simple`: One-variable quadratic function. 55 | * `simple-multi`: Two-variable quadratic function, where one of the variables 56 | is optimized using a learned optimizer and the other one using Adam. 57 | * `quadratic`: Batched ten-variable quadratic function. 58 | * `mnist`: Mnist classification using a two-layer fully connected network. 59 | * `cifar`: Cifar10 classification using a convolutional neural network. 60 | * `cifar-multi`: Cifar10 classification using a convolutional neural network, 61 | where two independent learned optimizers are used. One to optimize 62 | parameters from convolutional layers and the other one for parameters from 63 | fully connected layers. 64 | 65 | 66 | New problems can be implemented very easily. You can see in `train.py` that 67 | the `meta_minimize` method from the `MetaOptimizer` class is given a function 68 | that returns the TensorFlow operation that generates the loss function we want 69 | to minimize (see `problems.py` for an example). 70 | 71 | It's important that all operations with Python side effects (e.g. queue 72 | creation) must be done outside of the function passed to `meta_minimize`. The 73 | `cifar10` function in `problems.py` is a good example of a loss function that 74 | uses TensorFlow queues. 75 | 76 | 77 | Disclaimer: This is not an official Google product. 78 | -------------------------------------------------------------------------------- /convergence_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for L2L TensorFlow implementation.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from six.moves import xrange 22 | import tensorflow as tf 23 | 24 | import meta 25 | import problems 26 | 27 | 28 | def train(sess, minimize_ops, num_epochs, num_unrolls): 29 | """L2L training.""" 30 | step, update, reset, loss_last, x_last = minimize_ops 31 | 32 | for _ in xrange(num_epochs): 33 | sess.run(reset) 34 | for _ in xrange(num_unrolls): 35 | cost, final_x, unused_1, unused_2 = sess.run([loss_last, x_last, 36 | update, step]) 37 | 38 | return cost, final_x 39 | 40 | 41 | class L2LTest(tf.test.TestCase): 42 | """Tests L2L TensorFlow implementation.""" 43 | 44 | def testSimple(self): 45 | """Tests L2L applied to simple problem.""" 46 | problem = problems.simple() 47 | optimizer = meta.MetaOptimizer(net=dict( 48 | net="CoordinateWiseDeepLSTM", 49 | net_options={ 50 | "layers": (), 51 | # Initializing the network to zeros makes learning more stable. 52 | "initializer": "zeros" 53 | })) 54 | minimize_ops = optimizer.meta_minimize(problem, 20, learning_rate=1e-2) 55 | # L2L should solve the simple problem is less than 500 epochs. 56 | with self.test_session() as sess: 57 | sess.run(tf.global_variables_initializer()) 58 | cost, _ = train(sess, minimize_ops, 500, 5) 59 | self.assertLess(cost, 1e-5) 60 | 61 | 62 | if __name__ == "__main__": 63 | tf.test.main() 64 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Learning 2 Learn evaluation.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from six.moves import xrange 22 | import tensorflow as tf 23 | 24 | from tensorflow.contrib.learn.python.learn import monitored_session as ms 25 | 26 | import meta 27 | import util 28 | 29 | flags = tf.flags 30 | logging = tf.logging 31 | 32 | 33 | FLAGS = flags.FLAGS 34 | flags.DEFINE_string("optimizer", "L2L", "Optimizer.") 35 | flags.DEFINE_string("path", None, "Path to saved meta-optimizer network.") 36 | flags.DEFINE_integer("num_epochs", 100, "Number of evaluation epochs.") 37 | flags.DEFINE_integer("seed", None, "Seed for TensorFlow's RNG.") 38 | 39 | flags.DEFINE_string("problem", "simple", "Type of problem.") 40 | flags.DEFINE_integer("num_steps", 100, 41 | "Number of optimization steps per epoch.") 42 | flags.DEFINE_float("learning_rate", 0.001, "Learning rate.") 43 | 44 | 45 | def main(_): 46 | # Configuration. 47 | num_unrolls = FLAGS.num_steps 48 | 49 | if FLAGS.seed: 50 | tf.set_random_seed(FLAGS.seed) 51 | 52 | # Problem. 53 | problem, net_config, net_assignments = util.get_config(FLAGS.problem, 54 | FLAGS.path) 55 | 56 | # Optimizer setup. 57 | if FLAGS.optimizer == "Adam": 58 | cost_op = problem() 59 | problem_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 60 | problem_reset = tf.variables_initializer(problem_vars) 61 | 62 | optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate) 63 | optimizer_reset = tf.variables_initializer(optimizer.get_slot_names()) 64 | update = optimizer.minimize(cost_op) 65 | reset = [problem_reset, optimizer_reset] 66 | elif FLAGS.optimizer == "L2L": 67 | if FLAGS.path is None: 68 | logging.warning("Evaluating untrained L2L optimizer") 69 | optimizer = meta.MetaOptimizer(**net_config) 70 | meta_loss = optimizer.meta_loss(problem, 1, net_assignments=net_assignments) 71 | _, update, reset, cost_op, _ = meta_loss 72 | else: 73 | raise ValueError("{} is not a valid optimizer".format(FLAGS.optimizer)) 74 | 75 | with ms.MonitoredSession() as sess: 76 | # Prevent accidental changes to the graph. 77 | tf.get_default_graph().finalize() 78 | 79 | total_time = 0 80 | total_cost = 0 81 | for _ in xrange(FLAGS.num_epochs): 82 | # Training. 83 | time, cost = util.run_epoch(sess, cost_op, [update], reset, 84 | num_unrolls) 85 | total_time += time 86 | total_cost += cost 87 | 88 | # Results. 89 | util.print_stats("Epoch {}".format(FLAGS.num_epochs), total_cost, 90 | total_time, FLAGS.num_epochs) 91 | 92 | 93 | if __name__ == "__main__": 94 | tf.app.run() 95 | -------------------------------------------------------------------------------- /meta.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Learning to learn (meta) optimizer.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import contextlib 23 | import os 24 | 25 | import mock 26 | import sonnet as snt 27 | import tensorflow as tf 28 | 29 | from tensorflow.python.framework import ops 30 | from tensorflow.python.util import nest 31 | 32 | import networks 33 | 34 | 35 | def _nested_assign(ref, value): 36 | """Returns a nested collection of TensorFlow assign operations. 37 | 38 | Args: 39 | ref: Nested collection of TensorFlow variables. 40 | value: Values to be assigned to the variables. Must have the same structure 41 | as `ref`. 42 | 43 | Returns: 44 | Nested collection (same structure as `ref`) of TensorFlow assign operations. 45 | 46 | Raises: 47 | ValueError: If `ref` and `values` have different structures. 48 | """ 49 | if isinstance(ref, list) or isinstance(ref, tuple): 50 | if len(ref) != len(value): 51 | raise ValueError("ref and value have different lengths.") 52 | result = [_nested_assign(r, v) for r, v in zip(ref, value)] 53 | if isinstance(ref, tuple): 54 | return tuple(result) 55 | return result 56 | else: 57 | return tf.assign(ref, value) 58 | 59 | 60 | def _nested_variable(init, name=None, trainable=False): 61 | """Returns a nested collection of TensorFlow variables. 62 | 63 | Args: 64 | init: Nested collection of TensorFlow initializers. 65 | name: Variable name. 66 | trainable: Make variables trainable (`False` by default). 67 | 68 | Returns: 69 | Nested collection (same structure as `init`) of TensorFlow variables. 70 | """ 71 | if isinstance(init, list) or isinstance(init, tuple): 72 | result = [_nested_variable(i, name, trainable) for i in init] 73 | if isinstance(init, tuple): 74 | return tuple(result) 75 | return result 76 | else: 77 | return tf.Variable(init, name=name, trainable=trainable) 78 | 79 | 80 | def _wrap_variable_creation(func, custom_getter): 81 | """Provides a custom getter for all variable creations.""" 82 | original_get_variable = tf.get_variable 83 | def custom_get_variable(*args, **kwargs): 84 | if hasattr(kwargs, "custom_getter"): 85 | raise AttributeError("Custom getters are not supported for optimizee " 86 | "variables.") 87 | return original_get_variable(*args, custom_getter=custom_getter, **kwargs) 88 | 89 | # Mock the get_variable method. 90 | with mock.patch("tensorflow.get_variable", custom_get_variable): 91 | return func() 92 | 93 | 94 | def _get_variables(func): 95 | """Calls func, returning any variables created, but ignoring its return value. 96 | 97 | Args: 98 | func: Function to be called. 99 | 100 | Returns: 101 | A tuple (variables, constants) where the first element is a list of 102 | trainable variables and the second is the non-trainable variables. 103 | """ 104 | variables = [] 105 | constants = [] 106 | 107 | def custom_getter(getter, name, **kwargs): 108 | trainable = kwargs["trainable"] 109 | kwargs["trainable"] = False 110 | variable = getter(name, **kwargs) 111 | if trainable: 112 | variables.append(variable) 113 | else: 114 | constants.append(variable) 115 | return variable 116 | 117 | with tf.name_scope("unused_graph"): 118 | _wrap_variable_creation(func, custom_getter) 119 | 120 | return variables, constants 121 | 122 | 123 | def _make_with_custom_variables(func, variables): 124 | """Calls func and replaces any trainable variables. 125 | 126 | This returns the output of func, but whenever `get_variable` is called it 127 | will replace any trainable variables with the tensors in `variables`, in the 128 | same order. Non-trainable variables will re-use any variables already 129 | created. 130 | 131 | Args: 132 | func: Function to be called. 133 | variables: A list of tensors replacing the trainable variables. 134 | 135 | Returns: 136 | The return value of func is returned. 137 | """ 138 | variables = collections.deque(variables) 139 | 140 | def custom_getter(getter, name, **kwargs): 141 | if kwargs["trainable"]: 142 | return variables.popleft() 143 | else: 144 | kwargs["reuse"] = True 145 | return getter(name, **kwargs) 146 | 147 | return _wrap_variable_creation(func, custom_getter) 148 | 149 | 150 | MetaLoss = collections.namedtuple("MetaLoss", "loss, update, reset, fx, x") 151 | MetaStep = collections.namedtuple("MetaStep", "step, update, reset, fx, x") 152 | 153 | 154 | def _make_nets(variables, config, net_assignments): 155 | """Creates the optimizer networks. 156 | 157 | Args: 158 | variables: A list of variables to be optimized. 159 | config: A dictionary of network configurations, each of which will be 160 | passed to networks.Factory to construct a single optimizer net. 161 | net_assignments: A list of tuples where each tuple is of the form (netid, 162 | variable_names) and is used to assign variables to networks. netid must 163 | be a key in config. 164 | 165 | Returns: 166 | A tuple (nets, keys, subsets) where nets is a dictionary of created 167 | optimizer nets such that the net with key keys[i] should be applied to the 168 | subset of variables listed in subsets[i]. 169 | 170 | Raises: 171 | ValueError: If net_assignments is None and the configuration defines more 172 | than one network. 173 | """ 174 | # create a dictionary which maps a variable name to its index within the 175 | # list of variables. 176 | name_to_index = dict((v.name.split(":")[0], i) 177 | for i, v in enumerate(variables)) 178 | 179 | if net_assignments is None: 180 | if len(config) != 1: 181 | raise ValueError("Default net_assignments can only be used if there is " 182 | "a single net config.") 183 | 184 | with tf.variable_scope("vars_optimizer"): 185 | key = next(iter(config)) 186 | kwargs = config[key] 187 | net = networks.factory(**kwargs) 188 | 189 | nets = {key: net} 190 | keys = [key] 191 | subsets = [range(len(variables))] 192 | else: 193 | nets = {} 194 | keys = [] 195 | subsets = [] 196 | with tf.variable_scope("vars_optimizer"): 197 | for key, names in net_assignments: 198 | if key in nets: 199 | raise ValueError("Repeated netid in net_assigments.") 200 | nets[key] = networks.factory(**config[key]) 201 | subset = [name_to_index[name] for name in names] 202 | keys.append(key) 203 | subsets.append(subset) 204 | print("Net: {}, Subset: {}".format(key, subset)) 205 | 206 | # subsets should be a list of disjoint subsets (as lists!) of the variables 207 | # and nets should be a list of networks to apply to each subset. 208 | return nets, keys, subsets 209 | 210 | 211 | class MetaOptimizer(object): 212 | """Learning to learn (meta) optimizer. 213 | 214 | Optimizer which has an internal RNN which takes as input, at each iteration, 215 | the gradient of the function being minimized and returns a step direction. 216 | This optimizer can then itself be optimized to learn optimization on a set of 217 | tasks. 218 | """ 219 | 220 | def __init__(self, **kwargs): 221 | """Creates a MetaOptimizer. 222 | 223 | Args: 224 | **kwargs: A set of keyword arguments mapping network identifiers (the 225 | keys) to parameters that will be passed to networks.Factory (see docs 226 | for more info). These can be used to assign different optimizee 227 | parameters to different optimizers (see net_assignments in the 228 | meta_loss method). 229 | """ 230 | self._nets = None 231 | 232 | if not kwargs: 233 | # Use a default coordinatewise network if nothing is given. this allows 234 | # for no network spec and no assignments. 235 | self._config = { 236 | "coordinatewise": { 237 | "net": "CoordinateWiseDeepLSTM", 238 | "net_options": { 239 | "layers": (20, 20), 240 | "preprocess_name": "LogAndSign", 241 | "preprocess_options": {"k": 5}, 242 | "scale": 0.01, 243 | }}} 244 | else: 245 | self._config = kwargs 246 | 247 | def save(self, sess, path=None): 248 | """Save meta-optimizer.""" 249 | result = {} 250 | for k, net in self._nets.items(): 251 | if path is None: 252 | filename = None 253 | key = k 254 | else: 255 | filename = os.path.join(path, "{}.l2l".format(k)) 256 | key = filename 257 | net_vars = networks.save(net, sess, filename=filename) 258 | result[key] = net_vars 259 | return result 260 | 261 | def meta_loss(self, 262 | make_loss, 263 | len_unroll, 264 | net_assignments=None, 265 | second_derivatives=False): 266 | """Returns an operator computing the meta-loss. 267 | 268 | Args: 269 | make_loss: Callable which returns the optimizee loss; note that this 270 | should create its ops in the default graph. 271 | len_unroll: Number of steps to unroll. 272 | net_assignments: variable to optimizer mapping. If not None, it should be 273 | a list of (k, names) tuples, where k is a valid key in the kwargs 274 | passed at at construction time and names is a list of variable names. 275 | second_derivatives: Use second derivatives (default is false). 276 | 277 | Returns: 278 | namedtuple containing (loss, update, reset, fx, x) 279 | """ 280 | 281 | # Construct an instance of the problem only to grab the variables. This 282 | # loss will never be evaluated. 283 | x, constants = _get_variables(make_loss) 284 | 285 | print("Optimizee variables") 286 | print([op.name for op in x]) 287 | print("Problem variables") 288 | print([op.name for op in constants]) 289 | 290 | # Create the optimizer networks and find the subsets of variables to assign 291 | # to each optimizer. 292 | nets, net_keys, subsets = _make_nets(x, self._config, net_assignments) 293 | 294 | # Store the networks so we can save them later. 295 | self._nets = nets 296 | 297 | # Create hidden state for each subset of variables. 298 | state = [] 299 | with tf.name_scope("states"): 300 | for i, (subset, key) in enumerate(zip(subsets, net_keys)): 301 | net = nets[key] 302 | with tf.name_scope("state_{}".format(i)): 303 | state.append(_nested_variable( 304 | [net.initial_state_for_inputs(x[j], dtype=tf.float32) 305 | for j in subset], 306 | name="state", trainable=False)) 307 | 308 | def update(net, fx, x, state): 309 | """Parameter and RNN state update.""" 310 | with tf.name_scope("gradients"): 311 | gradients = tf.gradients(fx, x) 312 | 313 | # Stopping the gradient here corresponds to what was done in the 314 | # original L2L NIPS submission. However it looks like things like 315 | # BatchNorm, etc. don't support second-derivatives so we still need 316 | # this term. 317 | if not second_derivatives: 318 | gradients = [tf.stop_gradient(g) for g in gradients] 319 | 320 | with tf.name_scope("deltas"): 321 | deltas, state_next = zip(*[net(g, s) for g, s in zip(gradients, state)]) 322 | state_next = list(state_next) 323 | 324 | return deltas, state_next 325 | 326 | def time_step(t, fx_array, x, state): 327 | """While loop body.""" 328 | x_next = list(x) 329 | state_next = [] 330 | 331 | with tf.name_scope("fx"): 332 | fx = _make_with_custom_variables(make_loss, x) 333 | fx_array = fx_array.write(t, fx) 334 | 335 | with tf.name_scope("dx"): 336 | for subset, key, s_i in zip(subsets, net_keys, state): 337 | x_i = [x[j] for j in subset] 338 | deltas, s_i_next = update(nets[key], fx, x_i, s_i) 339 | 340 | for idx, j in enumerate(subset): 341 | x_next[j] += deltas[idx] 342 | state_next.append(s_i_next) 343 | 344 | with tf.name_scope("t_next"): 345 | t_next = t + 1 346 | 347 | return t_next, fx_array, x_next, state_next 348 | 349 | # Define the while loop. 350 | fx_array = tf.TensorArray(tf.float32, size=len_unroll + 1, 351 | clear_after_read=False) 352 | _, fx_array, x_final, s_final = tf.while_loop( 353 | cond=lambda t, *_: t < len_unroll, 354 | body=time_step, 355 | loop_vars=(0, fx_array, x, state), 356 | parallel_iterations=1, 357 | swap_memory=True, 358 | name="unroll") 359 | 360 | with tf.name_scope("fx"): 361 | fx_final = _make_with_custom_variables(make_loss, x_final) 362 | fx_array = fx_array.write(len_unroll, fx_final) 363 | 364 | loss = tf.reduce_sum(fx_array.stack(), name="loss") 365 | 366 | # Reset the state; should be called at the beginning of an epoch. 367 | with tf.name_scope("reset"): 368 | variables = (nest.flatten(state) + 369 | x + constants) 370 | # Empty array as part of the reset process. 371 | reset = [tf.variables_initializer(variables), fx_array.close()] 372 | 373 | # Operator to update the parameters and the RNN state after our loop, but 374 | # during an epoch. 375 | with tf.name_scope("update"): 376 | update = (nest.flatten(_nested_assign(x, x_final)) + 377 | nest.flatten(_nested_assign(state, s_final))) 378 | 379 | # Log internal variables. 380 | for k, net in nets.items(): 381 | print("Optimizer '{}' variables".format(k)) 382 | print([op.name for op in snt.get_variables_in_module(net)]) 383 | 384 | return MetaLoss(loss, update, reset, fx_final, x_final) 385 | 386 | def meta_minimize(self, make_loss, len_unroll, learning_rate=0.01, **kwargs): 387 | """Returns an operator minimizing the meta-loss. 388 | 389 | Args: 390 | make_loss: Callable which returns the optimizee loss; note that this 391 | should create its ops in the default graph. 392 | len_unroll: Number of steps to unroll. 393 | learning_rate: Learning rate for the Adam optimizer. 394 | **kwargs: keyword arguments forwarded to meta_loss. 395 | 396 | Returns: 397 | namedtuple containing (step, update, reset, fx, x) 398 | """ 399 | info = self.meta_loss(make_loss, len_unroll, **kwargs) 400 | optimizer = tf.train.AdamOptimizer(learning_rate) 401 | step = optimizer.minimize(info.loss) 402 | return MetaStep(step, *info[1:]) 403 | -------------------------------------------------------------------------------- /meta_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for L2L meta-optimizer.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import tempfile 23 | 24 | from nose_parameterized import parameterized 25 | import numpy as np 26 | from six.moves import xrange 27 | import sonnet as snt 28 | import tensorflow as tf 29 | 30 | import meta 31 | import problems 32 | 33 | 34 | def train(sess, minimize_ops, num_epochs, num_unrolls): 35 | """L2L training.""" 36 | step, update, reset, loss_last, x_last = minimize_ops 37 | 38 | for _ in xrange(num_epochs): 39 | sess.run(reset) 40 | for _ in xrange(num_unrolls): 41 | cost, final_x, unused_1, unused_2 = sess.run([loss_last, x_last, 42 | update, step]) 43 | 44 | return cost, final_x 45 | 46 | 47 | class L2LTest(tf.test.TestCase): 48 | """Tests L2L meta-optimizer.""" 49 | 50 | def testResults(self): 51 | """Tests reproducibility of Torch results.""" 52 | problem = problems.simple() 53 | optimizer = meta.MetaOptimizer(net=dict( 54 | net="CoordinateWiseDeepLSTM", 55 | net_options={ 56 | "layers": (), 57 | "initializer": "zeros" 58 | })) 59 | minimize_ops = optimizer.meta_minimize(problem, 5) 60 | with self.test_session() as sess: 61 | sess.run(tf.global_variables_initializer()) 62 | cost, final_x = train(sess, minimize_ops, 1, 2) 63 | 64 | # Torch results 65 | torch_cost = 0.7325327 66 | torch_final_x = 0.8559 67 | 68 | self.assertAlmostEqual(cost, torch_cost, places=4) 69 | self.assertAlmostEqual(final_x[0], torch_final_x, places=4) 70 | 71 | @parameterized.expand([ 72 | # Shared optimizer. 73 | ( 74 | None, 75 | { 76 | "net": { 77 | "net": "CoordinateWiseDeepLSTM", 78 | "net_options": {"layers": (1, 1,)} 79 | } 80 | } 81 | ), 82 | # Explicit sharing. 83 | ( 84 | [("net", ["x_0", "x_1"])], 85 | { 86 | "net": { 87 | "net": "CoordinateWiseDeepLSTM", 88 | "net_options": {"layers": (1,)} 89 | } 90 | } 91 | ), 92 | # Different optimizers. 93 | ( 94 | [("net1", ["x_0"]), ("net2", ["x_1"])], 95 | { 96 | "net1": { 97 | "net": "CoordinateWiseDeepLSTM", 98 | "net_options": {"layers": (1,)} 99 | }, 100 | "net2": {"net": "Adam"} 101 | } 102 | ), 103 | # Different optimizers for the same variable. 104 | ( 105 | [("net1", ["x_0"]), ("net2", ["x_0"])], 106 | { 107 | "net1": { 108 | "net": "CoordinateWiseDeepLSTM", 109 | "net_options": {"layers": (1,)} 110 | }, 111 | "net2": { 112 | "net": "CoordinateWiseDeepLSTM", 113 | "net_options": {"layers": (1,)} 114 | } 115 | } 116 | ), 117 | ]) 118 | def testMultiOptimizer(self, net_assignments, net_config): 119 | """Tests different variable->net mappings in multi-optimizer problem.""" 120 | problem = problems.simple_multi_optimizer(num_dims=2) 121 | optimizer = meta.MetaOptimizer(**net_config) 122 | minimize_ops = optimizer.meta_minimize(problem, 3, 123 | net_assignments=net_assignments) 124 | with self.test_session() as sess: 125 | sess.run(tf.global_variables_initializer()) 126 | train(sess, minimize_ops, 1, 2) 127 | 128 | def testSecondDerivatives(self): 129 | """Tests second derivatives for simple problem.""" 130 | problem = problems.simple() 131 | optimizer = meta.MetaOptimizer(net=dict( 132 | net="CoordinateWiseDeepLSTM", 133 | net_options={"layers": ()})) 134 | minimize_ops = optimizer.meta_minimize(problem, 3, 135 | second_derivatives=True) 136 | with self.test_session() as sess: 137 | sess.run(tf.global_variables_initializer()) 138 | train(sess, minimize_ops, 1, 2) 139 | 140 | def testConvolutional(self): 141 | """Tests L2L applied to problem with convolutions.""" 142 | kernel_shape = 4 143 | def convolutional_problem(): 144 | conv = snt.Conv2D(output_channels=1, 145 | kernel_shape=kernel_shape, 146 | stride=1, 147 | name="conv") 148 | output = conv(tf.random_normal((100, 100, 3, 10))) 149 | return tf.reduce_sum(output) 150 | 151 | net_config = { 152 | "conv": { 153 | "net": "KernelDeepLSTM", 154 | "net_options": { 155 | "kernel_shape": [kernel_shape] * 2, 156 | "layers": (5,) 157 | }, 158 | }, 159 | } 160 | optimizer = meta.MetaOptimizer(**net_config) 161 | minimize_ops = optimizer.meta_minimize( 162 | convolutional_problem, 3, 163 | net_assignments=[("conv", ["conv/w"])] 164 | ) 165 | with self.test_session() as sess: 166 | sess.run(tf.global_variables_initializer()) 167 | train(sess, minimize_ops, 1, 2) 168 | 169 | def testWhileLoopProblem(self): 170 | """Tests L2L applied to problem with while loop.""" 171 | def while_loop_problem(): 172 | x = tf.get_variable("x", shape=[], initializer=tf.ones_initializer()) 173 | 174 | # Strange way of squaring the variable. 175 | _, x_squared = tf.while_loop( 176 | cond=lambda t, _: t < 1, 177 | body=lambda t, x: (t + 1, x * x), 178 | loop_vars=(0, x), 179 | name="loop") 180 | return x_squared 181 | 182 | optimizer = meta.MetaOptimizer(net=dict( 183 | net="CoordinateWiseDeepLSTM", 184 | net_options={"layers": ()})) 185 | minimize_ops = optimizer.meta_minimize(while_loop_problem, 3) 186 | with self.test_session() as sess: 187 | sess.run(tf.global_variables_initializer()) 188 | train(sess, minimize_ops, 1, 2) 189 | 190 | def testSaveAndLoad(self): 191 | """Tests saving and loading a meta-optimizer.""" 192 | layers = (2, 3) 193 | net_options = {"layers": layers, "initializer": "zeros"} 194 | num_unrolls = 2 195 | num_epochs = 1 196 | 197 | problem = problems.simple() 198 | 199 | # Original optimizer. 200 | with tf.Graph().as_default() as g1: 201 | optimizer = meta.MetaOptimizer(net=dict( 202 | net="CoordinateWiseDeepLSTM", 203 | net_options=net_options)) 204 | minimize_ops = optimizer.meta_minimize(problem, 3) 205 | 206 | with self.test_session(graph=g1) as sess: 207 | sess.run(tf.global_variables_initializer()) 208 | train(sess, minimize_ops, 1, 2) 209 | 210 | # Save optimizer. 211 | tmp_dir = tempfile.mkdtemp() 212 | save_result = optimizer.save(sess, path=tmp_dir) 213 | net_path = next(iter(save_result)) 214 | 215 | # Retrain original optimizer. 216 | cost, x = train(sess, minimize_ops, num_unrolls, num_epochs) 217 | 218 | # Load optimizer and retrain in a new session. 219 | with tf.Graph().as_default() as g2: 220 | optimizer = meta.MetaOptimizer(net=dict( 221 | net="CoordinateWiseDeepLSTM", 222 | net_options=net_options, 223 | net_path=net_path)) 224 | minimize_ops = optimizer.meta_minimize(problem, 3) 225 | 226 | with self.test_session(graph=g2) as sess: 227 | sess.run(tf.global_variables_initializer()) 228 | cost_loaded, x_loaded = train(sess, minimize_ops, num_unrolls, num_epochs) 229 | 230 | # The last cost should be the same. 231 | self.assertAlmostEqual(cost, cost_loaded, places=3) 232 | self.assertAlmostEqual(x[0], x_loaded[0], places=3) 233 | 234 | # Cleanup. 235 | os.remove(net_path) 236 | os.rmdir(tmp_dir) 237 | 238 | 239 | if __name__ == "__main__": 240 | tf.test.main() 241 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Learning 2 Learn meta-optimizer networks.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import abc 22 | import collections 23 | import sys 24 | 25 | import dill as pickle 26 | import numpy as np 27 | import six 28 | import sonnet as snt 29 | import tensorflow as tf 30 | 31 | import preprocess 32 | 33 | 34 | def factory(net, net_options=(), net_path=None): 35 | """Network factory.""" 36 | 37 | net_class = getattr(sys.modules[__name__], net) 38 | net_options = dict(net_options) 39 | 40 | if net_path: 41 | with open(net_path, "rb") as f: 42 | net_options["initializer"] = pickle.load(f) 43 | 44 | return net_class(**net_options) 45 | 46 | 47 | def save(network, sess, filename=None): 48 | """Save the variables contained by a network to disk.""" 49 | to_save = collections.defaultdict(dict) 50 | variables = snt.get_variables_in_module(network) 51 | 52 | for v in variables: 53 | split = v.name.split(":")[0].split("/") 54 | module_name = split[-2] 55 | variable_name = split[-1] 56 | to_save[module_name][variable_name] = v.eval(sess) 57 | 58 | if filename: 59 | with open(filename, "wb") as f: 60 | pickle.dump(to_save, f) 61 | 62 | return to_save 63 | 64 | 65 | @six.add_metaclass(abc.ABCMeta) 66 | class Network(snt.RNNCore): 67 | """Base class for meta-optimizer networks.""" 68 | 69 | @abc.abstractmethod 70 | def initial_state_for_inputs(self, inputs, **kwargs): 71 | """Initial state given inputs.""" 72 | pass 73 | 74 | 75 | def _convert_to_initializer(initializer): 76 | """Returns a TensorFlow initializer. 77 | 78 | * Corresponding TensorFlow initializer when the argument is a string (e.g. 79 | "zeros" -> `tf.zeros_initializer`). 80 | * `tf.constant_initializer` when the argument is a `numpy` `array`. 81 | * Identity when the argument is a TensorFlow initializer. 82 | 83 | Args: 84 | initializer: `string`, `numpy` `array` or TensorFlow initializer. 85 | 86 | Returns: 87 | TensorFlow initializer. 88 | """ 89 | 90 | if isinstance(initializer, str): 91 | return getattr(tf, initializer + "_initializer")(dtype=tf.float32) 92 | elif isinstance(initializer, np.ndarray): 93 | return tf.constant_initializer(initializer) 94 | else: 95 | return initializer 96 | 97 | 98 | def _get_initializers(initializers, fields): 99 | """Produces a nn initialization `dict` (see Linear docs for a example). 100 | 101 | Grabs initializers for relevant fields if the first argument is a `dict` or 102 | reuses the same initializer for all fields otherwise. All initializers are 103 | processed using `_convert_to_initializer`. 104 | 105 | Args: 106 | initializers: Initializer or dictionary. 107 | fields: Fields nn is expecting for module initialization. 108 | 109 | Returns: 110 | nn initialization dictionary. 111 | """ 112 | 113 | result = {} 114 | for f in fields: 115 | if isinstance(initializers, dict): 116 | if f in initializers: 117 | # Variable-specific initializer. 118 | result[f] = _convert_to_initializer(initializers[f]) 119 | else: 120 | # Common initiliazer for all variables. 121 | result[f] = _convert_to_initializer(initializers) 122 | 123 | return result 124 | 125 | 126 | def _get_layer_initializers(initializers, layer_name, fields): 127 | """Produces a nn initialization dictionary for a layer. 128 | 129 | Calls `_get_initializers using initializers[layer_name]` if `layer_name` is a 130 | valid key or using initializers otherwise (reuses initializers between 131 | layers). 132 | 133 | Args: 134 | initializers: Initializer, dictionary, 135 | dictionary. 136 | layer_name: Layer name. 137 | fields: Fields nn is expecting for module initialization. 138 | 139 | Returns: 140 | nn initialization dictionary. 141 | """ 142 | 143 | # No initializers specified. 144 | if initializers is None: 145 | return None 146 | 147 | # Layer-specific initializer. 148 | if isinstance(initializers, dict) and layer_name in initializers: 149 | return _get_initializers(initializers[layer_name], fields) 150 | 151 | return _get_initializers(initializers, fields) 152 | 153 | 154 | class StandardDeepLSTM(Network): 155 | """LSTM layers with a Linear layer on top.""" 156 | 157 | def __init__(self, output_size, layers, preprocess_name="identity", 158 | preprocess_options=None, scale=1.0, initializer=None, 159 | name="deep_lstm"): 160 | """Creates an instance of `StandardDeepLSTM`. 161 | 162 | Args: 163 | output_size: Output sizes of the final linear layer. 164 | layers: Output sizes of LSTM layers. 165 | preprocess_name: Gradient preprocessing class name (in `l2l.preprocess` or 166 | tf modules). Default is `tf.identity`. 167 | preprocess_options: Gradient preprocessing options. 168 | scale: Gradient scaling (default is 1.0). 169 | initializer: Variable initializer for linear layer. See `snt.Linear` and 170 | `snt.LSTM` docs for more info. This parameter can be a string (e.g. 171 | "zeros" will be converted to tf.zeros_initializer). 172 | name: Module name. 173 | """ 174 | super(StandardDeepLSTM, self).__init__(name=name) 175 | 176 | self._output_size = output_size 177 | self._scale = scale 178 | 179 | if hasattr(preprocess, preprocess_name): 180 | preprocess_class = getattr(preprocess, preprocess_name) 181 | self._preprocess = preprocess_class(**preprocess_options) 182 | else: 183 | self._preprocess = getattr(tf, preprocess_name) 184 | 185 | with tf.variable_scope(self._template.variable_scope): 186 | self._cores = [] 187 | for i, size in enumerate(layers, start=1): 188 | name = "lstm_{}".format(i) 189 | init = _get_layer_initializers(initializer, name, 190 | ("w_gates", "b_gates")) 191 | self._cores.append(snt.LSTM(size, name=name, initializers=init)) 192 | self._rnn = snt.DeepRNN(self._cores, skip_connections=False, 193 | name="deep_rnn") 194 | 195 | init = _get_layer_initializers(initializer, "linear", ("w", "b")) 196 | self._linear = snt.Linear(output_size, name="linear", initializers=init) 197 | 198 | def _build(self, inputs, prev_state): 199 | """Connects the `StandardDeepLSTM` module into the graph. 200 | 201 | Args: 202 | inputs: 2D `Tensor` ([batch_size, input_size]). 203 | prev_state: `DeepRNN` state. 204 | 205 | Returns: 206 | `Tensor` shaped as `inputs`. 207 | """ 208 | # Adds preprocessing dimension and preprocess. 209 | inputs = self._preprocess(tf.expand_dims(inputs, -1)) 210 | # Incorporates preprocessing into data dimension. 211 | inputs = tf.reshape(inputs, [inputs.get_shape().as_list()[0], -1]) 212 | output, next_state = self._rnn(inputs, prev_state) 213 | return self._linear(output) * self._scale, next_state 214 | 215 | def initial_state_for_inputs(self, inputs, **kwargs): 216 | batch_size = inputs.get_shape().as_list()[0] 217 | return self._rnn.initial_state(batch_size, **kwargs) 218 | 219 | 220 | class CoordinateWiseDeepLSTM(StandardDeepLSTM): 221 | """Coordinate-wise `DeepLSTM`.""" 222 | 223 | def __init__(self, name="cw_deep_lstm", **kwargs): 224 | """Creates an instance of `CoordinateWiseDeepLSTM`. 225 | 226 | Args: 227 | name: Module name. 228 | **kwargs: Additional `DeepLSTM` args. 229 | """ 230 | super(CoordinateWiseDeepLSTM, self).__init__(1, name=name, **kwargs) 231 | 232 | def _reshape_inputs(self, inputs): 233 | return tf.reshape(inputs, [-1, 1]) 234 | 235 | def _build(self, inputs, prev_state): 236 | """Connects the CoordinateWiseDeepLSTM module into the graph. 237 | 238 | Args: 239 | inputs: Arbitrarily shaped `Tensor`. 240 | prev_state: `DeepRNN` state. 241 | 242 | Returns: 243 | `Tensor` shaped as `inputs`. 244 | """ 245 | input_shape = inputs.get_shape().as_list() 246 | reshaped_inputs = self._reshape_inputs(inputs) 247 | 248 | build_fn = super(CoordinateWiseDeepLSTM, self)._build 249 | output, next_state = build_fn(reshaped_inputs, prev_state) 250 | 251 | # Recover original shape. 252 | return tf.reshape(output, input_shape), next_state 253 | 254 | def initial_state_for_inputs(self, inputs, **kwargs): 255 | reshaped_inputs = self._reshape_inputs(inputs) 256 | return super(CoordinateWiseDeepLSTM, self).initial_state_for_inputs( 257 | reshaped_inputs, **kwargs) 258 | 259 | 260 | class KernelDeepLSTM(StandardDeepLSTM): 261 | """`DeepLSTM` for convolutional filters. 262 | 263 | The inputs are assumed to be shaped as convolutional filters with an extra 264 | preprocessing dimension ([kernel_w, kernel_h, n_input_channels, 265 | n_output_channels]). 266 | """ 267 | 268 | def __init__(self, kernel_shape, name="kernel_deep_lstm", **kwargs): 269 | """Creates an instance of `KernelDeepLSTM`. 270 | 271 | Args: 272 | kernel_shape: Kernel shape (2D `tuple`). 273 | name: Module name. 274 | **kwargs: Additional `DeepLSTM` args. 275 | """ 276 | self._kernel_shape = kernel_shape 277 | output_size = np.prod(kernel_shape) 278 | super(KernelDeepLSTM, self).__init__(output_size, name=name, **kwargs) 279 | 280 | def _reshape_inputs(self, inputs): 281 | transposed_inputs = tf.transpose(inputs, perm=[2, 3, 0, 1]) 282 | return tf.reshape(transposed_inputs, [-1] + self._kernel_shape) 283 | 284 | def _build(self, inputs, prev_state): 285 | """Connects the KernelDeepLSTM module into the graph. 286 | 287 | Args: 288 | inputs: 4D `Tensor` (convolutional filter). 289 | prev_state: `DeepRNN` state. 290 | 291 | Returns: 292 | `Tensor` shaped as `inputs`. 293 | """ 294 | input_shape = inputs.get_shape().as_list() 295 | reshaped_inputs = self._reshape_inputs(inputs) 296 | 297 | build_fn = super(KernelDeepLSTM, self)._build 298 | output, next_state = build_fn(reshaped_inputs, prev_state) 299 | transposed_output = tf.transpose(output, [1, 0]) 300 | 301 | # Recover original shape. 302 | return tf.reshape(transposed_output, input_shape), next_state 303 | 304 | def initial_state_for_inputs(self, inputs, **kwargs): 305 | """Batch size given inputs.""" 306 | reshaped_inputs = self._reshape_inputs(inputs) 307 | return super(KernelDeepLSTM, self).initial_state_for_inputs( 308 | reshaped_inputs, **kwargs) 309 | 310 | 311 | class Sgd(Network): 312 | """Identity network which acts like SGD.""" 313 | 314 | def __init__(self, learning_rate=0.001, name="sgd"): 315 | """Creates an instance of the Identity optimizer network. 316 | 317 | Args: 318 | learning_rate: constant learning rate to use. 319 | name: Module name. 320 | """ 321 | super(Sgd, self).__init__(name=name) 322 | self._learning_rate = learning_rate 323 | 324 | def _build(self, inputs, _): 325 | return -self._learning_rate * inputs, [] 326 | 327 | def initial_state_for_inputs(self, inputs, **kwargs): 328 | return [] 329 | 330 | 331 | def _update_adam_estimate(estimate, value, b): 332 | return (b * estimate) + ((1 - b) * value) 333 | 334 | 335 | def _debias_adam_estimate(estimate, b, t): 336 | return estimate / (1 - tf.pow(b, t)) 337 | 338 | 339 | class Adam(Network): 340 | """Adam algorithm (https://arxiv.org/pdf/1412.6980v8.pdf).""" 341 | 342 | def __init__(self, learning_rate=1e-3, beta1=0.9, beta2=0.999, epsilon=1e-8, 343 | name="adam"): 344 | """Creates an instance of Adam.""" 345 | super(Adam, self).__init__(name=name) 346 | self._learning_rate = learning_rate 347 | self._beta1 = beta1 348 | self._beta2 = beta2 349 | self._epsilon = epsilon 350 | 351 | def _build(self, g, prev_state): 352 | """Connects the Adam module into the graph.""" 353 | b1 = self._beta1 354 | b2 = self._beta2 355 | 356 | g_shape = g.get_shape().as_list() 357 | g = tf.reshape(g, (-1, 1)) 358 | 359 | t, m, v = prev_state 360 | 361 | t_next = t + 1 362 | 363 | m_next = _update_adam_estimate(m, g, b1) 364 | m_hat = _debias_adam_estimate(m_next, b1, t_next) 365 | 366 | v_next = _update_adam_estimate(v, tf.square(g), b2) 367 | v_hat = _debias_adam_estimate(v_next, b2, t_next) 368 | 369 | update = -self._learning_rate * m_hat / (tf.sqrt(v_hat) + self._epsilon) 370 | return tf.reshape(update, g_shape), (t_next, m_next, v_next) 371 | 372 | def initial_state_for_inputs(self, inputs, dtype=tf.float32, **kwargs): 373 | batch_size = int(np.prod(inputs.get_shape().as_list())) 374 | t = tf.zeros((), dtype=dtype) 375 | m = tf.zeros((batch_size, 1), dtype=dtype) 376 | v = tf.zeros((batch_size, 1), dtype=dtype) 377 | return (t, m, v) 378 | -------------------------------------------------------------------------------- /networks_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for L2L networks.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from nose_parameterized import parameterized 22 | import numpy as np 23 | import sonnet as snt 24 | import tensorflow as tf 25 | 26 | import networks 27 | 28 | 29 | class CoordinateWiseDeepLSTMTest(tf.test.TestCase): 30 | """Tests CoordinateWiseDeepLSTM network.""" 31 | 32 | def testShape(self): 33 | shape = [10, 5] 34 | gradients = tf.random_normal(shape) 35 | net = networks.CoordinateWiseDeepLSTM(layers=(1, 1)) 36 | state = net.initial_state_for_inputs(gradients) 37 | update, _ = net(gradients, state) 38 | self.assertEqual(update.get_shape().as_list(), shape) 39 | 40 | def testTrainable(self): 41 | """Tests the network contains trainable variables.""" 42 | shape = [10, 5] 43 | gradients = tf.random_normal(shape) 44 | net = networks.CoordinateWiseDeepLSTM(layers=(1,)) 45 | state = net.initial_state_for_inputs(gradients) 46 | net(gradients, state) 47 | # Weights and biases for two layers. 48 | variables = snt.get_variables_in_module(net) 49 | self.assertEqual(len(variables), 4) 50 | 51 | @parameterized.expand([ 52 | ["zeros"], 53 | [{"w": "zeros", "b": "zeros", "bad": "bad"}], 54 | [{"w": tf.zeros_initializer(), "b": np.array([0])}], 55 | [{"linear": {"w": tf.zeros_initializer(), "b": "zeros"}}] 56 | ]) 57 | def testResults(self, initializer): 58 | """Tests zero updates when last layer is initialized to zero.""" 59 | shape = [10] 60 | gradients = tf.random_normal(shape) 61 | net = networks.CoordinateWiseDeepLSTM(layers=(1, 1), 62 | initializer=initializer) 63 | state = net.initial_state_for_inputs(gradients) 64 | update, _ = net(gradients, state) 65 | 66 | with self.test_session() as sess: 67 | sess.run(tf.global_variables_initializer()) 68 | update_np = sess.run(update) 69 | self.assertAllEqual(update_np, np.zeros(shape)) 70 | 71 | 72 | class KernelDeepLSTMTest(tf.test.TestCase): 73 | """Tests KernelDeepLSTMTest network.""" 74 | 75 | def testShape(self): 76 | kernel_shape = [5, 5] 77 | shape = kernel_shape + [2, 2] # The input has to be 4-dimensional. 78 | gradients = tf.random_normal(shape) 79 | net = networks.KernelDeepLSTM(layers=(1, 1), kernel_shape=kernel_shape) 80 | state = net.initial_state_for_inputs(gradients) 81 | update, _ = net(gradients, state) 82 | self.assertEqual(update.get_shape().as_list(), shape) 83 | 84 | def testTrainable(self): 85 | """Tests the network contains trainable variables.""" 86 | kernel_shape = [5, 5] 87 | shape = kernel_shape + [2, 2] # The input has to be 4-dimensional. 88 | gradients = tf.random_normal(shape) 89 | net = networks.KernelDeepLSTM(layers=(1,), kernel_shape=kernel_shape) 90 | state = net.initial_state_for_inputs(gradients) 91 | net(gradients, state) 92 | # Weights and biases for two layers. 93 | variables = snt.get_variables_in_module(net) 94 | self.assertEqual(len(variables), 4) 95 | 96 | @parameterized.expand([ 97 | ["zeros"], 98 | [{"w": "zeros", "b": "zeros", "bad": "bad"}], 99 | [{"w": tf.zeros_initializer(), "b": np.array([0])}], 100 | [{"linear": {"w": tf.zeros_initializer(), "b": "zeros"}}] 101 | ]) 102 | def testResults(self, initializer): 103 | """Tests zero updates when last layer is initialized to zero.""" 104 | kernel_shape = [5, 5] 105 | shape = kernel_shape + [2, 2] # The input has to be 4-dimensional. 106 | gradients = tf.random_normal(shape) 107 | net = networks.KernelDeepLSTM(layers=(1, 1), 108 | kernel_shape=kernel_shape, 109 | initializer=initializer) 110 | state = net.initial_state_for_inputs(gradients) 111 | update, _ = net(gradients, state) 112 | 113 | with self.test_session() as sess: 114 | sess.run(tf.global_variables_initializer()) 115 | update_np = sess.run(update) 116 | self.assertAllEqual(update_np, np.zeros(shape)) 117 | 118 | 119 | class SgdTest(tf.test.TestCase): 120 | """Tests Sgd network.""" 121 | 122 | def testShape(self): 123 | shape = [10, 5] 124 | gradients = tf.random_normal(shape) 125 | net = networks.Sgd() 126 | state = net.initial_state_for_inputs(gradients) 127 | update, _ = net(gradients, state) 128 | self.assertEqual(update.get_shape().as_list(), shape) 129 | 130 | def testNonTrainable(self): 131 | """Tests the network doesn't contain trainable variables.""" 132 | shape = [10, 5] 133 | gradients = tf.random_normal(shape) 134 | net = networks.Sgd() 135 | state = net.initial_state_for_inputs(gradients) 136 | net(gradients, state) 137 | variables = snt.get_variables_in_module(net) 138 | self.assertEqual(len(variables), 0) 139 | 140 | def testResults(self): 141 | """Tests network produces zero updates with learning rate equal to zero.""" 142 | shape = [10] 143 | learning_rate = 0.01 144 | gradients = tf.random_normal(shape) 145 | net = networks.Sgd(learning_rate=learning_rate) 146 | state = net.initial_state_for_inputs(gradients) 147 | update, _ = net(gradients, state) 148 | 149 | with self.test_session() as sess: 150 | gradients_np, update_np = sess.run([gradients, update]) 151 | self.assertAllEqual(update_np, -learning_rate * gradients_np) 152 | 153 | 154 | class AdamTest(tf.test.TestCase): 155 | """Tests Adam network.""" 156 | 157 | def testShape(self): 158 | shape = [10, 5] 159 | gradients = tf.random_normal(shape) 160 | net = networks.Adam() 161 | state = net.initial_state_for_inputs(gradients) 162 | update, _ = net(gradients, state) 163 | self.assertEqual(update.get_shape().as_list(), shape) 164 | 165 | def testNonTrainable(self): 166 | """Tests the network doesn't contain trainable variables.""" 167 | shape = [10, 5] 168 | gradients = tf.random_normal(shape) 169 | net = networks.Adam() 170 | state = net.initial_state_for_inputs(gradients) 171 | net(gradients, state) 172 | variables = snt.get_variables_in_module(net) 173 | self.assertEqual(len(variables), 0) 174 | 175 | def testZeroLearningRate(self): 176 | """Tests network produces zero updates with learning rate equal to zero.""" 177 | shape = [10] 178 | gradients = tf.random_normal(shape) 179 | net = networks.Adam(learning_rate=0) 180 | state = net.initial_state_for_inputs(gradients) 181 | update, _ = net(gradients, state) 182 | 183 | with self.test_session() as sess: 184 | update_np = sess.run(update) 185 | self.assertAllEqual(update_np, np.zeros(shape)) 186 | 187 | 188 | if __name__ == "__main__": 189 | tf.test.main() 190 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Learning 2 Learn preprocessing modules.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | import sonnet as snt 23 | import tensorflow as tf 24 | 25 | 26 | class Clamp(snt.AbstractModule): 27 | 28 | def __init__(self, min_value=None, max_value=None, name="clamp"): 29 | super(Clamp, self).__init__(name=name) 30 | self._min = min_value 31 | self._max = max_value 32 | 33 | def _build(self, inputs): 34 | output = inputs 35 | if self._min is not None: 36 | output = tf.maximum(output, self._min) 37 | if self._max is not None: 38 | output = tf.minimum(output, self._max) 39 | return output 40 | 41 | 42 | class LogAndSign(snt.AbstractModule): 43 | """Log and sign preprocessing. 44 | 45 | As described in https://arxiv.org/pdf/1606.04474v1.pdf (Appendix A). 46 | """ 47 | 48 | def __init__(self, k, name="preprocess_log"): 49 | super(LogAndSign, self).__init__(name=name) 50 | self._k = k 51 | 52 | def _build(self, gradients): 53 | """Connects the LogAndSign module into the graph. 54 | 55 | Args: 56 | gradients: `Tensor` of gradients with shape `[d_1, ..., d_n]`. 57 | 58 | Returns: 59 | `Tensor` with shape `[d_1, ..., d_n-1, 2 * d_n]`. The first `d_n` elements 60 | along the nth dimension correspond to the log output and the remaining 61 | `d_n` elements to the sign output. 62 | """ 63 | eps = np.finfo(gradients.dtype.as_numpy_dtype).eps 64 | ndims = gradients.get_shape().ndims 65 | 66 | log = tf.log(tf.abs(gradients) + eps) 67 | clamped_log = Clamp(min_value=-1.0)(log / self._k) # pylint: disable=not-callable 68 | sign = Clamp(min_value=-1.0, max_value=1.0)(gradients * np.exp(self._k)) # pylint: disable=not-callable 69 | 70 | return tf.concat([clamped_log, sign], ndims - 1) 71 | -------------------------------------------------------------------------------- /preprocess_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for L2L preprocessors.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | import preprocess 25 | 26 | 27 | class ClampTest(tf.test.TestCase): 28 | """Tests Clamp module.""" 29 | 30 | def testShape(self): 31 | shape = [2, 3] 32 | inputs = tf.random_normal(shape) 33 | clamp = preprocess.Clamp(min_value=-1.0, max_value=1.0) 34 | output = clamp(inputs) 35 | self.assertEqual(output.get_shape().as_list(), shape) 36 | 37 | def testMin(self): 38 | shape = [100] 39 | inputs = tf.random_normal(shape) 40 | clamp = preprocess.Clamp(min_value=0.0) 41 | output = clamp(inputs) 42 | 43 | with self.test_session() as sess: 44 | output_np = sess.run(output) 45 | self.assertTrue(np.all(np.greater_equal(output_np, np.zeros(shape)))) 46 | 47 | def testMax(self): 48 | shape = [100] 49 | inputs = tf.random_normal(shape) 50 | clamp = preprocess.Clamp(max_value=0.0) 51 | output = clamp(inputs) 52 | 53 | with self.test_session() as sess: 54 | output_np = sess.run(output) 55 | self.assertTrue(np.all(np.less_equal(output_np, np.zeros(shape)))) 56 | 57 | def testMinAndMax(self): 58 | shape = [100] 59 | inputs = tf.random_normal(shape) 60 | clamp = preprocess.Clamp(min_value=0.0, max_value=0.0) 61 | output = clamp(inputs) 62 | 63 | with self.test_session() as sess: 64 | output_np = sess.run(output) 65 | self.assertAllEqual(output_np, np.zeros(shape)) 66 | 67 | 68 | class LogAndSignTest(tf.test.TestCase): 69 | """Tests LogAndSign module.""" 70 | 71 | def testShape(self): 72 | shape = [2, 3] 73 | inputs = tf.random_normal(shape) 74 | module = preprocess.LogAndSign(k=1) 75 | output = module(inputs) 76 | self.assertEqual(output.get_shape().as_list(), shape[:-1] + [shape[-1] * 2]) 77 | 78 | def testLogWithOnes(self): 79 | shape = [1] 80 | inputs = tf.ones(shape) 81 | module = preprocess.LogAndSign(k=10) 82 | output = module(inputs) 83 | 84 | with self.test_session() as sess: 85 | output_np = sess.run(output) 86 | log_np = output_np[0] 87 | self.assertAlmostEqual(log_np, 0.0) 88 | 89 | def testSign(self): 90 | shape = [2, 1] 91 | inputs = tf.random_normal(shape) 92 | module = preprocess.LogAndSign(k=1) 93 | output = module(inputs) 94 | 95 | with self.test_session() as sess: 96 | inputs_np, output_np = sess.run([inputs, output]) 97 | sign_np = output_np[:, 1:] 98 | self.assertAllEqual(np.sign(sign_np), np.sign(inputs_np)) 99 | 100 | 101 | if __name__ == "__main__": 102 | tf.test.main() 103 | -------------------------------------------------------------------------------- /problems.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Learning 2 Learn problems.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import tarfile 23 | import sys 24 | 25 | from six.moves import urllib 26 | from six.moves import xrange # pylint: disable=redefined-builtin 27 | import sonnet as snt 28 | import tensorflow as tf 29 | 30 | from tensorflow.contrib.learn.python.learn.datasets import mnist as mnist_dataset 31 | 32 | 33 | _nn_initializers = { 34 | "w": tf.random_normal_initializer(mean=0, stddev=0.01), 35 | "b": tf.random_normal_initializer(mean=0, stddev=0.01), 36 | } 37 | 38 | 39 | def simple(): 40 | """Simple problem: f(x) = x^2.""" 41 | 42 | def build(): 43 | """Builds loss graph.""" 44 | x = tf.get_variable( 45 | "x", 46 | shape=[], 47 | dtype=tf.float32, 48 | initializer=tf.ones_initializer()) 49 | return tf.square(x, name="x_squared") 50 | 51 | return build 52 | 53 | 54 | def simple_multi_optimizer(num_dims=2): 55 | """Multidimensional simple problem.""" 56 | 57 | def get_coordinate(i): 58 | return tf.get_variable("x_{}".format(i), 59 | shape=[], 60 | dtype=tf.float32, 61 | initializer=tf.ones_initializer()) 62 | 63 | def build(): 64 | coordinates = [get_coordinate(i) for i in xrange(num_dims)] 65 | x = tf.concat([tf.expand_dims(c, 0) for c in coordinates], 0) 66 | return tf.reduce_sum(tf.square(x, name="x_squared")) 67 | 68 | return build 69 | 70 | 71 | def quadratic(batch_size=128, num_dims=10, stddev=0.01, dtype=tf.float32): 72 | """Quadratic problem: f(x) = ||Wx - y||.""" 73 | 74 | def build(): 75 | """Builds loss graph.""" 76 | 77 | # Trainable variable. 78 | x = tf.get_variable( 79 | "x", 80 | shape=[batch_size, num_dims], 81 | dtype=dtype, 82 | initializer=tf.random_normal_initializer(stddev=stddev)) 83 | 84 | # Non-trainable variables. 85 | w = tf.get_variable("w", 86 | shape=[batch_size, num_dims, num_dims], 87 | dtype=dtype, 88 | initializer=tf.random_uniform_initializer(), 89 | trainable=False) 90 | y = tf.get_variable("y", 91 | shape=[batch_size, num_dims], 92 | dtype=dtype, 93 | initializer=tf.random_uniform_initializer(), 94 | trainable=False) 95 | 96 | product = tf.squeeze(tf.matmul(w, tf.expand_dims(x, -1))) 97 | return tf.reduce_mean(tf.reduce_sum((product - y) ** 2, 1)) 98 | 99 | return build 100 | 101 | 102 | def ensemble(problems, weights=None): 103 | """Ensemble of problems. 104 | 105 | Args: 106 | problems: List of problems. Each problem is specified by a dict containing 107 | the keys 'name' and 'options'. 108 | weights: Optional list of weights for each problem. 109 | 110 | Returns: 111 | Sum of (weighted) losses. 112 | 113 | Raises: 114 | ValueError: If weights has an incorrect length. 115 | """ 116 | if weights and len(weights) != len(problems): 117 | raise ValueError("len(weights) != len(problems)") 118 | 119 | build_fns = [getattr(sys.modules[__name__], p["name"])(**p["options"]) 120 | for p in problems] 121 | 122 | def build(): 123 | loss = 0 124 | for i, build_fn in enumerate(build_fns): 125 | with tf.variable_scope("problem_{}".format(i)): 126 | loss_p = build_fn() 127 | if weights: 128 | loss_p *= weights[i] 129 | loss += loss_p 130 | return loss 131 | 132 | return build 133 | 134 | 135 | def _xent_loss(output, labels): 136 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=output, 137 | labels=labels) 138 | return tf.reduce_mean(loss) 139 | 140 | 141 | def mnist(layers, # pylint: disable=invalid-name 142 | activation="sigmoid", 143 | batch_size=128, 144 | mode="train"): 145 | """Mnist classification with a multi-layer perceptron.""" 146 | 147 | if activation == "sigmoid": 148 | activation_op = tf.sigmoid 149 | elif activation == "relu": 150 | activation_op = tf.nn.relu 151 | else: 152 | raise ValueError("{} activation not supported".format(activation)) 153 | 154 | # Data. 155 | data = mnist_dataset.load_mnist() 156 | data = getattr(data, mode) 157 | images = tf.constant(data.images, dtype=tf.float32, name="MNIST_images") 158 | images = tf.reshape(images, [-1, 28, 28, 1]) 159 | labels = tf.constant(data.labels, dtype=tf.int64, name="MNIST_labels") 160 | 161 | # Network. 162 | mlp = snt.nets.MLP(list(layers) + [10], 163 | activation=activation_op, 164 | initializers=_nn_initializers) 165 | network = snt.Sequential([snt.BatchFlatten(), mlp]) 166 | 167 | def build(): 168 | indices = tf.random_uniform([batch_size], 0, data.num_examples, tf.int64) 169 | batch_images = tf.gather(images, indices) 170 | batch_labels = tf.gather(labels, indices) 171 | output = network(batch_images) 172 | return _xent_loss(output, batch_labels) 173 | 174 | return build 175 | 176 | 177 | CIFAR10_URL = "http://www.cs.toronto.edu/~kriz" 178 | CIFAR10_FILE = "cifar-10-binary.tar.gz" 179 | CIFAR10_FOLDER = "cifar-10-batches-bin" 180 | 181 | 182 | def _maybe_download_cifar10(path): 183 | """Download and extract the tarball from Alex's website.""" 184 | if not os.path.exists(path): 185 | os.makedirs(path) 186 | filepath = os.path.join(path, CIFAR10_FILE) 187 | if not os.path.exists(filepath): 188 | print("Downloading CIFAR10 dataset to {}".format(filepath)) 189 | url = os.path.join(CIFAR10_URL, CIFAR10_FILE) 190 | filepath, _ = urllib.request.urlretrieve(url, filepath) 191 | statinfo = os.stat(filepath) 192 | print("Successfully downloaded {} bytes".format(statinfo.st_size)) 193 | tarfile.open(filepath, "r:gz").extractall(path) 194 | 195 | 196 | def cifar10(path, # pylint: disable=invalid-name 197 | conv_channels=None, 198 | linear_layers=None, 199 | batch_norm=True, 200 | batch_size=128, 201 | num_threads=4, 202 | min_queue_examples=1000, 203 | mode="train"): 204 | """Cifar10 classification with a convolutional network.""" 205 | 206 | # Data. 207 | _maybe_download_cifar10(path) 208 | 209 | # Read images and labels from disk. 210 | if mode == "train": 211 | filenames = [os.path.join(path, 212 | CIFAR10_FOLDER, 213 | "data_batch_{}.bin".format(i)) 214 | for i in xrange(1, 6)] 215 | elif mode == "test": 216 | filenames = [os.path.join(path, "test_batch.bin")] 217 | else: 218 | raise ValueError("Mode {} not recognised".format(mode)) 219 | 220 | depth = 3 221 | height = 32 222 | width = 32 223 | label_bytes = 1 224 | image_bytes = depth * height * width 225 | record_bytes = label_bytes + image_bytes 226 | reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) 227 | _, record = reader.read(tf.train.string_input_producer(filenames)) 228 | record_bytes = tf.decode_raw(record, tf.uint8) 229 | 230 | label = tf.cast(tf.slice(record_bytes, [0], [label_bytes]), tf.int32) 231 | raw_image = tf.slice(record_bytes, [label_bytes], [image_bytes]) 232 | image = tf.cast(tf.reshape(raw_image, [depth, height, width]), tf.float32) 233 | # height x width x depth. 234 | image = tf.transpose(image, [1, 2, 0]) 235 | image = tf.div(image, 255) 236 | 237 | queue = tf.RandomShuffleQueue(capacity=min_queue_examples + 3 * batch_size, 238 | min_after_dequeue=min_queue_examples, 239 | dtypes=[tf.float32, tf.int32], 240 | shapes=[image.get_shape(), label.get_shape()]) 241 | enqueue_ops = [queue.enqueue([image, label]) for _ in xrange(num_threads)] 242 | tf.train.add_queue_runner(tf.train.QueueRunner(queue, enqueue_ops)) 243 | 244 | # Network. 245 | def _conv_activation(x): # pylint: disable=invalid-name 246 | return tf.nn.max_pool(tf.nn.relu(x), 247 | ksize=[1, 2, 2, 1], 248 | strides=[1, 2, 2, 1], 249 | padding="SAME") 250 | 251 | conv = snt.nets.ConvNet2D(output_channels=conv_channels, 252 | kernel_shapes=[5], 253 | strides=[1], 254 | paddings=[snt.SAME], 255 | activation=_conv_activation, 256 | activate_final=True, 257 | initializers=_nn_initializers, 258 | use_batch_norm=batch_norm) 259 | 260 | if batch_norm: 261 | linear_activation = lambda x: tf.nn.relu(snt.BatchNorm()(x)) 262 | else: 263 | linear_activation = tf.nn.relu 264 | 265 | mlp = snt.nets.MLP(list(linear_layers) + [10], 266 | activation=linear_activation, 267 | initializers=_nn_initializers) 268 | network = snt.Sequential([conv, snt.BatchFlatten(), mlp]) 269 | 270 | def build(): 271 | image_batch, label_batch = queue.dequeue_many(batch_size) 272 | label_batch = tf.reshape(label_batch, [batch_size]) 273 | 274 | output = network(image_batch) 275 | return _xent_loss(output, label_batch) 276 | 277 | return build 278 | -------------------------------------------------------------------------------- /problems_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for L2L problems.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from six.moves import xrange 22 | import tensorflow as tf 23 | 24 | from nose_parameterized import parameterized 25 | 26 | import problems 27 | 28 | 29 | class SimpleTest(tf.test.TestCase): 30 | """Tests simple problem.""" 31 | 32 | def testShape(self): 33 | problem = problems.simple() 34 | f = problem() 35 | self.assertEqual(f.get_shape().as_list(), []) 36 | 37 | def testVariables(self): 38 | problem = problems.simple() 39 | problem() 40 | variables = tf.trainable_variables() 41 | self.assertEqual(len(variables), 1) 42 | self.assertEqual(variables[0].get_shape().as_list(), []) 43 | 44 | @parameterized.expand([(-1,), (0,), (1,), (10,)]) 45 | def testValues(self, value): 46 | problem = problems.simple() 47 | f = problem() 48 | 49 | with self.test_session() as sess: 50 | output = sess.run(f, feed_dict={"x:0": value}) 51 | self.assertEqual(output, value**2) 52 | 53 | 54 | class SimpleMultiOptimizerTest(tf.test.TestCase): 55 | """Tests multi-optimizer simple problem.""" 56 | 57 | def testShape(self): 58 | num_dims = 3 59 | problem = problems.simple_multi_optimizer(num_dims=num_dims) 60 | f = problem() 61 | self.assertEqual(f.get_shape().as_list(), []) 62 | 63 | def testVariables(self): 64 | num_dims = 3 65 | problem = problems.simple_multi_optimizer(num_dims=num_dims) 66 | problem() 67 | variables = tf.trainable_variables() 68 | self.assertEqual(len(variables), num_dims) 69 | for v in variables: 70 | self.assertEqual(v.get_shape().as_list(), []) 71 | 72 | @parameterized.expand([(-1,), (0,), (1,), (10,)]) 73 | def testValues(self, value): 74 | problem = problems.simple_multi_optimizer(num_dims=1) 75 | f = problem() 76 | 77 | with self.test_session() as sess: 78 | output = sess.run(f, feed_dict={"x_0:0": value}) 79 | self.assertEqual(output, value**2) 80 | 81 | 82 | class QuadraticTest(tf.test.TestCase): 83 | """Tests Quadratic problem.""" 84 | 85 | def testShape(self): 86 | problem = problems.quadratic() 87 | f = problem() 88 | self.assertEqual(f.get_shape().as_list(), []) 89 | 90 | def testVariables(self): 91 | batch_size = 5 92 | num_dims = 3 93 | problem = problems.quadratic(batch_size=batch_size, num_dims=num_dims) 94 | problem() 95 | variables = tf.trainable_variables() 96 | self.assertEqual(len(variables), 1) 97 | self.assertEqual(variables[0].get_shape().as_list(), [batch_size, num_dims]) 98 | 99 | @parameterized.expand([(-1,), (0,), (1,), (10,)]) 100 | def testValues(self, value): 101 | problem = problems.quadratic(batch_size=1, num_dims=1) 102 | f = problem() 103 | 104 | w = 2.0 105 | y = 3.0 106 | 107 | with self.test_session() as sess: 108 | output = sess.run(f, feed_dict={"x:0": [[value]], 109 | "w:0": [[[w]]], 110 | "y:0": [[y]]}) 111 | self.assertEqual(output, ((w * value) - y)**2) 112 | 113 | 114 | class EnsembleTest(tf.test.TestCase): 115 | """Tests Ensemble problem.""" 116 | 117 | def testShape(self): 118 | num_dims = 3 119 | problem_defs = [{"name": "simple", "options": {}} for _ in xrange(num_dims)] 120 | ensemble = problems.ensemble(problem_defs) 121 | f = ensemble() 122 | self.assertEqual(f.get_shape().as_list(), []) 123 | 124 | def testVariables(self): 125 | num_dims = 3 126 | problem_defs = [{"name": "simple", "options": {}} for _ in xrange(num_dims)] 127 | ensemble = problems.ensemble(problem_defs) 128 | ensemble() 129 | variables = tf.trainable_variables() 130 | self.assertEqual(len(variables), num_dims) 131 | for v in variables: 132 | self.assertEqual(v.get_shape().as_list(), []) 133 | 134 | @parameterized.expand([(-1,), (0,), (1,), (10,)]) 135 | def testValues(self, value): 136 | num_dims = 1 137 | weight = 0.5 138 | problem_defs = [{"name": "simple", "options": {}} for _ in xrange(num_dims)] 139 | ensemble = problems.ensemble(problem_defs, weights=[weight]) 140 | f = ensemble() 141 | 142 | with self.test_session() as sess: 143 | output = sess.run(f, feed_dict={"problem_0/x:0": value}) 144 | self.assertEqual(output, weight * value**2) 145 | 146 | 147 | if __name__ == "__main__": 148 | tf.test.main() 149 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Learning 2 Learn training.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | 23 | from six.moves import xrange 24 | import tensorflow as tf 25 | 26 | from tensorflow.contrib.learn.python.learn import monitored_session as ms 27 | 28 | import meta 29 | import util 30 | 31 | flags = tf.flags 32 | logging = tf.logging 33 | 34 | 35 | FLAGS = flags.FLAGS 36 | flags.DEFINE_string("save_path", None, "Path for saved meta-optimizer.") 37 | flags.DEFINE_integer("num_epochs", 10000, "Number of training epochs.") 38 | flags.DEFINE_integer("log_period", 100, "Log period.") 39 | flags.DEFINE_integer("evaluation_period", 1000, "Evaluation period.") 40 | flags.DEFINE_integer("evaluation_epochs", 20, "Number of evaluation epochs.") 41 | 42 | flags.DEFINE_string("problem", "simple", "Type of problem.") 43 | flags.DEFINE_integer("num_steps", 100, 44 | "Number of optimization steps per epoch.") 45 | flags.DEFINE_integer("unroll_length", 20, "Meta-optimizer unroll length.") 46 | flags.DEFINE_float("learning_rate", 0.001, "Learning rate.") 47 | flags.DEFINE_boolean("second_derivatives", False, "Use second derivatives.") 48 | 49 | 50 | def main(_): 51 | # Configuration. 52 | num_unrolls = FLAGS.num_steps // FLAGS.unroll_length 53 | 54 | if FLAGS.save_path is not None: 55 | if os.path.exists(FLAGS.save_path): 56 | raise ValueError("Folder {} already exists".format(FLAGS.save_path)) 57 | else: 58 | os.mkdir(FLAGS.save_path) 59 | 60 | # Problem. 61 | problem, net_config, net_assignments = util.get_config(FLAGS.problem) 62 | 63 | # Optimizer setup. 64 | optimizer = meta.MetaOptimizer(**net_config) 65 | minimize = optimizer.meta_minimize( 66 | problem, FLAGS.unroll_length, 67 | learning_rate=FLAGS.learning_rate, 68 | net_assignments=net_assignments, 69 | second_derivatives=FLAGS.second_derivatives) 70 | step, update, reset, cost_op, _ = minimize 71 | 72 | with ms.MonitoredSession() as sess: 73 | # Prevent accidental changes to the graph. 74 | tf.get_default_graph().finalize() 75 | 76 | best_evaluation = float("inf") 77 | total_time = 0 78 | total_cost = 0 79 | for e in xrange(FLAGS.num_epochs): 80 | # Training. 81 | time, cost = util.run_epoch(sess, cost_op, [update, step], reset, 82 | num_unrolls) 83 | total_time += time 84 | total_cost += cost 85 | 86 | # Logging. 87 | if (e + 1) % FLAGS.log_period == 0: 88 | util.print_stats("Epoch {}".format(e + 1), total_cost, total_time, 89 | FLAGS.log_period) 90 | total_time = 0 91 | total_cost = 0 92 | 93 | # Evaluation. 94 | if (e + 1) % FLAGS.evaluation_period == 0: 95 | eval_cost = 0 96 | eval_time = 0 97 | for _ in xrange(FLAGS.evaluation_epochs): 98 | time, cost = util.run_epoch(sess, cost_op, [update], reset, 99 | num_unrolls) 100 | eval_time += time 101 | eval_cost += cost 102 | 103 | util.print_stats("EVALUATION", eval_cost, eval_time, 104 | FLAGS.evaluation_epochs) 105 | 106 | if FLAGS.save_path is not None and eval_cost < best_evaluation: 107 | print("Removing previously saved meta-optimizer") 108 | for f in os.listdir(FLAGS.save_path): 109 | os.remove(os.path.join(FLAGS.save_path, f)) 110 | print("Saving meta-optimizer to {}".format(FLAGS.save_path)) 111 | optimizer.save(sess, FLAGS.save_path) 112 | best_evaluation = eval_cost 113 | 114 | 115 | if __name__ == "__main__": 116 | tf.app.run() 117 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Learning 2 Learn utils.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | from timeit import default_timer as timer 23 | 24 | import numpy as np 25 | from six.moves import xrange 26 | 27 | import problems 28 | 29 | 30 | def run_epoch(sess, cost_op, ops, reset, num_unrolls): 31 | """Runs one optimization epoch.""" 32 | start = timer() 33 | sess.run(reset) 34 | for _ in xrange(num_unrolls): 35 | cost = sess.run([cost_op] + ops)[0] 36 | return timer() - start, cost 37 | 38 | 39 | def print_stats(header, total_error, total_time, n): 40 | """Prints experiment statistics.""" 41 | print(header) 42 | print("Log Mean Final Error: {:.2f}".format(np.log10(total_error / n))) 43 | print("Mean epoch time: {:.2f} s".format(total_time / n)) 44 | 45 | 46 | def get_net_path(name, path): 47 | return None if path is None else os.path.join(path, name + ".l2l") 48 | 49 | 50 | def get_default_net_config(name, path): 51 | return { 52 | "net": "CoordinateWiseDeepLSTM", 53 | "net_options": { 54 | "layers": (20, 20), 55 | "preprocess_name": "LogAndSign", 56 | "preprocess_options": {"k": 5}, 57 | "scale": 0.01, 58 | }, 59 | "net_path": get_net_path(name, path) 60 | } 61 | 62 | 63 | def get_config(problem_name, path=None): 64 | """Returns problem configuration.""" 65 | if problem_name == "simple": 66 | problem = problems.simple() 67 | net_config = {"cw": { 68 | "net": "CoordinateWiseDeepLSTM", 69 | "net_options": {"layers": (), "initializer": "zeros"}, 70 | "net_path": get_net_path("cw", path) 71 | }} 72 | net_assignments = None 73 | elif problem_name == "simple-multi": 74 | problem = problems.simple_multi_optimizer() 75 | net_config = { 76 | "cw": { 77 | "net": "CoordinateWiseDeepLSTM", 78 | "net_options": {"layers": (), "initializer": "zeros"}, 79 | "net_path": get_net_path("cw", path) 80 | }, 81 | "adam": { 82 | "net": "Adam", 83 | "net_options": {"learning_rate": 0.1} 84 | } 85 | } 86 | net_assignments = [("cw", ["x_0"]), ("adam", ["x_1"])] 87 | elif problem_name == "quadratic": 88 | problem = problems.quadratic(batch_size=128, num_dims=10) 89 | net_config = {"cw": { 90 | "net": "CoordinateWiseDeepLSTM", 91 | "net_options": {"layers": (20, 20)}, 92 | "net_path": get_net_path("cw", path) 93 | }} 94 | net_assignments = None 95 | elif problem_name == "mnist": 96 | mode = "train" if path is None else "test" 97 | problem = problems.mnist(layers=(20,), mode=mode) 98 | net_config = {"cw": get_default_net_config("cw", path)} 99 | net_assignments = None 100 | elif problem_name == "cifar": 101 | mode = "train" if path is None else "test" 102 | problem = problems.cifar10("cifar10", 103 | conv_channels=(16, 16, 16), 104 | linear_layers=(32,), 105 | mode=mode) 106 | net_config = {"cw": get_default_net_config("cw", path)} 107 | net_assignments = None 108 | elif problem_name == "cifar-multi": 109 | mode = "train" if path is None else "test" 110 | problem = problems.cifar10("cifar10", 111 | conv_channels=(16, 16, 16), 112 | linear_layers=(32,), 113 | mode=mode) 114 | net_config = { 115 | "conv": get_default_net_config("conv", path), 116 | "fc": get_default_net_config("fc", path) 117 | } 118 | conv_vars = ["conv_net_2d/conv_2d_{}/w".format(i) for i in xrange(3)] 119 | fc_vars = ["conv_net_2d/conv_2d_{}/b".format(i) for i in xrange(3)] 120 | fc_vars += ["conv_net_2d/batch_norm_{}/beta".format(i) for i in xrange(3)] 121 | fc_vars += ["mlp/linear_{}/w".format(i) for i in xrange(2)] 122 | fc_vars += ["mlp/linear_{}/b".format(i) for i in xrange(2)] 123 | fc_vars += ["mlp/batch_norm/beta"] 124 | net_assignments = [("conv", conv_vars), ("fc", fc_vars)] 125 | else: 126 | raise ValueError("{} is not a valid problem".format(problem_name)) 127 | 128 | return problem, net_config, net_assignments 129 | --------------------------------------------------------------------------------