├── .gitignore ├── LICENSE ├── README.md ├── adamp_tf.py └── sgdp_tf.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Junho Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## AdamP Optimizer — Unofficial TensorFlow Implementation 2 | ### "Slowing Down the Weight Norm Increase in Momentum-based Optimizers" 3 | ## Implemented by [Junho Kim](http://bit.ly/jhkim_ai) 4 | ### [[Paper]](https://arxiv.org/abs/2006.08217) [[Project page]](https://clovaai.github.io/AdamP/) [[Official Pytorch]](https://github.com/clovaai/AdamP) 5 | 6 |
7 | 8 | 9 |
10 | 11 | 12 | ## Validation 13 | I have checked that the code is working, but I couldn't confirm if the performance is the same as the offical code. 14 | 15 | ## Usage 16 | Usage is exactly same as [tf.keras.optimizers](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers) library! 17 | ```python 18 | 19 | from adamp_tf import AdamP 20 | from sgdp_tf import SGDP 21 | 22 | optimizer_adamp = AdamP(learning_rate=0.001, beta_1=0.9, beta_2=0.999, weight_decay=1e-2) 23 | optimizer_sgdp = SGDP(learning_rate=0.1, weight_decay=1e-5, momentum=0.9, nesterov=True) 24 | ``` 25 | * **Do not use with `tf.nn.scale_regularization_loss`.** Use the `weight_decay` argument. 26 | 27 | ## Arguments 28 | `SGDP` and `AdamP` share arguments with [tf.keras.optimizers.SGD](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/SGD) and [tf.keras.optimizers.Adam](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adam). 29 | There are two additional hyperparameters; we recommend using the default values. 30 | - `delta` : threhold that determines whether a set of parameters is scale invariant or not (default: 0.1) 31 | - `wd_ratio` : relative weight decay applied on _scale-invariant_ parameters compared to that applied on _scale-variant_ parameters (default: 0.1) 32 | 33 | Both `SGDP` and `AdamP` support Nesterov momentum. 34 | - `nesterov` : enables Nesterov momentum (default: False) 35 | 36 | ## How to cite 37 | 38 | ``` 39 | @article{heo2020adamp, 40 | title={Slowing Down the Weight Norm Increase in Momentum-based Optimizers}, 41 | author={Heo, Byeongho and Chun, Sanghyuk and Oh, Seong Joon and Han, Dongyoon and Yun, Sangdoo and Uh, Youngjung and Ha, Jung-Woo}, 42 | year={2020}, 43 | journal={arXiv preprint arXiv:2006.08217}, 44 | } 45 | ``` 46 | -------------------------------------------------------------------------------- /adamp_tf.py: -------------------------------------------------------------------------------- 1 | """AdamP for TensorFlow.""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from tensorflow.python.framework import ops 7 | from tensorflow.python.keras import backend_config 8 | from tensorflow.python.keras.optimizer_v2 import optimizer_v2 9 | from tensorflow.python.ops import array_ops 10 | from tensorflow.python.ops import control_flow_ops 11 | from tensorflow.python.ops import math_ops 12 | from tensorflow.python.ops import state_ops 13 | 14 | 15 | class AdamP(optimizer_v2.OptimizerV2): 16 | _HAS_AGGREGATE_GRAD = True 17 | 18 | def __init__(self, 19 | learning_rate=0.001, 20 | beta_1=0.9, 21 | beta_2=0.999, 22 | epsilon=1e-8, 23 | weight_decay=0.0, 24 | delta=0.1, wd_ratio=0.1, nesterov=False, 25 | name='AdamP', 26 | **kwargs): 27 | 28 | super(AdamP, self).__init__(name, **kwargs) 29 | self._set_hyper('learning_rate', kwargs.get('lr', learning_rate)) 30 | self._set_hyper('beta_1', beta_1) 31 | self._set_hyper('beta_2', beta_2) 32 | self._set_hyper('delta', delta) 33 | self._set_hyper('wd_ratio', wd_ratio) 34 | 35 | self.epsilon = epsilon or backend_config.epsilon() 36 | self.weight_decay = weight_decay 37 | self.nesterov = nesterov 38 | 39 | def _create_slots(self, var_list): 40 | # Create slots for the first and second moments. 41 | # Separate for-loops to respect the ordering of slot variables from v1. 42 | for var in var_list: 43 | self.add_slot(var, 'm') 44 | for var in var_list: 45 | self.add_slot(var, 'v') 46 | for var in var_list: 47 | self.add_slot(var, 'p') 48 | 49 | def _prepare_local(self, var_device, var_dtype, apply_state): 50 | super(AdamP, self)._prepare_local(var_device, var_dtype, apply_state) 51 | 52 | local_step = math_ops.cast(self.iterations + 1, var_dtype) 53 | beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype)) 54 | beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype)) 55 | beta_1_power = math_ops.pow(beta_1_t, local_step) 56 | beta_2_power = math_ops.pow(beta_2_t, local_step) 57 | 58 | lr = apply_state[(var_device, var_dtype)]['lr_t'] 59 | bias_correction1 = 1 - beta_1_power 60 | bias_correction2 = 1 - beta_2_power 61 | 62 | delta = array_ops.identity(self._get_hyper('delta', var_dtype)) 63 | wd_ratio = array_ops.identity(self._get_hyper('wd_ratio', var_dtype)) 64 | 65 | apply_state[(var_device, var_dtype)].update( 66 | dict( 67 | lr=lr, 68 | epsilon=ops.convert_to_tensor_v2(self.epsilon, var_dtype), 69 | weight_decay=ops.convert_to_tensor_v2(self.weight_decay, var_dtype), 70 | beta_1_t=beta_1_t, 71 | beta_1_power=beta_1_power, 72 | one_minus_beta_1_t=1 - beta_1_t, 73 | beta_2_t=beta_2_t, 74 | beta_2_power=beta_2_power, 75 | one_minus_beta_2_t=1 - beta_2_t, 76 | bias_correction1=bias_correction1, 77 | bias_correction2=bias_correction2, 78 | delta=delta, 79 | wd_ratio=wd_ratio)) 80 | 81 | def set_weights(self, weights): 82 | params = self.weights 83 | # If the weights are generated by Keras V1 optimizer, it includes vhats 84 | # optimizer has 2x + 1 variables. Filter vhats out for compatibility. 85 | num_vars = int((len(params) - 1) / 2) 86 | if len(weights) == 3 * num_vars + 1: 87 | weights = weights[:len(params)] 88 | super(AdamP, self).set_weights(weights) 89 | 90 | def _resource_apply_dense(self, grad, var, apply_state=None): 91 | var_device, var_dtype = var.device, var.dtype.base_dtype 92 | coefficients = ((apply_state or {}).get((var_device, var_dtype)) 93 | or self._fallback_apply_state(var_device, var_dtype)) 94 | 95 | # m_t = beta1 * m + (1 - beta1) * g_t 96 | m = self.get_slot(var, 'm') 97 | m_scaled_g_values = grad * coefficients['one_minus_beta_1_t'] 98 | m_t = state_ops.assign(m, m * coefficients['beta_1_t'] + m_scaled_g_values, use_locking=self._use_locking) 99 | 100 | # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 101 | v = self.get_slot(var, 'v') 102 | v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t'] 103 | v_t = state_ops.assign(v, v * coefficients['beta_2_t'] + v_scaled_g_values, use_locking=self._use_locking) 104 | 105 | denorm = (math_ops.sqrt(v_t) / math_ops.sqrt(coefficients['bias_correction2'])) + coefficients['epsilon'] 106 | step_size = coefficients['lr'] / coefficients['bias_correction1'] 107 | 108 | if self.nesterov: 109 | perturb = (coefficients['beta_1_t'] * m_t + coefficients['one_minus_beta_1_t'] * grad) / denorm 110 | else: 111 | perturb = m_t / denorm 112 | 113 | # Projection 114 | wd_ratio = 1 115 | if len(var.shape) > 1: 116 | perturb, wd_ratio = self._projection(var, grad, perturb, coefficients['delta'], coefficients['wd_ratio'], coefficients['epsilon']) 117 | 118 | # Weight decay 119 | 120 | if self.weight_decay > 0: 121 | var = state_ops.assign(var, var * (1 - coefficients['lr'] * coefficients['weight_decay'] * wd_ratio), use_locking=self._use_locking) 122 | 123 | var_update = state_ops.assign_sub(var, step_size * perturb, use_locking=self._use_locking) 124 | 125 | return control_flow_ops.group(*[var_update, m_t, v_t]) 126 | 127 | 128 | def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 129 | 130 | var_device, var_dtype = var.device, var.dtype.base_dtype 131 | coefficients = ((apply_state or {}).get((var_device, var_dtype)) 132 | or self._fallback_apply_state(var_device, var_dtype)) 133 | """ 134 | Adam 135 | """ 136 | # m_t = beta1 * m + (1 - beta1) * g_t 137 | m = self.get_slot(var, 'm') 138 | m_scaled_g_values = grad * coefficients['one_minus_beta_1_t'] 139 | m_t = state_ops.assign(m, m * coefficients['beta_1_t'], 140 | use_locking=self._use_locking) 141 | with ops.control_dependencies([m_t]): 142 | m_t = self._resource_scatter_add(m, indices, m_scaled_g_values) 143 | 144 | # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 145 | v = self.get_slot(var, 'v') 146 | v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t'] 147 | v_t = state_ops.assign(v, v * coefficients['beta_2_t'], 148 | use_locking=self._use_locking) 149 | with ops.control_dependencies([v_t]): 150 | v_t = self._resource_scatter_add(v, indices, v_scaled_g_values) 151 | 152 | denorm = (math_ops.sqrt(v_t) / math_ops.sqrt(coefficients['bias_correction2'])) + coefficients['epsilon'] 153 | step_size = coefficients['lr'] / coefficients['bias_correction1'] 154 | 155 | if self.nesterov: 156 | p_scaled_g_values = grad * coefficients['one_minus_beta_1_t'] 157 | perturb = m_t * coefficients['beta_1_t'] 158 | perturb = self._resource_scatter_add(perturb, indices, p_scaled_g_values) / denorm 159 | 160 | else: 161 | perturb = m_t / denorm 162 | 163 | # Projection 164 | wd_ratio = 1 165 | if len(var.shape) > 1: 166 | perturb, wd_ratio = self._projection(var, grad, perturb, coefficients['delta'], coefficients['wd_ratio'], coefficients['epsilon']) 167 | 168 | # Weight decay 169 | if self.weight_decay > 0: 170 | var = state_ops.assign(var, var * (1 - coefficients['lr'] * coefficients['weight_decay'] * wd_ratio), use_locking=self._use_locking) 171 | 172 | var_update = state_ops.assign_sub(var, step_size * perturb, use_locking=self._use_locking) 173 | 174 | return control_flow_ops.group(*[var_update, m_t, v_t]) 175 | 176 | def _channel_view(self, x): 177 | return array_ops.reshape(x, shape=[x.shape[0], -1]) 178 | 179 | def _layer_view(self, x): 180 | return array_ops.reshape(x, shape=[1, -1]) 181 | 182 | def _cosine_similarity(self, x, y, eps, view_func): 183 | x = view_func(x) 184 | y = view_func(y) 185 | 186 | x_norm = math_ops.euclidean_norm(x, axis=-1) + eps 187 | y_norm = math_ops.euclidean_norm(y, axis=-1) + eps 188 | dot = math_ops.reduce_sum(x * y, axis=-1) 189 | 190 | return math_ops.abs(dot) / x_norm / y_norm 191 | 192 | def _projection(self, var, grad, perturb, delta, wd_ratio, eps): 193 | # channel_view 194 | cosine_sim = self._cosine_similarity(grad, var, eps, self._channel_view) 195 | cosine_max = math_ops.reduce_max(cosine_sim) 196 | compare_val = delta / math_ops.sqrt(math_ops.cast(self._channel_view(var).shape[-1], dtype=delta.dtype)) 197 | 198 | perturb, wd = control_flow_ops.cond(pred=cosine_max < compare_val, 199 | true_fn=lambda : self.channel_true_fn(var, perturb, wd_ratio, eps), 200 | false_fn=lambda : self.channel_false_fn(var, grad, perturb, delta, wd_ratio, eps)) 201 | 202 | return perturb, wd 203 | 204 | def channel_true_fn(self, var, perturb, wd_ratio, eps): 205 | expand_size = [-1] + [1] * (len(var.shape) - 1) 206 | var_n = var / (array_ops.reshape(math_ops.euclidean_norm(self._channel_view(var), axis=-1), shape=expand_size) + eps) 207 | perturb -= var_n * array_ops.reshape(math_ops.reduce_sum(self._channel_view(var_n * perturb), axis=-1), shape=expand_size) 208 | wd = wd_ratio 209 | 210 | return perturb, wd 211 | 212 | def channel_false_fn(self, var, grad, perturb, delta, wd_ratio, eps): 213 | cosine_sim = self._cosine_similarity(grad, var, eps, self._layer_view) 214 | cosine_max = math_ops.reduce_max(cosine_sim) 215 | compare_val = delta / math_ops.sqrt(math_ops.cast(self._layer_view(var).shape[-1], dtype=delta.dtype)) 216 | 217 | perturb, wd = control_flow_ops.cond(cosine_max < compare_val, 218 | true_fn=lambda : self.layer_true_fn(var, perturb, wd_ratio, eps), 219 | false_fn=lambda : self.identity_fn(perturb)) 220 | 221 | return perturb, wd 222 | 223 | def layer_true_fn(self, var, perturb, wd_ratio, eps): 224 | expand_size = [-1] + [1] * (len(var.shape) - 1) 225 | var_n = var / (array_ops.reshape(math_ops.euclidean_norm(self._layer_view(var), axis=-1), shape=expand_size) + eps) 226 | perturb -= var_n * array_ops.reshape(math_ops.reduce_sum(self._layer_view(var_n * perturb), axis=-1), shape=expand_size) 227 | wd = wd_ratio 228 | 229 | return perturb, wd 230 | 231 | def identity_fn(self, perturb): 232 | wd = 1.0 233 | 234 | return perturb, wd 235 | 236 | def get_config(self): 237 | config = super(AdamP, self).get_config() 238 | config.update({ 239 | 'learning_rate': self._serialize_hyperparameter('learning_rate'), 240 | 'beta_1': self._serialize_hyperparameter('beta_1'), 241 | 'beta_2': self._serialize_hyperparameter('beta_2'), 242 | 'delta': self._serialize_hyperparameter('delta'), 243 | 'wd_ratio': self._serialize_hyperparameter('wd_ratio'), 244 | 'epsilon': self.epsilon, 245 | 'weight_decay': self.weight_decay, 246 | 'nesterov': self.nesterov 247 | }) 248 | return config -------------------------------------------------------------------------------- /sgdp_tf.py: -------------------------------------------------------------------------------- 1 | """ SGDP for TensorFlow.""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from tensorflow.python.framework import ops 7 | from tensorflow.python.keras.optimizer_v2 import optimizer_v2 8 | from tensorflow.python.ops import array_ops 9 | from tensorflow.python.ops import math_ops 10 | from tensorflow.python.ops import state_ops 11 | from tensorflow.python.ops import control_flow_ops 12 | 13 | 14 | class SGDP(optimizer_v2.OptimizerV2): 15 | _HAS_AGGREGATE_GRAD = True 16 | def __init__(self, 17 | learning_rate=0.1, 18 | momentum=0.0, 19 | dampening=0.0, 20 | weight_decay=0.0, 21 | nesterov=False, 22 | epsilon=1e-8, 23 | delta=0.1, 24 | wd_ratio=0.1, 25 | name="SGDP", 26 | **kwargs): 27 | 28 | super(SGDP, self).__init__(name, **kwargs) 29 | self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) 30 | 31 | self._momentum = False 32 | if isinstance(momentum, ops.Tensor) or callable(momentum) or momentum > 0: 33 | self._momentum = True 34 | if isinstance(momentum, (int, float)) and (momentum < 0 or momentum > 1): 35 | raise ValueError("`momentum` must be between [0, 1].") 36 | 37 | self._set_hyper("momentum", momentum) 38 | self._set_hyper("dampening", dampening) 39 | self._set_hyper("epsilon", epsilon) 40 | self._set_hyper("delta", delta) 41 | self._set_hyper("wd_ratio", wd_ratio) 42 | 43 | self.nesterov = nesterov 44 | self.weight_decay = weight_decay 45 | 46 | def _create_slots(self, var_list): 47 | if self._momentum: 48 | for var in var_list: 49 | self.add_slot(var, "momentum") 50 | for var in var_list: 51 | self.add_slot(var, "buf") 52 | 53 | def _prepare_local(self, var_device, var_dtype, apply_state): 54 | super(SGDP, self)._prepare_local(var_device, var_dtype, apply_state) 55 | lr = apply_state[(var_device, var_dtype)]['lr_t'] 56 | 57 | momentum = array_ops.identity(self._get_hyper("momentum", var_dtype)) 58 | dampening = array_ops.identity(self._get_hyper('dampening', var_dtype)) 59 | delta = array_ops.identity(self._get_hyper('delta', var_dtype)) 60 | wd_ratio = array_ops.identity(self._get_hyper('wd_ratio', var_dtype)) 61 | 62 | apply_state[(var_device, var_dtype)].update( 63 | dict( 64 | lr=lr, 65 | epsilon=ops.convert_to_tensor_v2(self.epsilon, var_dtype), 66 | weight_decay=ops.convert_to_tensor_v2(self.weight_decay, var_dtype), 67 | momentum=momentum, 68 | dampening=dampening, 69 | delta=delta, 70 | wd_ratio=wd_ratio)) 71 | 72 | 73 | def _resource_apply_dense(self, grad, var, apply_state=None): 74 | var_device, var_dtype = var.device, var.dtype.base_dtype 75 | coefficients = ((apply_state or {}).get((var_device, var_dtype)) 76 | or self._fallback_apply_state(var_device, var_dtype)) 77 | 78 | buf = self.get_slot(var, 'buf') 79 | b_scaled_g_values = grad * (1 - coefficients['dampening']) 80 | buf_t = state_ops.assign(buf, buf * coefficients['momentum'] + b_scaled_g_values, use_locking=self._use_locking) 81 | 82 | if self.nesterov: 83 | d_p = grad + coefficients['momentum'] * buf_t 84 | else: 85 | d_p = buf_t 86 | 87 | # Projection 88 | wd_ratio = 1 89 | if len(var.shape) > 1: 90 | d_p, wd_ratio = self._projection(var, grad, d_p, coefficients['delta'], coefficients['wd_ratio'], coefficients['epsilon']) 91 | 92 | # Weight decay 93 | if self.weight_decay > 0: 94 | var = state_ops.assign(var, var * (1 - coefficients['lr'] * coefficients['weight_decay'] * wd_ratio), use_locking=self._use_locking) 95 | 96 | var_update = state_ops.assign_sub(var, coefficients['lr'] * d_p, use_locking=self._use_locking) 97 | 98 | return control_flow_ops.group(*[var_update, buf_t]) 99 | 100 | def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 101 | # This method is only needed for momentum optimization. 102 | var_device, var_dtype = var.device, var.dtype.base_dtype 103 | coefficients = ((apply_state or {}).get((var_device, var_dtype)) 104 | or self._fallback_apply_state(var_device, var_dtype)) 105 | 106 | buf = self.get_slot(var, 'buf') 107 | b_scaled_g_values = grad * (1 - coefficients['dampening']) 108 | buf_t = state_ops.assign(buf, buf * coefficients['momentum'], use_locking=self._use_locking) 109 | 110 | with ops.control_dependencies([buf_t]): 111 | buf_t = self._resource_scatter_add(buf, indices, b_scaled_g_values) 112 | 113 | if self.nesterov: 114 | d_p = self._resource_scatter_add(buf_t * coefficients['momentum'], indices, grad) 115 | else: 116 | d_p = buf_t 117 | 118 | # Projection 119 | wd_ratio = 1 120 | if len(array_ops.shape(var)) > 1: 121 | d_p, wd_ratio = self._projection(var, grad, d_p, coefficients['delta'], coefficients['wd_ratio'], 122 | coefficients['epsilon']) 123 | 124 | # Weight decay 125 | if self.weight_decay > 0: 126 | var = state_ops.assign(var, var * (1 - coefficients['lr'] * coefficients['weight_decay'] * wd_ratio), use_locking=self._use_locking) 127 | 128 | var_update = state_ops.assign_sub(var, coefficients['lr'] * d_p, use_locking=self._use_locking) 129 | 130 | return control_flow_ops.group(*[var_update, buf_t]) 131 | 132 | def _channel_view(self, x): 133 | return array_ops.reshape(x, shape=[array_ops.shape(x)[0], -1]) 134 | 135 | def _layer_view(self, x): 136 | return array_ops.reshape(x, shape=[1, -1]) 137 | 138 | def _cosine_similarity(self, x, y, eps, view_func): 139 | x = view_func(x) 140 | y = view_func(y) 141 | 142 | x_norm = math_ops.euclidean_norm(x, axis=-1) + eps 143 | y_norm = math_ops.euclidean_norm(y, axis=-1) + eps 144 | dot = math_ops.reduce_sum(x * y, axis=-1) 145 | 146 | return math_ops.abs(dot) / x_norm / y_norm 147 | 148 | def _projection(self, var, grad, perturb, delta, wd_ratio, eps): 149 | # channel_view 150 | cosine_sim = self._cosine_similarity(grad, var, eps, self._channel_view) 151 | cosine_max = math_ops.reduce_max(cosine_sim) 152 | compare_val = delta / math_ops.sqrt(math_ops.cast(self._channel_view(var).shape[-1], dtype=delta.dtype)) 153 | 154 | perturb, wd = control_flow_ops.cond(pred=cosine_max < compare_val, 155 | true_fn=lambda : self.channel_true_fn(var, perturb, wd_ratio, eps), 156 | false_fn=lambda : self.channel_false_fn(var, grad, perturb, delta, wd_ratio, eps)) 157 | 158 | return perturb, wd 159 | 160 | def channel_true_fn(self, var, perturb, wd_ratio, eps): 161 | expand_size = [-1] + [1] * (len(var.shape) - 1) 162 | var_n = var / (array_ops.reshape(math_ops.euclidean_norm(self._channel_view(var), axis=-1), shape=expand_size) + eps) 163 | perturb = state_ops.assign_sub(perturb, var_n * array_ops.reshape(math_ops.reduce_sum(self._channel_view(var_n * perturb), axis=-1), shape=expand_size)) 164 | wd = wd_ratio 165 | 166 | return perturb, wd 167 | 168 | def channel_false_fn(self, var, grad, perturb, delta, wd_ratio, eps): 169 | cosine_sim = self._cosine_similarity(grad, var, eps, self._layer_view) 170 | cosine_max = math_ops.reduce_max(cosine_sim) 171 | compare_val = delta / math_ops.sqrt(math_ops.cast(self._layer_view(var).shape[-1], dtype=delta.dtype)) 172 | 173 | perturb, wd = control_flow_ops.cond(cosine_max < compare_val, 174 | true_fn=lambda : self.layer_true_fn(var, perturb, wd_ratio, eps), 175 | false_fn=lambda : self.identity_fn(perturb)) 176 | 177 | return perturb, wd 178 | 179 | def layer_true_fn(self, var, perturb, wd_ratio, eps): 180 | expand_size = [-1] + [1] * (len(var.shape) - 1) 181 | var_n = var / (array_ops.reshape(math_ops.euclidean_norm(self._layer_view(var), axis=-1), shape=expand_size) + eps) 182 | perturb = state_ops.assign_sub(perturb, var_n * array_ops.reshape(math_ops.reduce_sum(self._layer_view(var_n * perturb), axis=-1), shape=expand_size)) 183 | wd = wd_ratio 184 | 185 | return perturb, wd 186 | 187 | def identity_fn(self, perturb): 188 | wd = 1.0 189 | 190 | return perturb, wd 191 | 192 | def get_config(self): 193 | config = super(SGDP, self).get_config() 194 | config.update({ 195 | 'learning_rate': self._serialize_hyperparameter('learning_rate'), 196 | 'momentum': self._serialize_hyperparameter('momentum'), 197 | 'dampening': self._serialize_hyperparameter('dampening'), 198 | 'delta': self._serialize_hyperparameter('delta'), 199 | 'wd_ratio': self._serialize_hyperparameter('wd_ratio'), 200 | 'epsilon': self.epsilon, 201 | 'weight_decay': self.weight_decay, 202 | "nesterov": self.nesterov, 203 | }) 204 | return config 205 | --------------------------------------------------------------------------------