├── .gitignore
├── LICENSE
├── README.md
├── adamp_tf.py
└── sgdp_tf.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Junho Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## AdamP Optimizer — Unofficial TensorFlow Implementation
2 | ### "Slowing Down the Weight Norm Increase in Momentum-based Optimizers"
3 | ## Implemented by [Junho Kim](http://bit.ly/jhkim_ai)
4 | ### [[Paper]](https://arxiv.org/abs/2006.08217) [[Project page]](https://clovaai.github.io/AdamP/) [[Official Pytorch]](https://github.com/clovaai/AdamP)
5 |
6 |
7 |

8 |

9 |
10 |
11 |
12 | ## Validation
13 | I have checked that the code is working, but I couldn't confirm if the performance is the same as the offical code.
14 |
15 | ## Usage
16 | Usage is exactly same as [tf.keras.optimizers](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers) library!
17 | ```python
18 |
19 | from adamp_tf import AdamP
20 | from sgdp_tf import SGDP
21 |
22 | optimizer_adamp = AdamP(learning_rate=0.001, beta_1=0.9, beta_2=0.999, weight_decay=1e-2)
23 | optimizer_sgdp = SGDP(learning_rate=0.1, weight_decay=1e-5, momentum=0.9, nesterov=True)
24 | ```
25 | * **Do not use with `tf.nn.scale_regularization_loss`.** Use the `weight_decay` argument.
26 |
27 | ## Arguments
28 | `SGDP` and `AdamP` share arguments with [tf.keras.optimizers.SGD](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/SGD) and [tf.keras.optimizers.Adam](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adam).
29 | There are two additional hyperparameters; we recommend using the default values.
30 | - `delta` : threhold that determines whether a set of parameters is scale invariant or not (default: 0.1)
31 | - `wd_ratio` : relative weight decay applied on _scale-invariant_ parameters compared to that applied on _scale-variant_ parameters (default: 0.1)
32 |
33 | Both `SGDP` and `AdamP` support Nesterov momentum.
34 | - `nesterov` : enables Nesterov momentum (default: False)
35 |
36 | ## How to cite
37 |
38 | ```
39 | @article{heo2020adamp,
40 | title={Slowing Down the Weight Norm Increase in Momentum-based Optimizers},
41 | author={Heo, Byeongho and Chun, Sanghyuk and Oh, Seong Joon and Han, Dongyoon and Yun, Sangdoo and Uh, Youngjung and Ha, Jung-Woo},
42 | year={2020},
43 | journal={arXiv preprint arXiv:2006.08217},
44 | }
45 | ```
46 |
--------------------------------------------------------------------------------
/adamp_tf.py:
--------------------------------------------------------------------------------
1 | """AdamP for TensorFlow."""
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | from tensorflow.python.framework import ops
7 | from tensorflow.python.keras import backend_config
8 | from tensorflow.python.keras.optimizer_v2 import optimizer_v2
9 | from tensorflow.python.ops import array_ops
10 | from tensorflow.python.ops import control_flow_ops
11 | from tensorflow.python.ops import math_ops
12 | from tensorflow.python.ops import state_ops
13 |
14 |
15 | class AdamP(optimizer_v2.OptimizerV2):
16 | _HAS_AGGREGATE_GRAD = True
17 |
18 | def __init__(self,
19 | learning_rate=0.001,
20 | beta_1=0.9,
21 | beta_2=0.999,
22 | epsilon=1e-8,
23 | weight_decay=0.0,
24 | delta=0.1, wd_ratio=0.1, nesterov=False,
25 | name='AdamP',
26 | **kwargs):
27 |
28 | super(AdamP, self).__init__(name, **kwargs)
29 | self._set_hyper('learning_rate', kwargs.get('lr', learning_rate))
30 | self._set_hyper('beta_1', beta_1)
31 | self._set_hyper('beta_2', beta_2)
32 | self._set_hyper('delta', delta)
33 | self._set_hyper('wd_ratio', wd_ratio)
34 |
35 | self.epsilon = epsilon or backend_config.epsilon()
36 | self.weight_decay = weight_decay
37 | self.nesterov = nesterov
38 |
39 | def _create_slots(self, var_list):
40 | # Create slots for the first and second moments.
41 | # Separate for-loops to respect the ordering of slot variables from v1.
42 | for var in var_list:
43 | self.add_slot(var, 'm')
44 | for var in var_list:
45 | self.add_slot(var, 'v')
46 | for var in var_list:
47 | self.add_slot(var, 'p')
48 |
49 | def _prepare_local(self, var_device, var_dtype, apply_state):
50 | super(AdamP, self)._prepare_local(var_device, var_dtype, apply_state)
51 |
52 | local_step = math_ops.cast(self.iterations + 1, var_dtype)
53 | beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype))
54 | beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype))
55 | beta_1_power = math_ops.pow(beta_1_t, local_step)
56 | beta_2_power = math_ops.pow(beta_2_t, local_step)
57 |
58 | lr = apply_state[(var_device, var_dtype)]['lr_t']
59 | bias_correction1 = 1 - beta_1_power
60 | bias_correction2 = 1 - beta_2_power
61 |
62 | delta = array_ops.identity(self._get_hyper('delta', var_dtype))
63 | wd_ratio = array_ops.identity(self._get_hyper('wd_ratio', var_dtype))
64 |
65 | apply_state[(var_device, var_dtype)].update(
66 | dict(
67 | lr=lr,
68 | epsilon=ops.convert_to_tensor_v2(self.epsilon, var_dtype),
69 | weight_decay=ops.convert_to_tensor_v2(self.weight_decay, var_dtype),
70 | beta_1_t=beta_1_t,
71 | beta_1_power=beta_1_power,
72 | one_minus_beta_1_t=1 - beta_1_t,
73 | beta_2_t=beta_2_t,
74 | beta_2_power=beta_2_power,
75 | one_minus_beta_2_t=1 - beta_2_t,
76 | bias_correction1=bias_correction1,
77 | bias_correction2=bias_correction2,
78 | delta=delta,
79 | wd_ratio=wd_ratio))
80 |
81 | def set_weights(self, weights):
82 | params = self.weights
83 | # If the weights are generated by Keras V1 optimizer, it includes vhats
84 | # optimizer has 2x + 1 variables. Filter vhats out for compatibility.
85 | num_vars = int((len(params) - 1) / 2)
86 | if len(weights) == 3 * num_vars + 1:
87 | weights = weights[:len(params)]
88 | super(AdamP, self).set_weights(weights)
89 |
90 | def _resource_apply_dense(self, grad, var, apply_state=None):
91 | var_device, var_dtype = var.device, var.dtype.base_dtype
92 | coefficients = ((apply_state or {}).get((var_device, var_dtype))
93 | or self._fallback_apply_state(var_device, var_dtype))
94 |
95 | # m_t = beta1 * m + (1 - beta1) * g_t
96 | m = self.get_slot(var, 'm')
97 | m_scaled_g_values = grad * coefficients['one_minus_beta_1_t']
98 | m_t = state_ops.assign(m, m * coefficients['beta_1_t'] + m_scaled_g_values, use_locking=self._use_locking)
99 |
100 | # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
101 | v = self.get_slot(var, 'v')
102 | v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t']
103 | v_t = state_ops.assign(v, v * coefficients['beta_2_t'] + v_scaled_g_values, use_locking=self._use_locking)
104 |
105 | denorm = (math_ops.sqrt(v_t) / math_ops.sqrt(coefficients['bias_correction2'])) + coefficients['epsilon']
106 | step_size = coefficients['lr'] / coefficients['bias_correction1']
107 |
108 | if self.nesterov:
109 | perturb = (coefficients['beta_1_t'] * m_t + coefficients['one_minus_beta_1_t'] * grad) / denorm
110 | else:
111 | perturb = m_t / denorm
112 |
113 | # Projection
114 | wd_ratio = 1
115 | if len(var.shape) > 1:
116 | perturb, wd_ratio = self._projection(var, grad, perturb, coefficients['delta'], coefficients['wd_ratio'], coefficients['epsilon'])
117 |
118 | # Weight decay
119 |
120 | if self.weight_decay > 0:
121 | var = state_ops.assign(var, var * (1 - coefficients['lr'] * coefficients['weight_decay'] * wd_ratio), use_locking=self._use_locking)
122 |
123 | var_update = state_ops.assign_sub(var, step_size * perturb, use_locking=self._use_locking)
124 |
125 | return control_flow_ops.group(*[var_update, m_t, v_t])
126 |
127 |
128 | def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
129 |
130 | var_device, var_dtype = var.device, var.dtype.base_dtype
131 | coefficients = ((apply_state or {}).get((var_device, var_dtype))
132 | or self._fallback_apply_state(var_device, var_dtype))
133 | """
134 | Adam
135 | """
136 | # m_t = beta1 * m + (1 - beta1) * g_t
137 | m = self.get_slot(var, 'm')
138 | m_scaled_g_values = grad * coefficients['one_minus_beta_1_t']
139 | m_t = state_ops.assign(m, m * coefficients['beta_1_t'],
140 | use_locking=self._use_locking)
141 | with ops.control_dependencies([m_t]):
142 | m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)
143 |
144 | # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
145 | v = self.get_slot(var, 'v')
146 | v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t']
147 | v_t = state_ops.assign(v, v * coefficients['beta_2_t'],
148 | use_locking=self._use_locking)
149 | with ops.control_dependencies([v_t]):
150 | v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)
151 |
152 | denorm = (math_ops.sqrt(v_t) / math_ops.sqrt(coefficients['bias_correction2'])) + coefficients['epsilon']
153 | step_size = coefficients['lr'] / coefficients['bias_correction1']
154 |
155 | if self.nesterov:
156 | p_scaled_g_values = grad * coefficients['one_minus_beta_1_t']
157 | perturb = m_t * coefficients['beta_1_t']
158 | perturb = self._resource_scatter_add(perturb, indices, p_scaled_g_values) / denorm
159 |
160 | else:
161 | perturb = m_t / denorm
162 |
163 | # Projection
164 | wd_ratio = 1
165 | if len(var.shape) > 1:
166 | perturb, wd_ratio = self._projection(var, grad, perturb, coefficients['delta'], coefficients['wd_ratio'], coefficients['epsilon'])
167 |
168 | # Weight decay
169 | if self.weight_decay > 0:
170 | var = state_ops.assign(var, var * (1 - coefficients['lr'] * coefficients['weight_decay'] * wd_ratio), use_locking=self._use_locking)
171 |
172 | var_update = state_ops.assign_sub(var, step_size * perturb, use_locking=self._use_locking)
173 |
174 | return control_flow_ops.group(*[var_update, m_t, v_t])
175 |
176 | def _channel_view(self, x):
177 | return array_ops.reshape(x, shape=[x.shape[0], -1])
178 |
179 | def _layer_view(self, x):
180 | return array_ops.reshape(x, shape=[1, -1])
181 |
182 | def _cosine_similarity(self, x, y, eps, view_func):
183 | x = view_func(x)
184 | y = view_func(y)
185 |
186 | x_norm = math_ops.euclidean_norm(x, axis=-1) + eps
187 | y_norm = math_ops.euclidean_norm(y, axis=-1) + eps
188 | dot = math_ops.reduce_sum(x * y, axis=-1)
189 |
190 | return math_ops.abs(dot) / x_norm / y_norm
191 |
192 | def _projection(self, var, grad, perturb, delta, wd_ratio, eps):
193 | # channel_view
194 | cosine_sim = self._cosine_similarity(grad, var, eps, self._channel_view)
195 | cosine_max = math_ops.reduce_max(cosine_sim)
196 | compare_val = delta / math_ops.sqrt(math_ops.cast(self._channel_view(var).shape[-1], dtype=delta.dtype))
197 |
198 | perturb, wd = control_flow_ops.cond(pred=cosine_max < compare_val,
199 | true_fn=lambda : self.channel_true_fn(var, perturb, wd_ratio, eps),
200 | false_fn=lambda : self.channel_false_fn(var, grad, perturb, delta, wd_ratio, eps))
201 |
202 | return perturb, wd
203 |
204 | def channel_true_fn(self, var, perturb, wd_ratio, eps):
205 | expand_size = [-1] + [1] * (len(var.shape) - 1)
206 | var_n = var / (array_ops.reshape(math_ops.euclidean_norm(self._channel_view(var), axis=-1), shape=expand_size) + eps)
207 | perturb -= var_n * array_ops.reshape(math_ops.reduce_sum(self._channel_view(var_n * perturb), axis=-1), shape=expand_size)
208 | wd = wd_ratio
209 |
210 | return perturb, wd
211 |
212 | def channel_false_fn(self, var, grad, perturb, delta, wd_ratio, eps):
213 | cosine_sim = self._cosine_similarity(grad, var, eps, self._layer_view)
214 | cosine_max = math_ops.reduce_max(cosine_sim)
215 | compare_val = delta / math_ops.sqrt(math_ops.cast(self._layer_view(var).shape[-1], dtype=delta.dtype))
216 |
217 | perturb, wd = control_flow_ops.cond(cosine_max < compare_val,
218 | true_fn=lambda : self.layer_true_fn(var, perturb, wd_ratio, eps),
219 | false_fn=lambda : self.identity_fn(perturb))
220 |
221 | return perturb, wd
222 |
223 | def layer_true_fn(self, var, perturb, wd_ratio, eps):
224 | expand_size = [-1] + [1] * (len(var.shape) - 1)
225 | var_n = var / (array_ops.reshape(math_ops.euclidean_norm(self._layer_view(var), axis=-1), shape=expand_size) + eps)
226 | perturb -= var_n * array_ops.reshape(math_ops.reduce_sum(self._layer_view(var_n * perturb), axis=-1), shape=expand_size)
227 | wd = wd_ratio
228 |
229 | return perturb, wd
230 |
231 | def identity_fn(self, perturb):
232 | wd = 1.0
233 |
234 | return perturb, wd
235 |
236 | def get_config(self):
237 | config = super(AdamP, self).get_config()
238 | config.update({
239 | 'learning_rate': self._serialize_hyperparameter('learning_rate'),
240 | 'beta_1': self._serialize_hyperparameter('beta_1'),
241 | 'beta_2': self._serialize_hyperparameter('beta_2'),
242 | 'delta': self._serialize_hyperparameter('delta'),
243 | 'wd_ratio': self._serialize_hyperparameter('wd_ratio'),
244 | 'epsilon': self.epsilon,
245 | 'weight_decay': self.weight_decay,
246 | 'nesterov': self.nesterov
247 | })
248 | return config
--------------------------------------------------------------------------------
/sgdp_tf.py:
--------------------------------------------------------------------------------
1 | """ SGDP for TensorFlow."""
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | from tensorflow.python.framework import ops
7 | from tensorflow.python.keras.optimizer_v2 import optimizer_v2
8 | from tensorflow.python.ops import array_ops
9 | from tensorflow.python.ops import math_ops
10 | from tensorflow.python.ops import state_ops
11 | from tensorflow.python.ops import control_flow_ops
12 |
13 |
14 | class SGDP(optimizer_v2.OptimizerV2):
15 | _HAS_AGGREGATE_GRAD = True
16 | def __init__(self,
17 | learning_rate=0.1,
18 | momentum=0.0,
19 | dampening=0.0,
20 | weight_decay=0.0,
21 | nesterov=False,
22 | epsilon=1e-8,
23 | delta=0.1,
24 | wd_ratio=0.1,
25 | name="SGDP",
26 | **kwargs):
27 |
28 | super(SGDP, self).__init__(name, **kwargs)
29 | self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
30 |
31 | self._momentum = False
32 | if isinstance(momentum, ops.Tensor) or callable(momentum) or momentum > 0:
33 | self._momentum = True
34 | if isinstance(momentum, (int, float)) and (momentum < 0 or momentum > 1):
35 | raise ValueError("`momentum` must be between [0, 1].")
36 |
37 | self._set_hyper("momentum", momentum)
38 | self._set_hyper("dampening", dampening)
39 | self._set_hyper("epsilon", epsilon)
40 | self._set_hyper("delta", delta)
41 | self._set_hyper("wd_ratio", wd_ratio)
42 |
43 | self.nesterov = nesterov
44 | self.weight_decay = weight_decay
45 |
46 | def _create_slots(self, var_list):
47 | if self._momentum:
48 | for var in var_list:
49 | self.add_slot(var, "momentum")
50 | for var in var_list:
51 | self.add_slot(var, "buf")
52 |
53 | def _prepare_local(self, var_device, var_dtype, apply_state):
54 | super(SGDP, self)._prepare_local(var_device, var_dtype, apply_state)
55 | lr = apply_state[(var_device, var_dtype)]['lr_t']
56 |
57 | momentum = array_ops.identity(self._get_hyper("momentum", var_dtype))
58 | dampening = array_ops.identity(self._get_hyper('dampening', var_dtype))
59 | delta = array_ops.identity(self._get_hyper('delta', var_dtype))
60 | wd_ratio = array_ops.identity(self._get_hyper('wd_ratio', var_dtype))
61 |
62 | apply_state[(var_device, var_dtype)].update(
63 | dict(
64 | lr=lr,
65 | epsilon=ops.convert_to_tensor_v2(self.epsilon, var_dtype),
66 | weight_decay=ops.convert_to_tensor_v2(self.weight_decay, var_dtype),
67 | momentum=momentum,
68 | dampening=dampening,
69 | delta=delta,
70 | wd_ratio=wd_ratio))
71 |
72 |
73 | def _resource_apply_dense(self, grad, var, apply_state=None):
74 | var_device, var_dtype = var.device, var.dtype.base_dtype
75 | coefficients = ((apply_state or {}).get((var_device, var_dtype))
76 | or self._fallback_apply_state(var_device, var_dtype))
77 |
78 | buf = self.get_slot(var, 'buf')
79 | b_scaled_g_values = grad * (1 - coefficients['dampening'])
80 | buf_t = state_ops.assign(buf, buf * coefficients['momentum'] + b_scaled_g_values, use_locking=self._use_locking)
81 |
82 | if self.nesterov:
83 | d_p = grad + coefficients['momentum'] * buf_t
84 | else:
85 | d_p = buf_t
86 |
87 | # Projection
88 | wd_ratio = 1
89 | if len(var.shape) > 1:
90 | d_p, wd_ratio = self._projection(var, grad, d_p, coefficients['delta'], coefficients['wd_ratio'], coefficients['epsilon'])
91 |
92 | # Weight decay
93 | if self.weight_decay > 0:
94 | var = state_ops.assign(var, var * (1 - coefficients['lr'] * coefficients['weight_decay'] * wd_ratio), use_locking=self._use_locking)
95 |
96 | var_update = state_ops.assign_sub(var, coefficients['lr'] * d_p, use_locking=self._use_locking)
97 |
98 | return control_flow_ops.group(*[var_update, buf_t])
99 |
100 | def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
101 | # This method is only needed for momentum optimization.
102 | var_device, var_dtype = var.device, var.dtype.base_dtype
103 | coefficients = ((apply_state or {}).get((var_device, var_dtype))
104 | or self._fallback_apply_state(var_device, var_dtype))
105 |
106 | buf = self.get_slot(var, 'buf')
107 | b_scaled_g_values = grad * (1 - coefficients['dampening'])
108 | buf_t = state_ops.assign(buf, buf * coefficients['momentum'], use_locking=self._use_locking)
109 |
110 | with ops.control_dependencies([buf_t]):
111 | buf_t = self._resource_scatter_add(buf, indices, b_scaled_g_values)
112 |
113 | if self.nesterov:
114 | d_p = self._resource_scatter_add(buf_t * coefficients['momentum'], indices, grad)
115 | else:
116 | d_p = buf_t
117 |
118 | # Projection
119 | wd_ratio = 1
120 | if len(array_ops.shape(var)) > 1:
121 | d_p, wd_ratio = self._projection(var, grad, d_p, coefficients['delta'], coefficients['wd_ratio'],
122 | coefficients['epsilon'])
123 |
124 | # Weight decay
125 | if self.weight_decay > 0:
126 | var = state_ops.assign(var, var * (1 - coefficients['lr'] * coefficients['weight_decay'] * wd_ratio), use_locking=self._use_locking)
127 |
128 | var_update = state_ops.assign_sub(var, coefficients['lr'] * d_p, use_locking=self._use_locking)
129 |
130 | return control_flow_ops.group(*[var_update, buf_t])
131 |
132 | def _channel_view(self, x):
133 | return array_ops.reshape(x, shape=[array_ops.shape(x)[0], -1])
134 |
135 | def _layer_view(self, x):
136 | return array_ops.reshape(x, shape=[1, -1])
137 |
138 | def _cosine_similarity(self, x, y, eps, view_func):
139 | x = view_func(x)
140 | y = view_func(y)
141 |
142 | x_norm = math_ops.euclidean_norm(x, axis=-1) + eps
143 | y_norm = math_ops.euclidean_norm(y, axis=-1) + eps
144 | dot = math_ops.reduce_sum(x * y, axis=-1)
145 |
146 | return math_ops.abs(dot) / x_norm / y_norm
147 |
148 | def _projection(self, var, grad, perturb, delta, wd_ratio, eps):
149 | # channel_view
150 | cosine_sim = self._cosine_similarity(grad, var, eps, self._channel_view)
151 | cosine_max = math_ops.reduce_max(cosine_sim)
152 | compare_val = delta / math_ops.sqrt(math_ops.cast(self._channel_view(var).shape[-1], dtype=delta.dtype))
153 |
154 | perturb, wd = control_flow_ops.cond(pred=cosine_max < compare_val,
155 | true_fn=lambda : self.channel_true_fn(var, perturb, wd_ratio, eps),
156 | false_fn=lambda : self.channel_false_fn(var, grad, perturb, delta, wd_ratio, eps))
157 |
158 | return perturb, wd
159 |
160 | def channel_true_fn(self, var, perturb, wd_ratio, eps):
161 | expand_size = [-1] + [1] * (len(var.shape) - 1)
162 | var_n = var / (array_ops.reshape(math_ops.euclidean_norm(self._channel_view(var), axis=-1), shape=expand_size) + eps)
163 | perturb = state_ops.assign_sub(perturb, var_n * array_ops.reshape(math_ops.reduce_sum(self._channel_view(var_n * perturb), axis=-1), shape=expand_size))
164 | wd = wd_ratio
165 |
166 | return perturb, wd
167 |
168 | def channel_false_fn(self, var, grad, perturb, delta, wd_ratio, eps):
169 | cosine_sim = self._cosine_similarity(grad, var, eps, self._layer_view)
170 | cosine_max = math_ops.reduce_max(cosine_sim)
171 | compare_val = delta / math_ops.sqrt(math_ops.cast(self._layer_view(var).shape[-1], dtype=delta.dtype))
172 |
173 | perturb, wd = control_flow_ops.cond(cosine_max < compare_val,
174 | true_fn=lambda : self.layer_true_fn(var, perturb, wd_ratio, eps),
175 | false_fn=lambda : self.identity_fn(perturb))
176 |
177 | return perturb, wd
178 |
179 | def layer_true_fn(self, var, perturb, wd_ratio, eps):
180 | expand_size = [-1] + [1] * (len(var.shape) - 1)
181 | var_n = var / (array_ops.reshape(math_ops.euclidean_norm(self._layer_view(var), axis=-1), shape=expand_size) + eps)
182 | perturb = state_ops.assign_sub(perturb, var_n * array_ops.reshape(math_ops.reduce_sum(self._layer_view(var_n * perturb), axis=-1), shape=expand_size))
183 | wd = wd_ratio
184 |
185 | return perturb, wd
186 |
187 | def identity_fn(self, perturb):
188 | wd = 1.0
189 |
190 | return perturb, wd
191 |
192 | def get_config(self):
193 | config = super(SGDP, self).get_config()
194 | config.update({
195 | 'learning_rate': self._serialize_hyperparameter('learning_rate'),
196 | 'momentum': self._serialize_hyperparameter('momentum'),
197 | 'dampening': self._serialize_hyperparameter('dampening'),
198 | 'delta': self._serialize_hyperparameter('delta'),
199 | 'wd_ratio': self._serialize_hyperparameter('wd_ratio'),
200 | 'epsilon': self.epsilon,
201 | 'weight_decay': self.weight_decay,
202 | "nesterov": self.nesterov,
203 | })
204 | return config
205 |
--------------------------------------------------------------------------------