├── .gitignore ├── LICENSE ├── README.md ├── adabound.py ├── assets └── mnist_acc.png ├── logs ├── adabound │ └── events.out.tfevents.1552196804.KOZISTR ├── adagrad │ └── events.out.tfevents.1551702045.KOZISTR ├── adam │ └── events.out.tfevents.1551701853.KOZISTR ├── amsbound │ └── events.out.tfevents.1552197385.KOZISTR ├── momentum │ └── events.out.tfevents.1551702320.KOZISTR └── sgd │ └── events.out.tfevents.1551702195.KOZISTR └── mnist_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # pycharm 104 | .idea/ 105 | 106 | # github pr forms 107 | .github/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AdaBound in Tensorflow 2 | An optimizer that trains as fast as Adam and as good as SGD in Tensorflow 3 | 4 | This repo is based on pytorch impl [original repo](https://github.com/Luolc/AdaBound) 5 | 6 | [![Total alerts](https://img.shields.io/lgtm/alerts/g/kozistr/AdaBound-tensorflow.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/kozistr/AdaBound-tensorflow/alerts/) 7 | [![Language grade: Python](https://img.shields.io/lgtm/grade/python/g/kozistr/AdaBound-tensorflow.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/kozistr/AdaBound-tensorflow/context:python) 8 | 9 | ## Explanation 10 | An optimizer that trains as fast as Adam and as good as SGD, 11 | for developing state-of-the-art deep learning models on a wide variety of popular tasks in the field of CV, NLP, and etc. 12 | 13 | Based on Luo et al. (2019). [Adaptive Gradient Methods with Dynamic Bound of Learning Rate](https://openreview.net/forum?id=Bkg3g2R9FX). In Proc. of ICLR 2019. 14 | 15 | ## Requirement 16 | * Python 3.x 17 | * Tensorflow 1.x (maybe for 2.x) 18 | 19 | ## Usage 20 | 21 | ```python 22 | # learning can be either a scalar or a tensor 23 | 24 | # use exclude_from_weight_decay feature, 25 | # if you wanna selectively disable updating weight-decayed weights 26 | 27 | optimizer = AdaBoundOptimizer( 28 | learning_rate=1e-3, 29 | final_lr=1e-1, 30 | beta_1=0.9, 31 | beta_2=0.999, 32 | gamma=1e-3, 33 | epsilon=1e-6, 34 | amsbound=False, 35 | decay=0., 36 | weight_decay=0., 37 | exclude_from_weight_decay=["..."] 38 | ) 39 | ``` 40 | 41 | You can simply test the optimizers on MNIST Dataset w/ below model! 42 | 43 | For `AdaBound` optimizer, 44 | ```python 45 | python3 mnist_test --optimizer "adabound" 46 | ``` 47 | 48 | For `AMSBound` optimizer, 49 | ```python 50 | python3 mnist_test --optimizer "amsbound" 51 | ``` 52 | 53 | ## Results 54 | 55 | Testing Accuracy & Loss among the optimizers on the several data sets w/ under same condition. 56 | 57 | ### MNIST DataSet 58 | 59 | ![acc](./assets/mnist_acc.png) 60 | 61 | *Optimizer* | *Test Acc* | *Time* | *Etc* | 62 | :---: | :---: | :---: | :---: | 63 | AdaBound | **97.77%** | 5m 45s | | 64 | AMSBound | 97.72% | 5m 52s | | 65 | Adam | 97.62% | 4m 18s | | 66 | AdaGrad | 90.15% | **4m 07s** | | 67 | SGD | 87.88% | 5m 26s | | 68 | Momentum | 87.88% | 4m 26s | w/ nestrov | 69 | 70 | ## Citation 71 | 72 | ``` 73 | @inproceedings{Luo2019AdaBound, 74 | author = {Luo, Liangchen and Xiong, Yuanhao and Liu, Yan and Sun, Xu}, 75 | title = {Adaptive Gradient Methods with Dynamic Bound of Learning Rate}, 76 | booktitle = {Proceedings of the 7th International Conference on Learning Representations}, 77 | month = {May}, 78 | year = {2019}, 79 | address = {New Orleans, Louisiana} 80 | } 81 | ``` 82 | 83 | ## Author 84 | 85 | Hyeongchan Kim / [kozistr](http://kozistr.tech) 86 | -------------------------------------------------------------------------------- /adabound.py: -------------------------------------------------------------------------------- 1 | """AdaBound for Tensorflow.""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import tensorflow as tf 7 | import re 8 | 9 | 10 | class AdaBoundOptimizer(tf.train.Optimizer): 11 | """Optimizer that implements the AdaBound algorithm. 12 | 13 | See [Luo et al., 2019](https://openreview.net/forum?id=Bkg3g2R9FX) 14 | ([pdf](https://openreview.net/pdf?id=Bkg3g2R9FX)). 15 | """ 16 | 17 | def __init__(self, 18 | learning_rate=0.001, 19 | final_lr=0.1, 20 | beta1=0.9, 21 | beta2=0.999, 22 | gamma=1e-3, 23 | epsilon=1e-8, 24 | amsbound=False, 25 | decay=0., 26 | weight_decay=0., 27 | exclude_from_weight_decay=None, 28 | use_locking=False, name="AdaBound"): 29 | super(AdaBoundOptimizer, self).__init__(use_locking, name) 30 | 31 | if final_lr <= 0.: 32 | raise ValueError("Invalid final learning rate : {}".format(final_lr)) 33 | if not 0. <= beta1 < 1.: 34 | raise ValueError("Invalid beta1 value : {}".format(beta1)) 35 | if not 0. <= beta2 < 1.: 36 | raise ValueError("Invalid beta2 value : {}".format(beta2)) 37 | if not 0. <= gamma < 1.: 38 | raise ValueError("Invalid gamma value : {}".format(gamma)) 39 | if epsilon <= 0.: 40 | raise ValueError("Invalid epsilon value : {}".format(epsilon)) 41 | 42 | self._lr = learning_rate 43 | self._beta1 = beta1 44 | self._beta2 = beta2 45 | self._final_lr = final_lr 46 | self._gamma = gamma 47 | self._epsilon = epsilon 48 | self._amsbound = amsbound 49 | self._decay = decay 50 | self._weight_decay = weight_decay 51 | self._exclude_from_weight_decay = exclude_from_weight_decay 52 | 53 | self._base_lr = learning_rate 54 | 55 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 56 | lr = self._lr 57 | t = tf.cast(global_step, dtype=tf.float32) 58 | 59 | if self._decay > 0.: 60 | lr *= (1. / (1. + self._decay * t)) 61 | 62 | t += 1 63 | 64 | bias_correction1 = 1. - (self._beta1 ** t) 65 | bias_correction2 = 1. - (self._beta2 ** t) 66 | step_size = (lr * tf.sqrt(bias_correction2) / bias_correction1) 67 | 68 | # Applies bounds on actual learning rate 69 | # lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay 70 | final_lr = self._final_lr * lr / self._base_lr 71 | lower_bound = final_lr * (1. - 1. / (self._gamma * t + 1.)) 72 | upper_bound = final_lr * (1. + 1. / (self._gamma * t)) 73 | 74 | assignments = [] 75 | for grad, param in grads_and_vars: 76 | if grad is None or param is None: 77 | continue 78 | 79 | param_name = self._get_variable_name(param.name) 80 | 81 | m = tf.get_variable( 82 | name=param_name + "/adabound_m", 83 | shape=param.shape.as_list(), 84 | dtype=tf.float32, 85 | trainable=False, 86 | initializer=tf.zeros_initializer()) 87 | v = tf.get_variable( 88 | name=param_name + "/adabound_v", 89 | shape=param.shape.as_list(), 90 | dtype=tf.float32, 91 | trainable=False, 92 | initializer=tf.zeros_initializer()) 93 | if self._amsbound: 94 | v_hat = tf.get_variable( 95 | name=param_name + "/adabound_v_hat", 96 | shape=param.shape.as_list(), 97 | dtype=tf.float32, 98 | trainable=False, 99 | initializer=tf.zeros_initializer()) 100 | 101 | m_t = ( 102 | tf.multiply(self._beta1, m) + tf.multiply(1. - self._beta1, grad)) 103 | v_t = ( 104 | tf.multiply(self._beta2, v) + tf.multiply(1. - self._beta2, tf.square(grad))) 105 | 106 | if self._amsbound: 107 | # Maintains the maximum of all 2nd moment running avg. till now 108 | v_hat_t = tf.maximum(v_hat, v_t) 109 | 110 | # Use the max. for normalizing running avg. of gradient 111 | denom = (tf.sqrt(v_hat_t) + self._epsilon) 112 | else: 113 | denom = (tf.sqrt(v_t) + self._epsilon) 114 | 115 | step_size_p = step_size * tf.ones_like(denom) 116 | step_size_p_bound = step_size_p / denom 117 | 118 | lr_t = m_t * tf.clip_by_value(t=step_size_p_bound, 119 | clip_value_min=lower_bound, 120 | clip_value_max=upper_bound) 121 | p_t = param - lr_t 122 | 123 | if self._do_use_weight_decay(param_name): 124 | p_t += self._weight_decay * param 125 | 126 | update_list = [param.assign(p_t), m.assign(m_t), v.assign(v_t)] 127 | if self._amsbound: 128 | update_list.append(v_hat.assign(v_hat_t)) 129 | 130 | assignments.extend(update_list) 131 | 132 | # update the global step 133 | assignments.append(global_step.assign_add(1)) 134 | 135 | return tf.group(*assignments, name=name) 136 | 137 | def _do_use_weight_decay(self, param_name): 138 | """Whether to use L2 weight decay for `param_name`.""" 139 | if not self._weight_decay: 140 | return False 141 | if self._exclude_from_weight_decay: 142 | for r in self.exclude_from_weight_decay: 143 | if re.search(r, param_name) is not None: 144 | return False 145 | return True 146 | 147 | @staticmethod 148 | def _get_variable_name(param_name): 149 | """Get the variable name from the tensor name.""" 150 | m = re.match("^(.*):\\d+$", param_name) 151 | if m is not None: 152 | param_name = m.group(1) 153 | return param_name 154 | -------------------------------------------------------------------------------- /assets/mnist_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/AdaBound-tensorflow/4da54bed8af9f90b302edb561b27e24d9b3eb6c1/assets/mnist_acc.png -------------------------------------------------------------------------------- /logs/adabound/events.out.tfevents.1552196804.KOZISTR: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/AdaBound-tensorflow/4da54bed8af9f90b302edb561b27e24d9b3eb6c1/logs/adabound/events.out.tfevents.1552196804.KOZISTR -------------------------------------------------------------------------------- /logs/adagrad/events.out.tfevents.1551702045.KOZISTR: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/AdaBound-tensorflow/4da54bed8af9f90b302edb561b27e24d9b3eb6c1/logs/adagrad/events.out.tfevents.1551702045.KOZISTR -------------------------------------------------------------------------------- /logs/adam/events.out.tfevents.1551701853.KOZISTR: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/AdaBound-tensorflow/4da54bed8af9f90b302edb561b27e24d9b3eb6c1/logs/adam/events.out.tfevents.1551701853.KOZISTR -------------------------------------------------------------------------------- /logs/amsbound/events.out.tfevents.1552197385.KOZISTR: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/AdaBound-tensorflow/4da54bed8af9f90b302edb561b27e24d9b3eb6c1/logs/amsbound/events.out.tfevents.1552197385.KOZISTR -------------------------------------------------------------------------------- /logs/momentum/events.out.tfevents.1551702320.KOZISTR: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/AdaBound-tensorflow/4da54bed8af9f90b302edb561b27e24d9b3eb6c1/logs/momentum/events.out.tfevents.1551702320.KOZISTR -------------------------------------------------------------------------------- /logs/sgd/events.out.tfevents.1551702195.KOZISTR: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kozistr/AdaBound-tensorflow/4da54bed8af9f90b302edb561b27e24d9b3eb6c1/logs/sgd/events.out.tfevents.1551702195.KOZISTR -------------------------------------------------------------------------------- /mnist_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import argparse 7 | import os 8 | 9 | from tensorflow.examples.tutorials.mnist import input_data 10 | 11 | from adabound import AdaBoundOptimizer 12 | 13 | 14 | w_init = tf.contrib.layers.variance_scaling_initializer(factor=3., mode='FAN_AVG', uniform=True) 15 | w_reg = tf.contrib.layers.l2_regularizer(5e-4) 16 | 17 | 18 | def train(sess, 19 | input_shape=(None, 784), n_classes=10, 20 | n_feat=32, n_blocks=2, 21 | optimizer="adabound", lr=1e-3, grad_clip=0., 22 | log_dir="./logs"): 23 | def prepare_optimizer(optimizer_name, _global_step): 24 | # You can just use learning rate 25 | # either scalar value or tensor like below form 26 | # learning_rate = lr 27 | learning_rate = tf.train.exponential_decay( 28 | learning_rate=lr, 29 | global_step=_global_step, 30 | decay_steps=500, 31 | decay_rate=.95, 32 | staircase=True, 33 | ) 34 | 35 | if optimizer_name == "adabound": 36 | return AdaBoundOptimizer(learning_rate=learning_rate) 37 | elif optimizer_name == "amsbound": 38 | return AdaBoundOptimizer(learning_rate=learning_rate, amsbound=True) 39 | elif optimizer_name == "adam": 40 | return tf.train.AdamOptimizer(learning_rate=learning_rate) 41 | elif optimizer_name == "sgd": 42 | return tf.train.GradientDescentOptimizer(learning_rate=learning_rate) 43 | elif optimizer_name == "adagrad": 44 | return tf.train.AdagradOptimizer(learning_rate=learning_rate) 45 | elif optimizer_name == "momentum": 46 | return tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=1e-6, use_nesterov=True) 47 | else: 48 | raise NotImplementedError("[-] Unsupported Optimizer %s" % optimizer_name) 49 | 50 | with tf.name_scope("inputs"): 51 | img = tf.placeholder(tf.float32, shape=input_shape, name="x-image") 52 | label = tf.placeholder(tf.int32, shape=(None, n_classes), name="y-label") 53 | do_rate = tf.placeholder(tf.float32, shape=(), name="dropout") 54 | 55 | """ 56 | # CNN architecture example 57 | with tf.variable_scope("simple_cnn_model"): 58 | x = tf.reshape(img, [-1, 28, 28, 1]) 59 | 60 | for n_layer_idx in range(n_blocks): 61 | with tf.variable_scope("cnn_layer_%d" % n_layer_idx): 62 | x = tf.layers.conv2d(x, filters=n_feat, kernel_size=3, strides=1, padding='SAME', 63 | kernel_initializer=w_init, kernel_regularizer=w_reg) 64 | x = tf.nn.leaky_relu(x, alpha=0.2) 65 | x = tf.nn.dropout(x, keep_prob=do_rate) 66 | x = tf.layers.max_pooling2d(x, pool_size=(2, 2), strides=(2, 2), padding='SAME') \ 67 | if n_layer_idx % 2 == 0 else x 68 | n_feat *= 2 69 | 70 | x = tf.layers.flatten(x) 71 | 72 | x = tf.layers.dense(x, units=256, 73 | kernel_initializer=w_init, kernel_regularizer=w_reg) 74 | x = tf.nn.leaky_relu(x, alpha=0.2) 75 | x = tf.nn.dropout(x, keep_prob=do_rate) 76 | 77 | logits = tf.layers.dense(x, units=n_classes, 78 | kernel_initializer=w_init, kernel_regularizer=w_reg) 79 | pred = tf.nn.softmax(logits) 80 | """ 81 | 82 | with tf.variable_scope("simple_nn_model"): 83 | x = tf.layers.dense(img, units=256) 84 | x = tf.nn.leaky_relu(x, alpha=0.2) 85 | x = tf.nn.dropout(x, do_rate) 86 | 87 | x = tf.layers.dense(x, units=64) 88 | x = tf.nn.leaky_relu(x, alpha=0.2) 89 | x = tf.nn.dropout(x, do_rate) 90 | 91 | logits = tf.layers.dense(x, units=n_classes) 92 | pred = tf.nn.softmax(logits) 93 | with tf.name_scope("loss"): 94 | loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=label)) 95 | 96 | with tf.name_scope("train"): 97 | global_step = tf.train.get_or_create_global_step() 98 | opt = prepare_optimizer(optimizer, global_step) 99 | 100 | t_vars = tf.trainable_variables() 101 | grads = tf.gradients(loss, t_vars) 102 | 103 | # gradient clipping 104 | if grad_clip: 105 | grads, _ = tf.clip_by_global_norm(grads, clip_norm=grad_clip) 106 | 107 | train_op = opt.apply_gradients(zip(grads, t_vars), global_step=global_step) 108 | 109 | with tf.name_scope("metric"): 110 | corr_pred = tf.equal(tf.argmax(pred, axis=1), tf.argmax(label, axis=1)) 111 | acc = tf.reduce_mean(tf.cast(corr_pred, dtype=tf.float32)) 112 | 113 | with tf.name_scope("summary"): 114 | tf.summary.scalar("loss/loss", loss) 115 | tf.summary.scalar("metric/acc", acc) 116 | 117 | merged = tf.summary.merge_all() 118 | train_writer = tf.summary.FileWriter(os.path.join(log_dir, "train"), sess.graph) 119 | test_writer = tf.summary.FileWriter(os.path.join(log_dir, "test"), sess.graph) 120 | saver = tf.train.Saver(max_to_keep=1) 121 | return (img, label, do_rate), merged, train_op, loss, (train_writer, test_writer, saver) 122 | 123 | 124 | def main(training_steps, 125 | batch_size, 126 | n_classes, 127 | learning_rate, 128 | optimizer, 129 | n_blocks, 130 | filters, 131 | dropout, 132 | model_dir, 133 | data_dir, 134 | log_dir, 135 | logging_steps): 136 | # 0. prepare folders 137 | os.makedirs(log_dir, exist_ok=True) 138 | os.makedirs(model_dir, exist_ok=True) 139 | 140 | # 1. loading the MNIST dataset 141 | mnist = input_data.read_data_sets(data_dir, one_hot=True) 142 | 143 | config = tf.ConfigProto() 144 | config.gpu_options.allow_growth = True 145 | with tf.Session(config=config) as sess: 146 | # 2. loading the model 147 | (x, y, do_rate), merged, train_op, loss, (tr_writer, te_writer, saver) = train( 148 | sess=sess, 149 | input_shape=(None, 28 * 28), 150 | n_classes=n_classes, 151 | n_blocks=n_blocks, 152 | n_feat=filters, 153 | optimizer=optimizer, 154 | lr=learning_rate, 155 | log_dir=log_dir 156 | ) 157 | 158 | sess.run(tf.global_variables_initializer()) 159 | 160 | # 2-1. loading pre-trained model 161 | ckpt = tf.train.get_checkpoint_state(model_dir) 162 | if ckpt and ckpt.model_checkpoint_path: 163 | # Restores from checkpoint 164 | saver.restore(sess, ckpt.model_checkpoint_path) 165 | 166 | global_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]) 167 | print("[+] global step : %d" % global_step, " successfully loaded") 168 | else: 169 | global_step = 0 170 | print('[-] No checkpoint file found') 171 | 172 | for steps in range(global_step, training_steps): 173 | x_tr, y_tr = mnist.train.next_batch(batch_size) 174 | 175 | _, tr_loss = sess.run([train_op, loss], 176 | feed_dict={ 177 | x: x_tr, 178 | y: y_tr, 179 | do_rate: dropout, 180 | }) 181 | 182 | if steps and steps % logging_steps == 0: 183 | summary = sess.run(merged, 184 | feed_dict={ 185 | x: mnist.test.images, 186 | y: mnist.test.labels, 187 | do_rate: 1., 188 | }) 189 | 190 | te_writer.add_summary(summary, global_step) 191 | saver.save(sess, model_dir, global_step) 192 | 193 | if steps and steps % logging_steps == 0: 194 | print("[*] steps %05d : loss %.6f" % (steps, tr_loss)) 195 | 196 | summary = sess.run(merged, 197 | feed_dict={ 198 | x: x_tr, 199 | y: y_tr, 200 | do_rate: dropout, 201 | }) 202 | 203 | tr_writer.add_summary(summary, global_step) 204 | 205 | global_step += 1 206 | 207 | 208 | if __name__ == "__main__": 209 | parser = argparse.ArgumentParser() 210 | parser.add_argument('--training_steps', required=False, type=int, default=50001) 211 | parser.add_argument('--n_classes', required=False, type=int, default=10) 212 | parser.add_argument('--batch_size', required=False, type=int, default=128) 213 | parser.add_argument('--learning_rate', required=False, type=float, default=0.001) 214 | parser.add_argument('--optimizer', required=False, type=str, default="adabound", 215 | choices=["adabound", "amsbound", "adam", "sgd", "momentum", "adagrad"]) 216 | parser.add_argument('--filters', required=False, type=int, default=32) 217 | parser.add_argument('--n_blocks', required=False, type=int, default=4) 218 | parser.add_argument('--dropout', required=False, type=float, default=0.5) 219 | parser.add_argument('--model_dir', required=False, type=str, default="./model/") 220 | parser.add_argument('--data_dir', required=False, type=str, default="./mnist/") 221 | parser.add_argument('--log_dir', required=False, type=str, default="./logs") 222 | parser.add_argument('--logging_steps', required=False, type=int, default=1000) 223 | parser.add_argument('--seed', required=False, type=int, default=1337) 224 | args = vars(parser.parse_args()) 225 | 226 | for k, v in args.items(): 227 | print("[+] {} : {}".format(k, v)) 228 | 229 | # reproducibility 230 | tf.set_random_seed(args["seed"]) 231 | 232 | main( 233 | training_steps=args["training_steps"], 234 | n_classes=args["n_classes"], 235 | batch_size=args["batch_size"], 236 | learning_rate=args["learning_rate"], 237 | optimizer=args["optimizer"], 238 | n_blocks=args["n_blocks"], 239 | filters=args["filters"], 240 | dropout=args["dropout"], 241 | model_dir=args["model_dir"], 242 | data_dir=args["data_dir"], 243 | log_dir=args["log_dir"], 244 | logging_steps=args["logging_steps"], 245 | ) 246 | --------------------------------------------------------------------------------