├── .DS_Store ├── .gitignore ├── LICENSE ├── RAdam.py ├── README.md └── assets ├── alg.png └── result.png /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/RAdam-Tensorflow/c990c86158373d3d0d4083a47983d12ee57b1d0b/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Junho Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /RAdam.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.eager import context 3 | from tensorflow.python.framework import ops 4 | from tensorflow.python.ops import control_flow_ops 5 | from tensorflow.python.ops import math_ops 6 | from tensorflow.python.ops import resource_variable_ops 7 | from tensorflow.python.ops import state_ops 8 | from tensorflow.python.training import optimizer 9 | 10 | class RAdamOptimizer(optimizer.Optimizer): 11 | 12 | """ 13 | RAdam optimizer : On The Variance Of The Adaptive Learning Rate And Beyond 14 | https://arxiv.org/abs/1908.03265 15 | """ 16 | 17 | def __init__(self, 18 | learning_rate=0.001, 19 | beta1=0.9, 20 | beta2=0.999, 21 | epsilon=1e-8, 22 | weight_decay=0., 23 | use_locking=False, 24 | name="RAdam"): 25 | 26 | super(RAdamOptimizer, self).__init__(use_locking, name) 27 | self._lr = learning_rate 28 | self._beta1 = beta1 29 | self._beta2 = beta2 30 | self._epsilon = epsilon 31 | self._weight_decay = weight_decay 32 | 33 | self._lr_t = None 34 | self._step_t = None 35 | self._beta1_t = None 36 | self._beta2_t = None 37 | self._epsilon_t = None 38 | self._weight_decay_t = None 39 | 40 | def _get_beta_accumulators(self): 41 | with ops.init_scope(): 42 | if context.executing_eagerly(): 43 | graph = None 44 | else: 45 | graph = ops.get_default_graph() 46 | return (self._get_non_slot_variable("step", graph=graph), 47 | self._get_non_slot_variable("beta1_power", graph=graph), 48 | self._get_non_slot_variable("beta2_power", graph=graph)) 49 | 50 | def _create_slots(self, var_list): 51 | first_var = min(var_list, key=lambda x: x.name) 52 | self._create_non_slot_variable(initial_value=1.0, name="step", colocate_with=first_var) 53 | self._create_non_slot_variable(initial_value=self._beta1, name="beta1_power", colocate_with=first_var) 54 | self._create_non_slot_variable(initial_value=self._beta2, name="beta2_power", colocate_with=first_var) 55 | 56 | for v in var_list: 57 | self._zeros_slot(v, "m", self._name) 58 | self._zeros_slot(v, "v", self._name) 59 | 60 | def _prepare(self): 61 | lr = self._call_if_callable(self._lr) 62 | beta1 = self._call_if_callable(self._beta1) 63 | beta2 = self._call_if_callable(self._beta2) 64 | epsilon = self._call_if_callable(self._epsilon) 65 | weight_decay = self._call_if_callable(self._weight_decay) 66 | 67 | self._lr_t = ops.convert_to_tensor(lr, name="learning_rate") 68 | self._beta1_t = ops.convert_to_tensor(beta1, name="beta1") 69 | self._beta2_t = ops.convert_to_tensor(beta2, name="beta2") 70 | self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon") 71 | self._weight_decay_t = ops.convert_to_tensor(weight_decay, name="weight_decay") 72 | 73 | def _apply_dense(self, grad, var): 74 | return self._resource_apply_dense(grad, var) 75 | 76 | def _resource_apply_dense(self, grad, var): 77 | step, beta1_power, beta2_power = self._get_beta_accumulators() 78 | step = math_ops.cast(step, var.dtype.base_dtype) 79 | beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) 80 | beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) 81 | lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) 82 | 83 | beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) 84 | beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) 85 | epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) 86 | 87 | sma_inf = 2.0 / (1.0 - beta2_t) - 1.0 88 | sma_t = sma_inf - 2.0 * step * beta2_power / (1.0 - beta2_power) 89 | 90 | m = self.get_slot(var, "m") 91 | m_t = state_ops.assign(m, beta1_t * m + (1.0 - beta1_t) * grad, use_locking=self._use_locking) 92 | mhat_t = m_t / (1.0 - beta1_power) 93 | 94 | v = self.get_slot(var, "v") 95 | v_t = state_ops.assign(v, beta2_t * v + (1.0 - beta2_t) * math_ops.square(grad), use_locking=self._use_locking) 96 | vhat_t = math_ops.sqrt(v_t / ((1.0 - beta2_power) + epsilon_t)) 97 | 98 | r_t = math_ops.sqrt( ((sma_t - 4.0) * (sma_t - 2.0) * sma_inf) / ((sma_inf - 4.0) * (sma_inf - 2.0) * sma_t) ) 99 | 100 | var_t = tf.cond(sma_t >= 5.0, lambda : r_t * mhat_t / (vhat_t + epsilon_t), lambda : mhat_t) 101 | 102 | if self._weight_decay > 0.0: 103 | var_t += math_ops.cast(self._weight_decay_t, var.dtype.base_dtype) * var 104 | 105 | var_update = state_ops.assign_sub(var, lr_t * var_t, use_locking=self._use_locking) 106 | 107 | updates = [var_update, m_t, v_t] 108 | 109 | return control_flow_ops.group(*updates) 110 | 111 | def _apply_sparse_shared(self, grad, var, indices, scatter_add): 112 | step, beta1_power, beta2_power = self._get_beta_accumulators() 113 | step = math_ops.cast(step, var.dtype.base_dtype) 114 | beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) 115 | beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) 116 | lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) 117 | 118 | beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) 119 | beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) 120 | epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) 121 | 122 | sma_inf = 2.0 / (1.0 - beta2_t) - 1.0 123 | sma_t = sma_inf - 2.0 * step * beta2_power / (1.0 - beta2_power) 124 | 125 | m = self.get_slot(var, "m") 126 | m_scaled_g_values = grad * (1 - beta1_t) 127 | m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking) 128 | 129 | with ops.control_dependencies([m_t]): 130 | m_t = scatter_add(m, indices, m_scaled_g_values) 131 | 132 | mhat_t = m_t / (1.0 - beta1_power) 133 | 134 | v = self.get_slot(var, "v") 135 | v_scaled_g_values = (grad * grad) * (1 - beta2_t) 136 | v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) 137 | 138 | with ops.control_dependencies([v_t]): 139 | v_t = scatter_add(v, indices, v_scaled_g_values) 140 | 141 | vhat_t = math_ops.sqrt(v_t / (1.0 - beta2_power) + epsilon_t) 142 | 143 | r_t = math_ops.sqrt( ((sma_t - 4.0) * (sma_t - 2.0) * sma_inf) / ((sma_inf - 4.0) * (sma_inf - 2.0) * sma_t) ) 144 | 145 | var_t = tf.cond(sma_t >= 5.0, lambda : r_t * mhat_t / (vhat_t + epsilon_t), lambda : mhat_t) 146 | 147 | if self._weight_decay > 0.0: 148 | var_t += math_ops.cast(self._weight_decay_t, var.dtype.base_dtype) * var 149 | 150 | var_update = state_ops.assign_sub(var, lr_t * var_t, use_locking=self._use_locking) 151 | 152 | updates = [var_update, m_t, v_t] 153 | 154 | return control_flow_ops.group(*updates) 155 | 156 | def _apply_sparse(self, grad, var): 157 | return self._apply_sparse_shared( 158 | grad.values, 159 | var, 160 | grad.indices, 161 | lambda x, i, v: state_ops.scatter_add(x, i, v, use_locking=self._use_locking)) 162 | 163 | def _resource_scatter_add(self, x, i, v): 164 | with ops.control_dependencies([resource_variable_ops.resource_scatter_add(x.handle, i, v)]): 165 | return x.value() 166 | 167 | def _resource_apply_sparse(self, grad, var, indices): 168 | return self._apply_sparse_shared(grad, var, indices, self._resource_scatter_add) 169 | 170 | def _finish(self, update_ops, name_scope): 171 | with ops.control_dependencies(update_ops): 172 | step, beta1_power, beta2_power = self._get_beta_accumulators() 173 | with ops.colocate_with(beta1_power): 174 | update_step = step.assign(step + 1.0, use_locking=self._use_locking) 175 | update_beta1 = beta1_power.assign(beta1_power * self._beta1_t, use_locking=self._use_locking) 176 | update_beta2 = beta2_power.assign(beta2_power * self._beta2_t, use_locking=self._use_locking) 177 | return control_flow_ops.group(*update_ops + [update_step, update_beta1, update_beta2], name=name_scope) 178 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

RAdam-Tensorflow

2 |

On the Variance of the Adaptive Learning Rate and Beyond

3 | 4 | ### [Paper](https://arxiv.org/abs/1908.03265) | [Official Pytorch code](https://github.com/LiyuanLucasLiu/RAdam) 5 | 6 | ## Usage 7 | ```python 8 | from RAdam import RAdamOptimizer 9 | 10 | train_op = RAdamOptimizer(learning_rate=0.001, beta1=0.9, beta2=0.999, weight_decay=0.0).minimize(loss) 11 | ``` 12 | 13 | ## Algorithm 14 |
15 | 16 |
17 | 18 | 19 | ## Result 20 | ![result](./assets/result.png) 21 | 22 | ## Author 23 | [Junho Kim](http://bit.ly/jhkim_ai) 24 | -------------------------------------------------------------------------------- /assets/alg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/RAdam-Tensorflow/c990c86158373d3d0d4083a47983d12ee57b1d0b/assets/alg.png -------------------------------------------------------------------------------- /assets/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/RAdam-Tensorflow/c990c86158373d3d0d4083a47983d12ee57b1d0b/assets/result.png --------------------------------------------------------------------------------