├── .gitignore ├── LICENSE ├── README.md ├── cabs.py ├── examples ├── cifar10_adaptive_batchsize.py ├── models │ ├── __init__.py │ ├── cifar10_2conv_3dense.py │ └── mnist_2conv_2dense.py ├── run_cabs_cifar10.py └── run_cabs_mnist.py └── gradient_moment.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | examples/data 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2016 Max Planck Society. All rights reserved. 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright 2016, Max Planck Society. 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SGD with Coupled Adaptive Batch Size (CABS) 2 | 3 | This is a TensorFlow implementation of [SGD with Coupled Adaptive Batch Size (CABS)][1]. 4 | 5 | ## The Algorithm in a Nutshell 6 | 7 | CABS is an algorithm to dynamically adapt the batch size when performing 8 | stochastic gradient descent (SGD) on an empirical risk minimization problem. At 9 | each iteration, it computes an empirical measure ``xi`` of the variance of the 10 | stochastic gradient. The batch size for the next iteration is then set to 11 | ``bs_new = lr*xi/loss``, where ``lr`` is the learning rate and ``loss`` is the 12 | current value of the loss function. Refer to the [paper][1] for more information. 13 | 14 | ## Requirements 15 | 16 | tensorflow 1.0 17 | ## Usage 18 | 19 | Usage of ``CABSOptimizer`` is similar to that of other TensorFlow optimizers, 20 | with the exception that its ``minimize`` function expects a vector of ``losses`` 21 | one for each training example in the batch, instead of an aggregate mean 22 | ``loss``. This is so that the optimizer can easily access the batch size. Moreover, 23 | some measures have to be taken to ensure that batches of appropriate size are 24 | fed into the TensorFlow model. The specifics depend on how you choose to feed 25 | your data. 26 | 27 | ### Manually Feeding Data 28 | If you are a manually providing the training data via a 29 | ``feed_dict``, you have to fetch the batch size that CABS suggests and then 30 | provide a batch of that size for the next iteration. This would look roughly 31 | like this 32 | 33 | ```python 34 | import tensorflow as tf 35 | from cabs import CABSOptimizer 36 | 37 | X, y = ... # your placeholders for data 38 | ... # set up your model 39 | losses = ... # vector of losses, one for each training example in the batch 40 | var_list = ... # list of trainable variables 41 | 42 | opt = CABSOptimizer(lr, bs_min, bs_max) 43 | sgd_step, bs_new, loss = opt.minimize(losses, var_list) 44 | m = initial_batch_size 45 | 46 | sess = tf.Session() 47 | sess.run(tf.initialize_all_variables()) 48 | 49 | for i in range(num_steps): 50 | X_batch, y_batch = ... # Get a batch of size m ready (you have to take care of this yourself) 51 | _, m_new, l = sess.run([sgd_step, bs_new, loss], feed_dict={X: X_batch, y: y_batch}) 52 | print(l) 53 | print(m_new) 54 | m = m_new 55 | ``` 56 | 57 | The MNIST example (examples/run_cabs_mnist.py) is a full working example using 58 | ``feed_dict``. 59 | 60 | ### Reading Data from Files 61 | If you are reading data from files using TensorFlow's built-in mechanism, the 62 | batch size comes into action when fetching batches of data from an example queue 63 | via ``tf.train.batch`` (or ``tf.train.shuffle_batch``). To use CABS, use a 64 | variable ``global_bs`` as the ``batch_size`` argument of ``tf.train.batch``, 65 | then pass ``global_bs`` to the ``minimize`` method of the ``CABSOptimizer``. The 66 | optimizer will then write the new batch size to the global batch size variable, 67 | directly communicating it to your data loading mechanism. Sketch: 68 | 69 | ```python 70 | import tensorflow as tf 71 | from cabs import CABSOptimizer 72 | 73 | X, y = ... # your example queue 74 | global_bs = tf.Variable(initial_batch_size) # initialize a global batch size variable 75 | X_batch, y_batch = tf.train.batch([X, y], batch_size=global_bs) # global_bs is used as the batch_size argument of tf.train.batch 76 | ... # set up your model 77 | losses = ... # vector of losses, one for each training example in the batch 78 | var_list = ... # list of trainable variables 79 | 80 | opt = CABSOptimizer(lr, bs_min, bs_max) 81 | sgd_step, bs_new, loss = opt.minimize(losses, var_list, global_bs) # pass global_bs here, so that CABSOptimizer can write to it 82 | 83 | sess = tf.Session() 84 | sess.run(tf.initialize_all_variables()) 85 | 86 | for i in range(num_steps): 87 | _, m_new, l = sess.run([sgd_step, bs_new, loss]) 88 | print(l) 89 | print(m_new) 90 | ``` 91 | 92 | Refer to our CIFAR-10 example (examples/run_cabs_cifar10.py) for a full working 93 | example using this mechanism. 94 | 95 | 96 | ## Quick Guide to this Implementation 97 | 98 | The implementation of CABS (see cabs.py) itself is straight-forward. The 99 | ``CABSOptimizer`` class inherits from ``tf.train.GradientDescentOptimizer`` and 100 | implements the identical parameter updates, but adds the necessary additional 101 | computations for the CABS batch size. A crucial part of that is the within-batch 102 | estimate of the gradient variance, see equation (10) in the [paper][1]. As 103 | mentioned in section 4.2, the computation of the second gradient moment (see 104 | gradient_moment.py) is a little tricky; for more information see this [note][2]. 105 | 106 | [1]: https://arxiv.org/abs/1612.05086 107 | [2]: https://drive.google.com/open?id=0B0adgqwcMJK5aDNaQ2Q4ZmhCQzA 108 | -------------------------------------------------------------------------------- /cabs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | SGD optimizer with Coupled Adaptive Batch Size as described in 4 | 5 | Lukas Balles, Javier Romero and Philipp Hennig: Coupling Adaptive Batch Sizes 6 | with Learning Rates. [url]. 7 | """ 8 | 9 | import tensorflow as tf 10 | import gradient_moment as gm 11 | 12 | class CABSOptimizer(tf.train.GradientDescentOptimizer): 13 | 14 | """Optimizer that implements stochastic gradient desent with Coupled Adative 15 | Batch Size (CABS) as descibed in 16 | 17 | Lukas Balles, Javier Romero and Philipp Hennig: Coupling Adaptive Batch 18 | Sizes with Learning Rates. [url]. 19 | 20 | @@__init__ 21 | """ 22 | 23 | def __init__(self, learning_rate, bs_min=16, bs_max=2048, 24 | running_avg_constant=0.95, eps=0.0, c=1.0, debug=False, 25 | name="CABS-SGD"): 26 | """Construct a new gradient descent optimizer with coupled adaptive batch 27 | size (CABS). 28 | 29 | Args: 30 | :learning_rate: A Tensor or a floating point value. The learning 31 | rate to use. 32 | :bs_min: Minimum batch size (integer). Defaults to 16. 33 | :bs_max: Maximum batch size (integer). Defaults to 2048. 34 | :running_average_constant: The variance and function value estimates 35 | are smoothed over iterations using an exponential running average with 36 | this constant. Defaults to 0.95. 37 | :eps: Constant added to the denominator of the CABS rule for numerical 38 | stability. Defaults to 0.0, but might be set to a small constant, e.g. 39 | eps=1e-8. 40 | :c: Constant by which to multiply the CABS batch size. Defaults to 1.0 41 | and we recommend to leave it at this. 42 | :debug: Boolean to switch on debug mode, where ``minimize()`` returns 43 | additional diagnostic outputs. Default is False. 44 | :name: Optional name prefix for the operations created when applying 45 | gradients. Defaults to "CABS-SGD". 46 | """ 47 | 48 | super(CABSOptimizer, self).__init__(learning_rate, name=name) 49 | self._bs_min = bs_min 50 | self._bs_max = bs_max 51 | self._running_avg_constant = running_avg_constant 52 | self._eps = eps 53 | self._c = c 54 | self._debug = debug 55 | 56 | def minimize(self, losses, var_list=None, global_bs=None): 57 | """Add operations to minimize `loss` by updating `var_list` with SGD and 58 | compute the batch size for the next step according to the CABS rule. 59 | 60 | Args: 61 | :losses: A rank 1 `Tensor` containing the individual loss values for each 62 | example in the batch. You can *not* insert a scalar mean loss, as in 63 | other optimizers. 64 | :var_list: Optional list of `Variable` objects to update to minimize 65 | `loss`. Defaults to the list of variables collected in the graph 66 | under the key `GraphKeys.TRAINABLE_VARIABLES`. 67 | :global_bs: Optional `Variable` to which the computed batch size is 68 | assigned. When you feed data using tensorflow queues, use this variable 69 | as batch size in ``tf.train.batch()`` or `tf.train.shuffle_batch`. When 70 | you feed data via ``placeholder``s and ``feed_dict``s, use 71 | ``global_bs=None``. In this case you have to fetch ``bs_new_int`` 72 | (one of the return values of this function, see below) and take care 73 | of the batch size yourself. 74 | Returns: 75 | If ``debug=False`` 76 | :sgd_step: An Operation that updates the variables in `var_list` via 77 | SGD step. 78 | :bs_new: A scalar integer tensor containing the CABS batch size for the 79 | next optimization step. 80 | :loss: A scalar tensor with the mean of the inserted ``losses``. 81 | If ``debug=True`` 82 | :sgd_step: An Operation that updates the variables in `var_list` via 83 | SGD step. 84 | :bs_new: A scalar integer tensor containing the rounded and capped CABS 85 | batch size to be used in the next optimization step. 86 | :bs_new_raw: A scalar tensor containing the raw CABS batch size before 87 | rounding and capping. 88 | :loss_avg: A scalar tensor containing the running average of the mean 89 | loss. 90 | :loss: A scalar tensor with the mean of the inserted ``losses``, i.e. 91 | the current loss. 92 | :xi_avg: A scalar tensor containing the running average of the 93 | gradient variance. 94 | :xi: A scalar tensor containing the current gradient variance. 95 | If ``global_bs`` was not ``None``, the result ``bs_new`` is also 96 | written to the ``global_bs`` Variable. 97 | Raises: 98 | ValueError: If some of the variables are not `Variable` objects. 99 | """ 100 | 101 | if global_bs is not None: 102 | assert isinstance(global_bs, tf.Variable) 103 | 104 | # Create variables for the moving averages of noise level and loss 105 | if var_list is None: 106 | var_list = tf.trainable_variables() 107 | xi_avg = tf.Variable(0.0) 108 | loss_avg = tf.Variable(1.0) 109 | 110 | # Extract input data type and batch size from the provided losses 111 | input_dtype = losses.dtype.base_dtype 112 | input_batch_size = tf.cast(tf.gather(tf.shape(losses), 0), input_dtype) 113 | 114 | # Convert constant algo parameters to tensors 115 | mu = tf.convert_to_tensor(self._running_avg_constant, dtype=input_dtype) 116 | c = tf.convert_to_tensor(self._c, dtype=input_dtype) 117 | lr = tf.convert_to_tensor(self._learning_rate, dtype=input_dtype) 118 | eps = tf.convert_to_tensor(self._eps, dtype=input_dtype) 119 | bs_min = tf.convert_to_tensor(self._bs_min, dtype=input_dtype) 120 | bs_max = tf.convert_to_tensor(self._bs_max, dtype=input_dtype) 121 | 122 | # Compute mean loss and feed it into a running average 123 | loss = tf.reduce_mean(losses) 124 | update_avgs = [loss_avg.assign(mu*loss_avg + (1.0-mu)*loss)] 125 | 126 | # Compute gradients and gradient moments 127 | grads, moms = gm.grads_and_grad_moms(loss, input_batch_size, var_list) 128 | grads_squared = [tf.square(g) for g in grads] 129 | 130 | # Compute gradient variance and feed it into a running average 131 | grad_variances = [(m-g2) for g2, m in zip(grads_squared, moms)] 132 | xi = tf.add_n([tf.reduce_sum(gv) for gv in grad_variances]) 133 | update_avgs.append(xi_avg.assign(mu*xi_avg + (1.0-mu)*xi)) 134 | 135 | # Compute the new batch size (with a dependency that makes sure that the 136 | # moving averages are updated beforehand) 137 | with tf.control_dependencies(update_avgs): 138 | bs_new_raw = c*lr*tf.divide(xi_avg, loss_avg+eps) 139 | 140 | # Round the new batch size 141 | bs_new_rounded = tf.round(bs_new_raw) 142 | bs_new = tf.clip_by_value(bs_new_rounded, bs_min, bs_max) 143 | bs_new = tf.to_int32(bs_new) 144 | 145 | # If a global variable to hold the batch size was given by the user, add 146 | # operation that saves the new batch size to this variable 147 | deps = [bs_new] 148 | if global_bs is not None: 149 | deps.append(global_bs.assign(bs_new)) 150 | 151 | # Add SGD update operations 152 | with tf.control_dependencies(deps): 153 | sgd_updates = [v.assign_sub(lr*g) for v, g in zip(var_list, grads)] 154 | sgd_step = tf.group(*sgd_updates) 155 | 156 | # Return the SGD update op and the new (rounded) batch size 157 | # In debug mode, additionally return the various intermediate quantities 158 | if self._debug: 159 | return sgd_step, bs_new, bs_new_raw, loss_avg, loss, xi_avg, xi 160 | else: 161 | return sgd_step, bs_new, loss -------------------------------------------------------------------------------- /examples/cifar10_adaptive_batchsize.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Routine for decoding the CIFAR-10 binary file format.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import sys 23 | import os 24 | import tarfile 25 | 26 | from six.moves import xrange, urllib # pylint: disable=redefined-builtin 27 | import tensorflow as tf 28 | 29 | # Process images of this size. Note that this differs from the original CIFAR 30 | # image size of 32 x 32. If one alters this number, then the entire model 31 | # architecture will change and any model would need to be retrained. 32 | IMAGE_SIZE = 24 33 | 34 | # Global constants describing the CIFAR-10 data set. 35 | NUM_CLASSES = 10 36 | NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000 37 | NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000 38 | 39 | DATA_DIR = "data/cifar-10/cifar-10-batches-bin" 40 | DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz' 41 | 42 | # Check if data is already there, if not download! 43 | dest_directory = "data/cifar-10" 44 | if not os.path.exists(dest_directory): 45 | os.makedirs(dest_directory) 46 | filename = DATA_URL.split('/')[-1] 47 | filepath = os.path.join(dest_directory, filename) 48 | if not os.path.exists(filepath): 49 | def _progress(count, block_size, total_size): 50 | sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename, 51 | float(count * block_size) / float(total_size) * 100.0)) 52 | sys.stdout.flush() 53 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) 54 | print() 55 | statinfo = os.stat(filepath) 56 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 57 | tarfile.open(filepath, 'r:gz').extractall(dest_directory) 58 | 59 | def read_cifar10(filename_queue): 60 | """Reads and parses examples from CIFAR10 data files. 61 | Recommendation: if you want N-way read parallelism, call this function 62 | N times. This will give you N independent Readers reading different 63 | files & positions within those files, which will give better mixing of 64 | examples. 65 | Args: 66 | filename_queue: A queue of strings with the filenames to read from. 67 | Returns: 68 | An object representing a single example, with the following fields: 69 | height: number of rows in the result (32) 70 | width: number of columns in the result (32) 71 | depth: number of color channels in the result (3) 72 | key: a scalar string Tensor describing the filename & record number 73 | for this example. 74 | label: an int32 Tensor with the label in the range 0..9. 75 | uint8image: a [height, width, depth] uint8 Tensor with the image data 76 | """ 77 | 78 | class CIFAR10Record(object): 79 | pass 80 | result = CIFAR10Record() 81 | 82 | # Dimensions of the images in the CIFAR-10 dataset. 83 | # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the 84 | # input format. 85 | label_bytes = 1 # 2 for CIFAR-100 86 | result.height = 32 87 | result.width = 32 88 | result.depth = 3 89 | image_bytes = result.height * result.width * result.depth 90 | # Every record consists of a label followed by the image, with a 91 | # fixed number of bytes for each. 92 | record_bytes = label_bytes + image_bytes 93 | 94 | # Read a record, getting filenames from the filename_queue. No 95 | # header or footer in the CIFAR-10 format, so we leave header_bytes 96 | # and footer_bytes at their default of 0. 97 | reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) 98 | result.key, value = reader.read(filename_queue) 99 | 100 | # Convert from a string to a vector of uint8 that is record_bytes long. 101 | record_bytes = tf.decode_raw(value, tf.uint8) 102 | 103 | # The first bytes represent the label, which we convert from uint8->int32. 104 | result.label = tf.cast( 105 | tf.slice(record_bytes, [0], [label_bytes]), tf.int32) 106 | 107 | # The remaining bytes after the label represent the image, which we reshape 108 | # from [depth * height * width] to [depth, height, width]. 109 | depth_major = tf.reshape(tf.slice(record_bytes, [label_bytes], [image_bytes]), 110 | [result.depth, result.height, result.width]) 111 | # Convert from [depth, height, width] to [height, width, depth]. 112 | result.uint8image = tf.transpose(depth_major, [1, 2, 0]) 113 | 114 | return result 115 | 116 | 117 | def _generate_image_and_label_batch(image, label, min_queue_examples, 118 | batch_size, shuffle, max_batch_size=1024): 119 | """Construct a queued batch of images and labels. 120 | Args: 121 | image: 3-D Tensor of [height, width, 3] of type.float32. 122 | label: 1-D Tensor of type.int32 123 | min_queue_examples: int32, minimum number of samples to retain 124 | in the queue that provides of batches of examples. 125 | batch_size: Number of images per batch. 126 | shuffle: boolean indicating whether to use a shuffling queue. 127 | Returns: 128 | images: Images. 4D tensor of [batch_size, height, width, 3] size. 129 | labels: Labels. 1D tensor of [batch_size] size. 130 | """ 131 | # Create a queue that shuffles the examples, and then 132 | # read 'batch_size' images + labels from the example queue. 133 | num_preprocess_threads = 16 134 | if shuffle: 135 | images, label_batch = tf.train.shuffle_batch( 136 | [image, label], 137 | batch_size=batch_size, 138 | num_threads=num_preprocess_threads, 139 | capacity=min_queue_examples + 2*max_batch_size, 140 | min_after_dequeue=min_queue_examples) 141 | else: 142 | images, label_batch = tf.train.batch( 143 | [image, label], 144 | batch_size=batch_size, 145 | num_threads=num_preprocess_threads, 146 | capacity=min_queue_examples + 2*max_batch_size) 147 | 148 | # Display the training images in the visualizer. 149 | tf.summary.image('images', images) 150 | 151 | return images, tf.reshape(label_batch, [batch_size]) 152 | 153 | # Possibly remove when we only use undistorted images 154 | def distorted_inputs(data_dir=DATA_DIR, batch_size=128): 155 | """Construct distorted input for CIFAR training using the Reader ops. 156 | Args: 157 | data_dir: Path to the CIFAR-10 data directory. 158 | batch_size: Number of images per batch. 159 | Returns: 160 | images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. 161 | labels: Labels. 1D tensor of [batch_size] size. 162 | """ 163 | filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) 164 | for i in xrange(1, 6)] 165 | for f in filenames: 166 | if not tf.gfile.Exists(f): 167 | raise ValueError('Failed to find file: ' + f) 168 | 169 | # Create a queue that produces the filenames to read. 170 | filename_queue = tf.train.string_input_producer(filenames) 171 | 172 | # Read examples from files in the filename queue. 173 | read_input = read_cifar10(filename_queue) 174 | reshaped_image = tf.cast(read_input.uint8image, tf.float32) 175 | 176 | height = IMAGE_SIZE 177 | width = IMAGE_SIZE 178 | 179 | # Image processing for training the network. Note the many random 180 | # distortions applied to the image. 181 | 182 | # Randomly crop a [height, width] section of the image. 183 | distorted_image = tf.random_crop(reshaped_image, [height, width, 3]) 184 | 185 | # Randomly flip the image horizontally. 186 | distorted_image = tf.image.random_flip_left_right(distorted_image) 187 | 188 | # Because these operations are not commutative, consider randomizing 189 | # the order their operation. 190 | distorted_image = tf.image.random_brightness(distorted_image, 191 | max_delta=63) 192 | distorted_image = tf.image.random_contrast(distorted_image, 193 | lower=0.2, upper=1.8) 194 | 195 | # Subtract off the mean and divide by the variance of the pixels. 196 | float_image = tf.image.per_image_standardization(distorted_image) 197 | 198 | # Ensure that the random shuffling has good mixing properties. 199 | min_fraction_of_examples_in_queue = 0.4 200 | min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * 201 | min_fraction_of_examples_in_queue) 202 | print ('Filling queue with %d CIFAR images before starting to train. ' 203 | 'This will take a few minutes.' % min_queue_examples) 204 | 205 | # Generate a batch of images and labels by building up a queue of examples. 206 | return _generate_image_and_label_batch(float_image, read_input.label, 207 | min_queue_examples, batch_size, 208 | shuffle=True) 209 | 210 | 211 | def inputs(eval_data, data_dir=DATA_DIR, batch_size=128): 212 | """Construct input for CIFAR evaluation using the Reader ops. 213 | Args: 214 | eval_data: bool, indicating if one should use the train or eval data set. 215 | data_dir: Path to the CIFAR-10 data directory. 216 | batch_size: Number of images per batch. 217 | Returns: 218 | images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. 219 | labels: Labels. 1D tensor of [batch_size] size. 220 | """ 221 | if not eval_data: 222 | filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) 223 | for i in xrange(1, 6)] 224 | num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN 225 | else: 226 | filenames = [os.path.join(data_dir, 'test_batch.bin')] 227 | num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL 228 | 229 | for f in filenames: 230 | if not tf.gfile.Exists(f): 231 | raise ValueError('Failed to find file: ' + f) 232 | 233 | # Create a queue that produces the filenames to read. 234 | filename_queue = tf.train.string_input_producer(filenames) 235 | 236 | # Read examples from files in the filename queue. 237 | read_input = read_cifar10(filename_queue) 238 | reshaped_image = tf.cast(read_input.uint8image, tf.float32) 239 | 240 | height = IMAGE_SIZE 241 | width = IMAGE_SIZE 242 | 243 | # Image processing for evaluation. 244 | # Crop the central [height, width] of the image. 245 | resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, 246 | width, height) 247 | 248 | # Subtract off the mean and divide by the variance of the pixels. 249 | float_image = tf.image.per_image_standardization(resized_image) 250 | 251 | # Ensure that the random shuffling has good mixing properties. 252 | min_fraction_of_examples_in_queue = 0.4 253 | min_queue_examples = int(num_examples_per_epoch * 254 | min_fraction_of_examples_in_queue) 255 | 256 | # Generate a batch of images and labels by building up a queue of examples. 257 | return _generate_image_and_label_batch(float_image, read_input.label, 258 | min_queue_examples, batch_size, 259 | shuffle=False) -------------------------------------------------------------------------------- /examples/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Nov 27 11:59:40 2016 4 | 5 | @author: Lukas Balles [lballes@tuebingen.mpg.de] 6 | """ 7 | -------------------------------------------------------------------------------- /examples/models/cifar10_2conv_3dense.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Aug 4 11:05:44 2016 4 | 5 | @author: lballes 6 | """ 7 | 8 | import tensorflow as tf 9 | 10 | def weight_variable(shape, stddev=1e-2): 11 | initial = tf.truncated_normal(shape, stddev=stddev) 12 | return tf.Variable(initial) 13 | 14 | def bias_variable(shape, val=0.05): 15 | initial = tf.constant(val, shape=shape) 16 | return tf.Variable(initial) 17 | 18 | def conv2d(x, W): 19 | return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 20 | 21 | def max_pool_3x3(x): 22 | return tf.nn.max_pool(x, ksize=[1, 3, 3, 1], 23 | strides=[1, 2, 2, 1], padding='SAME') 24 | 25 | def set_up_model(images, labels): 26 | W_conv1 = weight_variable([5, 5, 3, 64], 5e-2) 27 | b_conv1 = bias_variable([64], 0.0) 28 | h_conv1 = tf.nn.relu(conv2d(images, W_conv1) + b_conv1) 29 | h_conv1_pool = max_pool_3x3(h_conv1) 30 | 31 | W_conv2 = weight_variable([5, 5, 64, 64], 5e-2) 32 | b_conv2 = bias_variable([64], 0.1) 33 | h_conv2 = tf.nn.relu(conv2d(h_conv1_pool, W_conv2) + b_conv2) 34 | h_conv2_pool = max_pool_3x3(h_conv2) 35 | 36 | batch_size = tf.gather(tf.shape(images), 0) 37 | reshape = tf.reshape(h_conv2_pool, tf.stack([batch_size, 2304])) 38 | dim = 2304 39 | W_fc1 = weight_variable([dim, 384], 0.04) 40 | b_fc1 = bias_variable([384], 0.1) 41 | h_fc1 = tf.nn.relu(tf.matmul(reshape, W_fc1) + b_fc1) 42 | 43 | W_fc2 = weight_variable([384, 192], 0.04) 44 | b_fc2 = bias_variable([192], 0.1) 45 | h_fc2 = tf.nn.relu(tf.matmul(h_fc1, W_fc2) + b_fc2) 46 | 47 | W_fc3 = weight_variable([192, 10], 1/192.0) 48 | b_fc3 = bias_variable([10], 0.0) 49 | h_fc3 = tf.matmul(h_fc2, W_fc3) + b_fc3 50 | 51 | labels = tf.cast(labels, tf.int64) 52 | losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h_fc3, labels=labels) 53 | return losses, [W_conv1, b_conv1, W_conv2, b_conv2, W_fc1, b_fc1, W_fc2, b_fc2, W_fc3, b_fc3] 54 | -------------------------------------------------------------------------------- /examples/models/mnist_2conv_2dense.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | TensorFlow MNIST CNN model. 4 | """ 5 | 6 | import tensorflow as tf 7 | 8 | def weight_variable(shape): 9 | initial = tf.truncated_normal(shape, stddev=1e-2) 10 | return tf.Variable(initial) 11 | 12 | def bias_variable(shape): 13 | initial = tf.constant(0.05, shape=shape) 14 | return tf.Variable(initial) 15 | 16 | def conv2d(x, W): 17 | return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 18 | 19 | def max_pool_2x2(x): 20 | return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], 21 | strides=[1, 2, 2, 1], padding='SAME') 22 | 23 | def set_up_model(): 24 | tf.reset_default_graph() 25 | X = tf.placeholder(tf.float32, shape=[None, 784]) 26 | y = tf.placeholder(tf.float32, shape=[None, 10]) 27 | W_conv1 = weight_variable([5, 5, 1, 32]) 28 | b_conv1 = bias_variable([32]) 29 | X_image = tf.reshape(X, [-1,28,28,1]) 30 | h_conv1 = tf.nn.relu(conv2d(X_image, W_conv1) + b_conv1) 31 | h_pool1 = max_pool_2x2(h_conv1) 32 | W_conv2 = weight_variable([5, 5, 32, 64]) 33 | b_conv2 = bias_variable([64]) 34 | h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 35 | h_pool2 = max_pool_2x2(h_conv2) 36 | W_fc1 = weight_variable([7 * 7 * 64, 1024]) 37 | b_fc1 = bias_variable([1024]) 38 | h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 39 | h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 40 | W_fc2 = weight_variable([1024, 10]) 41 | b_fc2 = bias_variable([10]) 42 | h_fc2 = tf.nn.softmax(tf.matmul(h_fc1, W_fc2) + b_fc2) 43 | losses = -tf.reduce_sum(y*tf.log(h_fc2), reduction_indices=[1]) 44 | return losses, [X, y], [W_conv1, b_conv1, W_conv2, b_conv2, W_fc1, b_fc1, W_fc2, b_fc2] 45 | -------------------------------------------------------------------------------- /examples/run_cabs_cifar10.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Run CABS on a CIFAR-10 example. 4 | 5 | This will download the dataset to data/cifar-10 automatically if necessary. 6 | """ 7 | 8 | import os 9 | import sys 10 | sys.path.insert(0, os.path.abspath('..')) 11 | 12 | import tensorflow as tf 13 | import cifar10_adaptive_batchsize as cifar10 14 | 15 | from cabs import CABSOptimizer 16 | 17 | #### Specify training specifics here ########################################## 18 | from models import cifar10_2conv_3dense as model 19 | num_steps = 8000 20 | learning_rate = 0.1 21 | initial_batch_size = 16 22 | bs_min = 16 23 | bs_max = 2048 24 | ############################################################################### 25 | 26 | # Set up model 27 | tf.reset_default_graph() 28 | global_bs = tf.Variable(tf.constant(initial_batch_size, dtype=tf.int32)) 29 | images, labels = cifar10.inputs(eval_data=False, batch_size=global_bs) 30 | losses, variables = model.set_up_model(images, labels) 31 | 32 | # Set up CABS optimizer 33 | opt = CABSOptimizer(learning_rate, bs_min, bs_max) 34 | sgd_step, bs_new, loss = opt.minimize(losses, variables, global_bs) 35 | 36 | # Initialize variables and start queues 37 | sess = tf.Session() 38 | coord = tf.train.Coordinator() 39 | sess.run(tf.global_variables_initializer()) 40 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 41 | 42 | # Run CABS 43 | for i in range(num_steps): 44 | _, m_new, l = sess.run([sgd_step, bs_new, loss]) 45 | print(l) 46 | print(m_new) 47 | 48 | # Stop queues 49 | coord.request_stop() 50 | coord.join(threads) -------------------------------------------------------------------------------- /examples/run_cabs_mnist.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Run CABS on a MNIST example. 4 | 5 | This will download the dataset to data/mnist automatically if necessary. 6 | """ 7 | 8 | import os 9 | import sys 10 | sys.path.insert(0, os.path.abspath('..')) 11 | 12 | import tensorflow as tf 13 | from tensorflow.examples.tutorials.mnist import input_data 14 | mnist = input_data.read_data_sets('data/mnist', one_hot=True) 15 | 16 | from cabs import CABSOptimizer 17 | 18 | #### Specify training specifics here ########################################## 19 | from models import mnist_2conv_2dense as model 20 | num_steps = 8000 21 | learning_rate = 0.1 22 | initial_batch_size = 16 23 | bs_min = 16 24 | bs_max = 2048 25 | ############################################################################### 26 | 27 | # Set up model 28 | losses, placeholders, variables = model.set_up_model() 29 | X, y = placeholders 30 | 31 | # Set up CABS optimizer 32 | opt = CABSOptimizer(learning_rate, bs_min, bs_max) 33 | sgd_step, bs_new, loss = opt.minimize(losses, variables) 34 | 35 | # Initialize variables 36 | sess = tf.Session() 37 | sess.run(tf.global_variables_initializer()) 38 | 39 | # Run CABS 40 | m = initial_batch_size 41 | for i in range(num_steps): 42 | batch = mnist.train.next_batch(m) 43 | _, m_new, l = sess.run([sgd_step, bs_new, loss], {X: batch[0], y: batch[1]}) 44 | print(l) 45 | print(m_new) 46 | m = m_new -------------------------------------------------------------------------------- /gradient_moment.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Computation of *moments* of gradients through tensorflow operations. 4 | 5 | Tensorflow is typically used for empircal risk minimzation with gradient-based 6 | optimization methods. That is, we want to adjust trainable variables ``W``, 7 | such as to minimize an objective quantity, called ``LOSS``, of the form 8 | 9 | LOSS(W) = (1/n) * sum{i=1:n}[ loss(W, d_i) ] 10 | 11 | That is the mean of individual losses induced by ``n`` training data points 12 | ``d_i``. Consquently, the gradient of ``LOSS`` w.r.t. the variables ``W`` is 13 | the mean of individual gradients ``dloss(W, d_i)``. These individual gradients 14 | are not computed separately when we call ``tf.gradients`` on the aggregate 15 | ``LOSS``. Instead, they are implicitly aggregated by the operations in the 16 | backward graph. This batch processing is crucial for the computational 17 | efficiency of the gradient computation. 18 | 19 | This module provides functionality to compute the ``p``-th moment of the 20 | individual gradients, i.e. the quantity 21 | 22 | MOM(W) = (1/n) * sum{i=1:n}[ dloss(w, d_i)**p ] 23 | 24 | without giving up the efficiency of batch processing. For a more detailed 25 | explanation, see the note [1]. Applications of this are the computation of the 26 | gradient variance estimate in [2] and [3]. 27 | 28 | [1] https://drive.google.com/open?id=0B0adgqwcMJK5aDNaQ2Q4ZmhCQzA 29 | 30 | [2] M. Mahsereci and P. Hennig. Probabilistic line searches for stochastic 31 | optimization. In Advances in Neural Information Processing Systems 28, pages 32 | 181-189, 2015. 33 | 34 | [3] L. Balles, J. Romero and P. Hennig. Coupling Adaptive Batch Sizes with 35 | Learning Rates. In arXiv preprint arXiv:1612.05086, 2016. 36 | https://arxiv.org/abs/1612.05086. 37 | """ 38 | 39 | import tensorflow as tf 40 | from tensorflow.python.ops import gen_array_ops 41 | 42 | VALID_TYPES = ["MatMul", "Conv2D", "Add"] 43 | VALID_REGULARIZATION_TYPES = ["L2Loss"] 44 | 45 | def _check_and_sort_ops(op_list): 46 | """Sort a list of ops according to type into valid types for which we can 47 | compute the gradient moment) and regularizers. Raise an exception when 48 | encountering an op of invalid type.""" 49 | 50 | valid, regularizers = [], [] 51 | for op in op_list: 52 | if op.type in VALID_TYPES: 53 | valid.append(op) 54 | elif op.type in VALID_REGULARIZATION_TYPES: 55 | regularizers.append(op) 56 | else: 57 | raise Exception("A variable in var_list is consumed by an operation of " 58 | "type {} for which I don't how to compute the gradient moment. " 59 | "Allowed are types {} and regularization operations " 60 | "of type {}".format(op.type, str(VALID_TYPES), 61 | str(VALID_REGULARIZATION_TYPES))) 62 | return valid, regularizers 63 | 64 | def grads_and_grad_moms(loss, batch_size, var_list, mom=2): 65 | """Compute the gradients and gradient moments of ``loss`` w.r.t. to the 66 | variables in ``var_list`` 67 | 68 | Inputs: 69 | :loss: The tensor containing the scalar loss. The loss has to be the 70 | ``tf.mean`` of ``batch_size`` individual losses induced by 71 | individual training data points. 72 | :batch_size: Self-explanatory. Integer tensor. 73 | :var_list: The list of variables. 74 | :mom: The desired moment. Integer. Defaults to 2. 75 | 76 | Returns: 77 | :v_grads: The gradients of ``loss`` w.r.t. the variables in ``var_list`` 78 | as computed by ``tf.gradients(loss, var_list)``. 79 | :grad_moms: The gradient moments for each variable in ``var_list``.""" 80 | 81 | assert len(set(var_list)) == len(var_list) 82 | vs = [tf.convert_to_tensor(v) for v in var_list] 83 | num_vars = len(vs) 84 | 85 | consumers = [] 86 | consumer_outs = [] 87 | for v in vs: 88 | valid, regularizers = _check_and_sort_ops(v.consumers()) 89 | if len(valid) > 1: 90 | raise Exception("Variable {} is consumed by more than one operation " 91 | "(ignoring regularization operations)".format(v.name)) 92 | if len(regularizers) > 1: 93 | raise Exception("Variable {} is consumed by more than one " 94 | "regularization operation".format(v.name)) 95 | consumers.extend(valid) 96 | consumer_outs.extend(valid[0].outputs) 97 | 98 | # Use tf.gradients to compute gradients w.r.t. the variables, while also 99 | # retrieving gradients w.r.t. the outputs 100 | all_grads = tf.gradients(loss, vs+consumer_outs) 101 | v_grads = all_grads[0:num_vars] 102 | out_grads = all_grads[num_vars::] 103 | 104 | # Compute the gradient moment for each (v, vp, op, output) 105 | with tf.name_scope("grad_moms"): 106 | grad_moms = [_GradMom(o, v, out_grad, batch_size, mom) 107 | for o, v, out_grad in zip(consumers, vs, out_grads)] 108 | 109 | return (v_grads, grad_moms) 110 | 111 | def _GradMom(op, v, out_grad, batch_size, mom=2): 112 | """Wrapper function for the operation type-specific GradMom functions below. 113 | 114 | Inputs: 115 | :op: A tensorflow operation of type in VALID_TYPES. 116 | :v: The read-tensor of the trainable variable consumed by this operation. 117 | :out_grad: The tensor containing the gradient w.r.t. to the output of 118 | the op (as computed by ``tf.gradients``). 119 | :batch_size: Batch size ``m`` (constant integer or scalar int tf.Tensor) 120 | :mom: Integer moment desired (defaults to 2).""" 121 | 122 | with tf.name_scope(op.name+"_grad_mom"): 123 | if op.type == "MatMul": 124 | return _MatMulGradMom(op, v, out_grad, batch_size, mom) 125 | elif op.type == "Conv2D": 126 | return _Conv2DGradMom(op, v, out_grad, batch_size, mom) 127 | elif op.type == "Add": 128 | return _AddGradMom(op, v, out_grad, batch_size, mom) 129 | else: 130 | raise ValueError("Don't know how to compute gradient moment for " 131 | "variable {}, consumed by operation of type {}".format(v.name, 132 | op.type)) 133 | 134 | def _MatMulGradMom(op, W, out_grad, batch_size, mom=2): 135 | """Computes gradient moment for a weight matrix through a MatMul operation. 136 | 137 | Assumes ``Z=tf.matmul(A, W)``, where ``W`` is a d1xd2 weight matrix, ``A`` 138 | are the nxd1 activations of the previous layer (n being the batch size). 139 | ``out_grad`` is the gradient w.r.t. ``Z``, as computed by ``tf.gradients()``. 140 | No transposes in the MatMul operation allowed. 141 | 142 | Inputs: 143 | :op: The MatMul operation 144 | :W: The weight matrix (the tensor, not the variable) 145 | :out_grad: The tensor of gradient w.r.t. to the output of the op 146 | :batch_size: Batch size n (constant integer or scalar int tf.Tensor) 147 | :mom: Integer moment desired (defaults to 2)""" 148 | 149 | assert op.type == "MatMul" 150 | t_a, t_b = op.get_attr("transpose_a"), op.get_attr("transpose_b") 151 | assert W is op.inputs[1] and not t_a and not t_b 152 | 153 | A = op.inputs[0] 154 | out_grad_pow = tf.pow(out_grad, mom) 155 | A_pow = tf.pow(A, mom) 156 | return tf.multiply(batch_size, tf.matmul(A_pow, out_grad_pow, transpose_a=True)) 157 | 158 | def _Conv2DGradMom(op, f, out_grad, batch_size, mom=2): 159 | """Computes gradient moment for the filter of a Conv2D operation. 160 | 161 | Assumes ``Z=tf.nn.conv2d(A, f)``, where ``f`` is a ``[h_f, w_f, c_in, c_out]`` 162 | convolution filter and ``A`` are the ``[n, h_in, w_in, c_in]`` activations of 163 | the previous layer (``n`` being the batch size). ``out_grad`` is the gradient 164 | w.r.t. ``Z``, as computed by ``tf.gradients()``. 165 | 166 | Inputs: 167 | :op: The Conv2D operation 168 | :f: The filter (the tensor, not the variable) 169 | :out_grad: The tensor of gradient w.r.t. to the output of the op 170 | :batch_size: Batch size ``n`` (constant integer or scalar int tf.Tensor) 171 | :mom: Integer moment desired (defaults to 2)""" 172 | 173 | assert op.type == "Conv2D" 174 | assert f is op.inputs[1] 175 | 176 | strides = op.get_attr("strides") 177 | padding = op.get_attr("padding") 178 | use_cudnn = op.get_attr("use_cudnn_on_gpu") 179 | data_format = op.get_attr("data_format") 180 | 181 | inp = op.inputs[0] 182 | inp_pow = tf.pow(inp, mom) 183 | 184 | f_shape = tf.shape(f) 185 | out_grad_pow = tf.pow(out_grad, mom) 186 | 187 | raw_moment = tf.nn.conv2d_backprop_filter(inp_pow, f_shape, out_grad_pow, 188 | strides, padding, use_cudnn, data_format) 189 | return tf.multiply(batch_size, raw_moment) 190 | 191 | def _AddGradMom(op, b, out_grad, batch_size, mom=2): 192 | """Computes gradient moment for a bias variable through an Add operation. 193 | 194 | Assumes ``Z = tf.add(Zz, b)``, where ``b`` is a bias parameter and ``Zz`` is 195 | a ``[n, ?]`` tensor (``n`` being the batch size). Broadcasting for all kinds 196 | of shapes of ``Zz`` (e.g. ``[n, d_in]`` or ``[n, h_in, w_in, c_in]`` are 197 | supported. ``out_grad`` is the gradient w.r.t. ``Z``, as computed by 198 | ``tf.gradients()``. 199 | 200 | Inputs: 201 | :op: The Add operation 202 | :b: The bias parameter (the tensor, not the variable) 203 | :out_grad: The tensor of gradient w.r.t. to the output of the op 204 | :batch_size: Batch size ``n`` (constant integer or scalar int tf.Tensor) 205 | :mom: Integer moment desired (defaults to 2)""" 206 | 207 | assert op.type == "Add" 208 | 209 | out_grad_pow = tf.pow(out_grad, mom) 210 | 211 | if b is op.inputs[0]: 212 | y = op.inputs[1] 213 | sx = tf.shape(b) 214 | sy = tf.shape(y) 215 | rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) 216 | raw_mom = tf.reshape(tf.reduce_sum(out_grad_pow, rx), sx) 217 | elif b is op.inputs[1]: 218 | x = op.inputs[0] 219 | sx = tf.shape(x) 220 | sy = tf.shape(b) 221 | rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) 222 | raw_mom = tf.reshape(tf.reduce_sum(out_grad_pow, ry), sy) 223 | return tf.multiply(batch_size, raw_mom) 224 | --------------------------------------------------------------------------------