├── .gitignore ├── LICENSE ├── README.md ├── main.py ├── yellowfin.py └── yellowfin_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.gz 2 | *.ipynb 3 | data/ 4 | *.pyc 5 | .DS_Store 6 | .idea/ 7 | backup/ 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # YellowFin 2 | 3 | YellowFin is an auto-tuning optimizer based on momentum SGD **which requires no manual specification of learning rate and momentum**. It measures the objective landscape on-the-fly and tune momentum as well as learning rate using local quadratic approximation. 4 | 5 | The implementation here can be **a drop-in replacement for any optimizer in MXNet** (So far we only implemented and tested upon SGD and other optimizers are in the to-do list). 6 | 7 | For more technical details, please refer to the paper [YellowFin and the Art of Momentum Tuning](https://arxiv.org/abs/1706.03471). 8 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from yellowfin import * 2 | 3 | mnist = mx.test_utils.get_mnist() 4 | batch_size = 100 5 | # train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True) 6 | # val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) 7 | 8 | train_iter = mx.io.MNISTIter( 9 | image="data/train-images-idx3-ubyte", 10 | label="data/train-labels-idx1-ubyte", 11 | data_shape=(28, 28), #data_shape=(784,), 12 | batch_size=batch_size, shuffle=False, flat=False, silent=False, seed=10) 13 | val_iter = mx.io.MNISTIter( 14 | image="data/t10k-images-idx3-ubyte", 15 | label="data/t10k-labels-idx1-ubyte", 16 | data_shape=(28, 28), #data_shape=(784,), 17 | batch_size=batch_size, shuffle=False, flat=False, silent=False) 18 | 19 | data = mx.sym.var('data') 20 | def get_lenet(data): 21 | # first conv layer 22 | conv1 = mx.sym.Convolution(data=data, kernel=(5,5), num_filter=20) 23 | tanh1 = mx.sym.Activation(data=conv1, act_type="tanh") 24 | pool1 = mx.sym.Pooling(data=tanh1, pool_type="max", kernel=(2,2), stride=(2,2)) 25 | # second conv layer 26 | conv2 = mx.sym.Convolution(data=pool1, kernel=(5,5), num_filter=50) 27 | tanh2 = mx.sym.Activation(data=conv2, act_type="tanh") 28 | pool2 = mx.sym.Pooling(data=tanh2, pool_type="max", kernel=(2,2), stride=(2,2)) 29 | # first fullc layer 30 | flatten = mx.sym.flatten(data=pool2) 31 | fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500) 32 | tanh3 = mx.sym.Activation(data=fc1, act_type="tanh") 33 | # second fullc 34 | fc2 = mx.sym.FullyConnected(data=tanh3, num_hidden=10) 35 | # softmax loss 36 | lenet = mx.sym.SoftmaxOutput(data=fc2, name='softmax') 37 | return lenet 38 | 39 | def get_mlp(data): 40 | # The first fully-connected layer and the corresponding activation function 41 | fc1 = mx.sym.FullyConnected(data=data, num_hidden=128) 42 | act1 = mx.sym.Activation(data=fc1, act_type="relu") 43 | 44 | # The second fully-connected layer and the corresponding activation function 45 | fc2 = mx.sym.FullyConnected(data=act1, num_hidden = 64) 46 | act2 = mx.sym.Activation(data=fc2, act_type="relu") 47 | # MNIST has 10 classes 48 | fc3 = mx.sym.FullyConnected(data=act2, num_hidden=10) 49 | # Softmax with cross entropy loss 50 | mlp = mx.sym.SoftmaxOutput(data=fc3, name='softmax') 51 | return mlp 52 | 53 | model = get_mlp(data) 54 | import logging 55 | logging.getLogger().setLevel(logging.DEBUG) # logging to stdout 56 | # create a trainable module on CPU 57 | model = mx.mod.Module(symbol=model, context=mx.cpu()) 58 | 59 | initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) 60 | 61 | model.fit(train_iter, # train data 62 | eval_data=val_iter, # validation data 63 | optimizer='YFOptimizer', # use SGD to train 64 | optimizer_params={'rescale_grad': 1./batch_size, 'learning_rate':0.1, 'momentum': 0.0, 'zero_debias': False}, # use fixed learning rate 65 | eval_metric='acc', # report accuracy during training 66 | batch_end_callback = mx.callback.Speedometer(batch_size, 100), # output progress for each 100 data batches 67 | initializer=initializer, 68 | num_epoch=1) # train for at most 10 dataset passes 69 | 70 | test_iter = mx.io.NDArrayIter(mnist['test_data'], None, batch_size) 71 | prob = model.predict(test_iter) 72 | test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) 73 | # predict accuracy for mlp 74 | acc = mx.metric.Accuracy() 75 | model.score(test_iter, acc) 76 | print(acc) 77 | 78 | # assert acc.get()[1] > 0.98 -------------------------------------------------------------------------------- /yellowfin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import mxnet as mx 4 | 5 | @mx.optimizer.Optimizer.register 6 | class YFOptimizer(mx.optimizer.Optimizer): 7 | """The YF optimizer built upon SGD optimizer with momentum and weight decay. 8 | The optimizer updates the weight by:: 9 | state = momentum * state + lr * rescale_grad * clip(grad, clip_gradient) + wd * weight 10 | weight = weight - state 11 | For details of the update algorithm see :class:`~mxnet.ndarray.sgd_update` and 12 | :class:`~mxnet.ndarray.sgd_mom_update`. 13 | This optimizer accepts the following parameters in addition to those accepted 14 | by :class:`.Optimizer`. 15 | Parameters 16 | ---------- 17 | momentum : float, optional 18 | The initial momentum value. 19 | beta : float, optional 20 | The smoothing parameter for estimations. 21 | curv_win_width: int, optional 22 | 23 | zero_bias: bool, optional 24 | """ 25 | 26 | def __init__(self, momentum=0.0, beta=0.999, curv_win_width=20, zero_debias=True, **kwargs): 27 | super(YFOptimizer, self).__init__(**kwargs) 28 | self.momentum = momentum 29 | self.beta = beta 30 | self.curv_win_width = 20 31 | self.zero_debias = zero_debias 32 | # The following are global states for YF tuner 33 | # 1. Calculate grad norm for all indices 34 | self._grad_norm = None 35 | # 2. Calculate grad norm squared for all indices 36 | self._grad_norm_squared = None 37 | # 3. Update state parameters for YF after each iteration 38 | # a. Used in curvature estimation 39 | self._h_min = 0.0 40 | self._h_max = 0.0 41 | self._h_window = np.zeros(curv_win_width) 42 | # b. Used in grad_variance 43 | self._grad_var = None 44 | # c. Used in distance to opt. estimation 45 | self._grad_norm_avg = 0.0 46 | self._grad_norm_squared_avg = 0.0 47 | self._h_avg = 0.0 48 | self._dist_to_opt_avg = 0.0 49 | # For testing purpose only 50 | self._test_res = [] 51 | 52 | def create_state(self, index, weight): 53 | momentum = mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype) 54 | grad_avg = mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype) 55 | grad_avg_squared = mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype) 56 | return momentum, grad_avg, grad_avg_squared 57 | 58 | def zero_debias_factor(self): 59 | if not self.zero_debias: 60 | return 1.0 61 | return 1.0 - self.beta ** (self.num_update) 62 | 63 | def clear_grad_norm_info(self): 64 | # self._grad_norm = None 65 | self._grad_norm_squared = None 66 | self._grad_var = None 67 | 68 | def update_grad_norm_and_var(self, index, grad, state): 69 | _, grad_avg, grad_avg_squared = state 70 | # _, grad_avg = state 71 | grad_avg[:] = self.beta * grad_avg + (1. - self.beta) * grad 72 | grad_avg_squared[:] = self.beta * grad_avg_squared + (1. - self.beta) * mx.nd.square(grad) 73 | 74 | # grad_norm_squared = sum(grad * grad) 75 | grad_norm_squared = mx.ndarray.sum(grad * grad) 76 | # print(grad_norm_squared.shape) 77 | if self._grad_norm_squared is None: 78 | self._grad_norm_squared = grad_norm_squared 79 | else: 80 | self._grad_norm_squared += grad_norm_squared 81 | 82 | if self._grad_var is None: 83 | self._grad_var = mx.ndarray.sum(grad_avg * grad_avg) 84 | else: 85 | self._grad_var += mx.ndarray.sum(grad_avg * grad_avg) 86 | 87 | def curvature_range(self): 88 | curv_win = self._h_window 89 | beta = self.beta 90 | curv_win[(self.num_update-1) % self.curv_win_width] = self._grad_norm_squared 91 | valid_end = min(self.curv_win_width, self.num_update) 92 | self._h_min = beta * self._h_min + (1 - beta) * curv_win[:valid_end].min() 93 | self._h_max = beta * self._h_max + (1 - beta) * curv_win[:valid_end].max() 94 | debias_factor = self.zero_debias_factor() 95 | return self._h_min / debias_factor, self._h_max / debias_factor 96 | 97 | def grad_variance(self): 98 | debias_factor = self.zero_debias_factor() 99 | self._grad_var /= -(debias_factor ** 2) 100 | self._grad_var += self._grad_norm_squared_avg/debias_factor 101 | return self._grad_var 102 | 103 | def dist_to_opt(self): 104 | beta = self.beta 105 | self._grad_norm_avg = beta * self._grad_norm_avg + (1 - beta) * math.sqrt(self._grad_norm_squared) 106 | self._dist_to_opt_avg = beta * self._dist_to_opt_avg + (1 - beta) * self._grad_norm_avg / self._grad_norm_squared_avg 107 | debias_factor = self.zero_debias_factor() 108 | return self._dist_to_opt_avg / debias_factor 109 | 110 | def single_step_mu_lr(self, C, D, h_min, h_max): 111 | coef = np.array([-1.0, 3.0, 0.0, 1.0]) 112 | coef[2] = -(3 + D ** 2 * h_min ** 2 / 2 / C) 113 | roots = np.roots(coef) 114 | root = roots[np.logical_and(np.logical_and(np.real(roots) > 0.0, 115 | np.real(roots) < 1.0), np.imag(roots) < 1e-5)] 116 | assert root.size == 1 117 | dr = h_max / h_min 118 | mu_t = max(np.real(root)[0] ** 2, ((np.sqrt(dr) - 1) / (np.sqrt(dr) + 1)) ** 2) 119 | lr_t = (1.0 - math.sqrt(mu_t)) ** 2 / h_min 120 | return mu_t, lr_t 121 | 122 | def after_apply(self): 123 | beta = self.beta 124 | 125 | self._grad_norm_squared = self._grad_norm_squared.asscalar() 126 | self._grad_norm_squared_avg = self.beta * self._grad_norm_squared_avg + (1 - self.beta) * self._grad_norm_squared 127 | 128 | h_min, h_max = self.curvature_range() 129 | C = self.grad_variance().asscalar() 130 | D = self.dist_to_opt() 131 | if self.num_update > 1: 132 | mu_t, lr_t = self.single_step_mu_lr(C, D, h_min, h_max) 133 | self.momentum = beta * self.momentum + (1 - beta) * mu_t 134 | self.lr = beta * self.lr + (1 - beta) * lr_t 135 | self._test_res = [h_max, h_min, C, D, self.lr, self.momentum] 136 | self.clear_grad_norm_info() 137 | 138 | def is_end_iter(self): 139 | if (self.num_update == 1) and (len(self._index_update_count) == len(self.idx2name)): 140 | return True 141 | elif (self.num_update > 1) and (np.min(self._index_update_count.values()) == self.num_update): 142 | return True 143 | else: 144 | return False 145 | 146 | def update(self, index, weight, grad, state): 147 | assert (isinstance(weight, mx.nd.NDArray)) 148 | assert (isinstance(grad, mx.nd.NDArray)) 149 | lr = self._get_lr(index) 150 | wd = self._get_wd(index) 151 | momentum = self.momentum 152 | self._update_count(index) 153 | 154 | kwargs = {'rescale_grad': self.rescale_grad} 155 | if self.momentum > 0: 156 | kwargs['momentum'] = momentum 157 | if self.clip_gradient: 158 | kwargs['clip_gradient'] = self.clip_gradient 159 | 160 | if state is not None: 161 | mx.optimizer.sgd_mom_update(weight, grad, state[0], out=weight, 162 | lr=lr, wd=wd, **kwargs) 163 | self.update_grad_norm_and_var(index, grad*self.rescale_grad, state) 164 | if self.is_end_iter(): 165 | self.after_apply() 166 | else: 167 | mx.optimizer.sgd_update(weight, grad, out=weight, 168 | lr=lr, wd=wd, **kwargs) 169 | -------------------------------------------------------------------------------- /yellowfin_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import mxnet as mx 3 | import numpy as np 4 | from yellowfin import * 5 | import time 6 | 7 | 8 | n_dim = 100 9 | n_iter = 100 10 | 11 | class MyLoss(mx.operator.NumpyOp): 12 | def __init__(self): 13 | super(MyLoss, self).__init__(False) 14 | 15 | def list_arguments(self): 16 | return ['data', 'label'] 17 | 18 | def list_outputs(self): 19 | return ['output'] 20 | 21 | def infer_shape(self, in_shape): 22 | data_shape = in_shape[0] 23 | label_shape = (in_shape[0][0],) 24 | output_shape = in_shape[0] 25 | return [data_shape, label_shape], [output_shape] 26 | 27 | def forward(self, in_data, out_data): 28 | x = in_data[0] 29 | y = out_data[0] 30 | y[:] = x 31 | 32 | def backward(self, out_grad, in_data, out_data, in_grad): 33 | dx = in_grad[0] 34 | dx[:] = np.ones(dx.shape) 35 | 36 | @mx.initializer.register 37 | class CustomInit(mx.initializer.Initializer): 38 | """Initializes the weights to a scalar value. 39 | Parameters 40 | ---------- 41 | value : float 42 | Fill value. 43 | """ 44 | def __init__(self): 45 | super(CustomInit, self).__init__() 46 | 47 | def _init_weight(self, _, arr): 48 | arr[:] = 1 49 | 50 | def _init_bias(self, _, arr): 51 | arr[:] = 1 52 | 53 | def tune_everything(x0squared, C, T, gmin, gmax): 54 | # First tune based on dynamic range 55 | if C == 0: 56 | dr = gmax / gmin 57 | mustar = ((np.sqrt(dr) - 1) / (np.sqrt(dr) + 1)) ** 2 58 | alpha_star = (1 + np.sqrt(mustar)) ** 2 / gmax 59 | 60 | return alpha_star, mustar 61 | 62 | dist_to_opt = x0squared 63 | grad_var = C 64 | max_curv = gmax 65 | min_curv = gmin 66 | const_fact = dist_to_opt * min_curv ** 2 / 2 / grad_var 67 | coef = [-1, 3, -(3 + const_fact), 1] 68 | roots = np.roots(coef) 69 | roots = roots[np.real(roots) > 0] 70 | roots = roots[np.real(roots) < 1] 71 | root = roots[np.argmin(np.imag(roots))] 72 | 73 | assert root > 0 and root < 1 and np.absolute(root.imag) < 1e-6 74 | 75 | dr = max_curv / min_curv 76 | assert max_curv >= min_curv 77 | mu = max(((np.sqrt(dr) - 1) / (np.sqrt(dr) + 1)) ** 2, root ** 2) 78 | 79 | lr_min = (1 - np.sqrt(mu)) ** 2 / min_curv 80 | lr_max = (1 + np.sqrt(mu)) ** 2 / max_curv 81 | 82 | alpha_star = lr_min 83 | mustar = mu 84 | 85 | return alpha_star, mustar 86 | 87 | def test_measurement(zero_debias=True): 88 | 89 | data = np.array([np.ones(n_dim)]) 90 | label = np.array([0]) 91 | batch_size = 1 92 | train_iter = mx.io.NDArrayIter(data, label, batch_size, label_name='linear_output_label') 93 | 94 | net = mx.sym.Variable('data') 95 | weight = mx.sym.Variable(name='fc1_weight') 96 | bias = mx.sym.Variable(name='fc1_bias') 97 | net = mx.sym.FullyConnected(data=net, weight=weight, bias=bias, name='fc1', num_hidden=1) 98 | myloss = MyLoss() 99 | net = myloss(data=net, name='linear_output') 100 | 101 | mod = mx.mod.Module(symbol=net, 102 | context=mx.cpu(), 103 | data_names=['data'], 104 | label_names=['linear_output_label']) 105 | # allocate memory given the input data and label shapes 106 | mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label) 107 | # initialize parameters by uniform random numbers 108 | mod.init_params(CustomInit()) 109 | # use SGD with learning rate 0.1 to train 110 | mod.init_optimizer(optimizer='YFOptimizer', optimizer_params=(('learning_rate', 1.0), ('momentum', 0.0), ('zero_debias', zero_debias))) 111 | # use accuracy as the metric 112 | metric = mx.metric.create('mse') 113 | 114 | target_h_max = 0.0 115 | target_h_min = 0.0 116 | g_norm_squared_avg = 0.0 117 | g_norm_avg = 0.0 118 | g_avg = 0.0 119 | target_dist = 0.0 120 | 121 | for epoch in range(n_iter): 122 | train_iter.reset() 123 | metric.reset() 124 | for batch in train_iter: 125 | i = epoch 126 | mod.forward(batch, is_train=True) # compute predictions 127 | mod.update_metric(metric, batch.label) # accumulate prediction MSE 128 | mod.backward() # compute gradients 129 | mod._exec_group.grad_arrays[0][0] *= i + 1 130 | mod._exec_group.grad_arrays[1][0] *= i + 1 131 | mod.update() # update parameters 132 | 133 | res = mod._optimizer._test_res 134 | 135 | g_norm_squared_avg = 0.999 * g_norm_squared_avg \ 136 | + 0.001 * np.sum(((i + 1) * np.ones([n_dim + 1, ])) ** 2) 137 | g_norm_avg = 0.999 * g_norm_avg \ 138 | + 0.001 * np.linalg.norm((i + 1) * np.ones([n_dim + 1, ])) 139 | g_avg = 0.999 * g_avg + 0.001 * (i + 1) 140 | 141 | target_h_max = 0.999 * target_h_max + 0.001 * (i + 1) ** 2 * (n_dim + 1) 142 | target_h_min = 0.999 * target_h_min + 0.001 * max(1, i + 2 - 20) ** 2 * (n_dim + 1) 143 | if zero_debias: 144 | target_var = g_norm_squared_avg / (1 - 0.999 ** (i + 1)) \ 145 | - g_avg ** 2 * (n_dim + 1) / (1 - 0.999 ** (i + 1)) ** 2 146 | else: 147 | target_var = g_norm_squared_avg - g_avg ** 2 * (n_dim + 1) 148 | target_dist = 0.999 * target_dist + 0.001 * g_norm_avg / g_norm_squared_avg 149 | 150 | if i == 0: 151 | continue 152 | if zero_debias: 153 | # print "iter ", i, " h max ", res[0], target_h_max/(1-0.999**(i + 1) ), \ 154 | # " h min ", res[1], target_h_min/(1-0.999**(i + 1) ), \ 155 | # " var ", res[2], target_var, \ 156 | # " dist ", res[3], target_dist/(1-0.999**(i + 1) ) 157 | assert np.abs(target_h_max / (1 - 0.999 ** (i + 1)) - res[0]) < np.abs(res[0]) * 1e-3 158 | assert np.abs(target_h_min / (1 - 0.999 ** (i + 1)) - res[1]) < np.abs(res[1]) * 1e-3 159 | assert np.abs(target_var - res[2]) < np.abs(target_var) * 1e-3 160 | assert np.abs(target_dist / (1 - 0.999 ** (i + 1)) - res[3]) < np.abs(res[3]) * 1e-3 161 | else: 162 | # print "iter ", i, " h max ", res[0], target_h_max, " h min ", res[1], target_h_min, \ 163 | # " var ", res[2], target_var, " dist ", res[3], target_dist 164 | assert np.abs(target_h_max - res[0]) < np.abs(target_h_max) * 1e-3 165 | assert np.abs(target_h_min - res[1]) < np.abs(target_h_min) * 1e-3 166 | assert np.abs(target_var - res[2]) < np.abs(res[2]) * 1e-3 167 | assert np.abs(target_dist - res[3]) < np.abs(res[3]) * 1e-3 168 | 169 | print "sync measurement test passed!" 170 | 171 | 172 | def test_lr_mu(zero_debias=False): 173 | 174 | data = np.array([np.ones(n_dim)]) 175 | label = np.array([0]) 176 | batch_size = 1 177 | train_iter = mx.io.NDArrayIter(data, label, batch_size, label_name='linear_output_label') 178 | 179 | net = mx.sym.Variable('data') 180 | weight = mx.sym.Variable(name='fc1_weight') 181 | bias = mx.sym.Variable(name='fc1_bias') 182 | net = mx.sym.FullyConnected(data=net, weight=weight, bias=bias, name='fc1', num_hidden=1) 183 | myloss = MyLoss() 184 | net = myloss(data=net, name='linear_output') 185 | 186 | mod = mx.mod.Module(symbol=net, 187 | context=mx.cpu(), 188 | data_names=['data'], 189 | label_names=['linear_output_label']) 190 | # allocate memory given the input data and label shapes 191 | mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label) 192 | # initialize parameters by uniform random numbers 193 | mod.init_params(CustomInit()) 194 | # use SGD with learning rate 0.1 to train 195 | mod.init_optimizer(optimizer='YFOptimizer', optimizer_params=(('learning_rate', 1.0), ('momentum', 0.0), ('zero_debias', zero_debias))) 196 | # use accuracy as the metric 197 | metric = mx.metric.create('mse') 198 | 199 | target_h_max = 0.0 200 | target_h_min = 0.0 201 | g_norm_squared_avg = 0.0 202 | g_norm_avg = 0.0 203 | g_avg = 0.0 204 | target_dist = 0.0 205 | target_lr = 1.0 206 | target_mu = 0.0 207 | 208 | for epoch in range(n_iter): 209 | train_iter.reset() 210 | metric.reset() 211 | for batch in train_iter: 212 | i = epoch 213 | 214 | mod.forward(batch, is_train=True) # compute predictions 215 | mod.update_metric(metric, batch.label) # accumulate prediction MSE 216 | mod.backward() # compute gradients 217 | mod._exec_group.grad_arrays[0][0] *= i + 1 218 | mod._exec_group.grad_arrays[1][0] *= i + 1 219 | mod.update() # update parameters 220 | 221 | res = mod._optimizer._test_res 222 | 223 | g_norm_squared_avg = 0.999 * g_norm_squared_avg \ 224 | + 0.001 * np.sum(((i + 1) * np.ones([n_dim + 1, ])) ** 2) 225 | g_norm_avg = 0.999 * g_norm_avg \ 226 | + 0.001 * np.linalg.norm((i + 1) * np.ones([n_dim + 1, ])) 227 | g_avg = 0.999 * g_avg + 0.001 * (i + 1) 228 | 229 | target_h_max = 0.999 * target_h_max + 0.001 * (i + 1) ** 2 * (n_dim + 1) 230 | target_h_min = 0.999 * target_h_min + 0.001 * max(1, i + 2 - 20) ** 2 * (n_dim + 1) 231 | if zero_debias: 232 | target_var = g_norm_squared_avg / (1 - 0.999 ** (i + 1)) \ 233 | - g_avg ** 2 * (n_dim + 1) / (1 - 0.999 ** (i + 1)) ** 2 234 | else: 235 | target_var = g_norm_squared_avg - g_avg ** 2 * (n_dim + 1) 236 | target_dist = 0.999 * target_dist + 0.001 * g_norm_avg / g_norm_squared_avg 237 | 238 | if i == 0: 239 | continue 240 | if zero_debias: 241 | # print "iter ", i, " h max ", res[0], target_h_max/(1-0.999**(i + 1) ), \ 242 | # " h min ", res[1], target_h_min/(1-0.999**(i + 1) ), \ 243 | # " var ", res[2], target_var, \ 244 | # " dist ", res[3], target_dist/(1-0.999**(i + 1) ) 245 | assert np.abs(target_h_max / (1 - 0.999 ** (i + 1)) - res[0]) < np.abs(res[0]) * 1e-3 246 | assert np.abs(target_h_min / (1 - 0.999 ** (i + 1)) - res[1]) < np.abs(res[1]) * 1e-3 247 | assert np.abs(target_var - res[2]) < np.abs(target_var) * 1e-3 248 | assert np.abs(target_dist / (1 - 0.999 ** (i + 1)) - res[3]) < np.abs(res[3]) * 1e-3 249 | else: 250 | # print "iter ", i, " h max ", res[0], target_h_max, " h min ", res[1], target_h_min, \ 251 | # " var ", res[2], target_var, " dist ", res[3], target_dist 252 | assert np.abs(target_h_max - res[0]) < np.abs(target_h_max) * 1e-3 253 | assert np.abs(target_h_min - res[1]) < np.abs(target_h_min) * 1e-3 254 | assert np.abs(target_var - res[2]) < np.abs(res[2]) * 1e-3 255 | assert np.abs(target_dist - res[3]) < np.abs(res[3]) * 1e-3 256 | 257 | if i > 0: 258 | if zero_debias: 259 | lr, mu = tune_everything((target_dist / (1 - 0.999 ** (i + 1))) ** 2, 260 | target_var, 1, target_h_min / (1 - 0.999 ** (i + 1)), 261 | target_h_max / (1 - 0.999 ** (i + 1))) 262 | else: 263 | lr, mu = tune_everything(target_dist ** 2, target_var, 1, target_h_min, target_h_max) 264 | lr = np.real(lr) 265 | mu = np.real(mu) 266 | target_lr = 0.999 * target_lr + 0.001 * lr 267 | target_mu = 0.999 * target_mu + 0.001 * mu 268 | print "lr ", target_lr, res[4], " mu ", target_mu, res[5] 269 | assert target_lr == 0.0 or np.abs(target_lr - res[4]) < np.abs(res[4]) * 1e-3 270 | assert target_mu == 0.0 or np.abs(target_mu - res[5]) < np.abs(res[5]) * 5e-3 271 | print "lr and mu computing test passed!" 272 | 273 | 274 | if __name__ == "__main__": 275 | start = time.time() 276 | test_measurement(zero_debias=False) 277 | end = time.time() 278 | print "measurement test without zero_debias done in ", (end - start) / float(n_iter), " s/iter!" 279 | 280 | start = time.time() 281 | test_measurement(zero_debias=True) 282 | end = time.time() 283 | print "measurement test with zero_debias done in ", (end - start) / float(n_iter), " s/iter!" 284 | 285 | start = time.time() 286 | test_lr_mu(zero_debias=False) 287 | end = time.time() 288 | print "lr and mu test done without zero_debias in ", (end - start) / float(n_iter), " s/iter!" 289 | 290 | start = time.time() 291 | test_lr_mu(zero_debias=True) 292 | end = time.time() 293 | print "lr and mu test done with zero_debias in ", (end - start) / float(n_iter), " s/iter!" 294 | --------------------------------------------------------------------------------