├── .floydexpt ├── .floydignore ├── .gitignore ├── LICENSE ├── Pipfile ├── Pipfile.lock ├── README.md ├── analysis ├── Max min Time_to_train_vs_learning_rate.png ├── Time_to_train_vs_learning_rate.png ├── Time_to_train_vs_learning_rate.xlsx ├── Time_to_train_vs_size_of_model.xlsx ├── Validation_accuracy_over_time matlib.png ├── Validation_accuracy_over_time.png ├── Validation_accuracy_over_time.xlsx ├── time-vs-size-all.png └── time-vs-size-single.png ├── hooks.py ├── ploty.py ├── run-notebook.sh ├── run-time-vs-lr.sh ├── run-time-vs-size.sh ├── run-trace.sh └── train.py /.floydexpt: -------------------------------------------------------------------------------- 1 | {"name": "learning-rates", "family_id": "ixv47BML3dyKqTbKb4AFHP", "namespace": "davidmack"} -------------------------------------------------------------------------------- /.floydignore: -------------------------------------------------------------------------------- 1 | 2 | # Directories and files to ignore when uploading code to floyd 3 | 4 | data 5 | output 6 | analysis 7 | MNIST-data 8 | __pycache__ 9 | 10 | .git 11 | .eggs 12 | eggs 13 | lib 14 | lib64 15 | parts 16 | sdist 17 | var 18 | *.pyc 19 | *.swp 20 | .DS_Store 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | output/ 3 | MNIST-data/ 4 | __pycache__ 5 | *~ 6 | .DS_Store 7 | 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | env/ 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | .hypothesis/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # dotenv 91 | .env 92 | 93 | # virtualenv 94 | .venv 95 | venv/ 96 | ENV/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to 25 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | 3 | url = "https://pypi.python.org/simple" 4 | verify_ssl = true 5 | name = "pypi" 6 | 7 | 8 | [dev-packages] 9 | 10 | 11 | 12 | [packages] 13 | 14 | tensorflow = "*" 15 | pprint = "*" 16 | 17 | 18 | [requires] 19 | 20 | python_version = "3.6" 21 | -------------------------------------------------------------------------------- /Pipfile.lock: -------------------------------------------------------------------------------- 1 | { 2 | "_meta": { 3 | "hash": { 4 | "sha256": "eccc08e4abee5da73f79df9a7d05f348839877aec81615cfc1eb5ea3f56873db" 5 | }, 6 | "host-environment-markers": { 7 | "implementation_name": "cpython", 8 | "implementation_version": "3.6.2", 9 | "os_name": "posix", 10 | "platform_machine": "x86_64", 11 | "platform_python_implementation": "CPython", 12 | "platform_release": "17.4.0", 13 | "platform_system": "Darwin", 14 | "platform_version": "Darwin Kernel Version 17.4.0: Sun Dec 17 09:19:54 PST 2017; root:xnu-4570.41.2~1/RELEASE_X86_64", 15 | "python_full_version": "3.6.2", 16 | "python_version": "3.6", 17 | "sys_platform": "darwin" 18 | }, 19 | "pipfile-spec": 6, 20 | "requires": { 21 | "python_version": "3.6" 22 | }, 23 | "sources": [ 24 | { 25 | "name": "pypi", 26 | "url": "https://pypi.python.org/simple", 27 | "verify_ssl": true 28 | } 29 | ] 30 | }, 31 | "default": { 32 | "absl-py": { 33 | "hashes": [ 34 | "sha256:908eba9a96a37c10f10074aba57d685070b814906b02a1ea2cf54bb10a6b8c74" 35 | ], 36 | "version": "==0.1.10" 37 | }, 38 | "astor": { 39 | "hashes": [ 40 | "sha256:64c805f1ad6fbc505633416b6174fc23796eb164f371a7dc1f3951ea30560fb5", 41 | "sha256:ff6d2e2962d834acb125cc4dcc80c54a8c17c253f4cc9d9c43b5102a560bb75d" 42 | ], 43 | "version": "==0.6.2" 44 | }, 45 | "backports.weakref": { 46 | "hashes": [ 47 | "sha256:81bc9b51c0abc58edc76aefbbc68c62a787918ffe943a37947e162c3f8e19e82", 48 | "sha256:bc4170a29915f8b22c9e7c4939701859650f2eb84184aee80da329ac0b9825c2" 49 | ], 50 | "version": "==1.0.post1" 51 | }, 52 | "bleach": { 53 | "hashes": [ 54 | "sha256:e67f46adcec78dbc3c04462f3aba3213a673d5652eba2609ed1ef15492a44b8d", 55 | "sha256:978e758599b54cd3caa2e160d74102879b230ea8dc93871d0783721eef58bc65" 56 | ], 57 | "version": "==1.5.0" 58 | }, 59 | "enum34": { 60 | "hashes": [ 61 | "sha256:6bd0f6ad48ec2aa117d3d141940d484deccda84d4fcd884f5c3d93c23ecd8c79", 62 | "sha256:644837f692e5f550741432dd3f223bbb9852018674981b1664e5dc339387588a", 63 | "sha256:8ad8c4783bf61ded74527bffb48ed9b54166685e4230386a9ed9b1279e2df5b1", 64 | "sha256:2d81cbbe0e73112bdfe6ef8576f2238f2ba27dd0d55752a776c41d38b7da2850" 65 | ], 66 | "version": "==1.1.6" 67 | }, 68 | "funcsigs": { 69 | "hashes": [ 70 | "sha256:330cc27ccbf7f1e992e69fef78261dc7c6569012cf397db8d3de0234e6c937ca", 71 | "sha256:a7bb0f2cf3a3fd1ab2732cb49eba4252c2af4240442415b4abce3b87022a8f50" 72 | ], 73 | "markers": "python_version < '3.3'", 74 | "version": "==1.0.2" 75 | }, 76 | "futures": { 77 | "hashes": [ 78 | "sha256:c4884a65654a7c45435063e14ae85280eb1f111d94e542396717ba9828c4337f", 79 | "sha256:51ecb45f0add83c806c68e4b06106f90db260585b25ef2abfcda0bd95c0132fd" 80 | ], 81 | "markers": "python_version < '3'", 82 | "version": "==3.1.1" 83 | }, 84 | "gast": { 85 | "hashes": [ 86 | "sha256:7068908321ecd2774f145193c4b34a11305bd104b4551b09273dfd1d6a374930" 87 | ], 88 | "version": "==0.2.0" 89 | }, 90 | "grpcio": { 91 | "hashes": [ 92 | "sha256:a1bc37c9910d0fbf4d9e80d5822f92c6e01e28dd1eb01323636ed19666b537cb", 93 | "sha256:14bca481188c3f19135012aaff9abefa4e15529c7e1aca3084183d78094d06bd", 94 | "sha256:4fa658a7e1ba5727ca066b1c8bb64c6befb98f2b8007f04a16c7c84555bf11b9", 95 | "sha256:982439a872d41f969724efc139e0416ba45e0d7446e9a41fd2ebe19351adff9a", 96 | "sha256:025a591606b0aca13bec3e019d6acec01a39421f01b915b98a3a93ea0a53b412", 97 | "sha256:d410835e7554d064c2d99cfa0dd393ffbb0ccf52145ab51c725a8472ed254a3c", 98 | "sha256:88afda198adb0a9da52a66152062027a57877b46f59ffcf55acc3cbfaff77160", 99 | "sha256:ae82bf2f7ceac6ba956e816120b4f66bda035571350e46b61bbdde1808aed1dd", 100 | "sha256:10efe2e016c3ca7a370771ffcf1de9baa3456d4bccefde0f4ce3be091d871c8f", 101 | "sha256:224c513fbe0c3ca546870e5c21b08a8a56cd25795b76b3192ee9702a3344764b", 102 | "sha256:aa473b8276de39eeccc4ad6cbb7fd7feab0868180d72c0c93226033c79fa69b7", 103 | "sha256:022dc2a6d1537a5a16af4ccc3355ad7b512f9c627a1d5d579cd7c18830378bb3", 104 | "sha256:d9e3105f6de6cb759b028702bdd21cb36d27e010227669e43c675b9957a3c180", 105 | "sha256:b56e4f355c2499bb0bf8f8f4d0362b618b06afdfd2c10722710596dc7e295c6c", 106 | "sha256:e86639989c03831912fd9924beda26f6e9ffcc267656cea035bde9d88cf793b2", 107 | "sha256:f4a38071dd27f140cfe774f56aecdf0e33de926c21289cc9c7521ce8dd91fc1c", 108 | "sha256:2894466c499d9752e0d49ee8adc5ee12c676d86211fc1b292bf713cc7cfe9853", 109 | "sha256:1bc36e512741f82c1d73f42df536aa2ab75d840f0d35c149b5d0bee1aed16862", 110 | "sha256:ea9564f58144e2f07995d57fb8e636be5efb084cd59c8651391ada2bb75dc0ff", 111 | "sha256:435b3bab2e34814666854eec203c77b169df1cd56cf22fe449cf5510af416e7d", 112 | "sha256:da306c80d69801a3e4115c448ed4ad481957d723ec1e00b99497c6661573c3e5", 113 | "sha256:17240d672b5c1c9ff22e52236c1870413b7fb5af762b97ce5a747a55e0a57e98", 114 | "sha256:03265472d39bf26f124c3ef68446f7873c8260893e6ae65b323a5b51ed52e580" 115 | ], 116 | "version": "==1.10.0" 117 | }, 118 | "html5lib": { 119 | "hashes": [ 120 | "sha256:2612a191a8d5842bfa057e41ba50bbb9dcb722419d2408c78cff4758d0754868" 121 | ], 122 | "version": "==0.9999999" 123 | }, 124 | "markdown": { 125 | "hashes": [ 126 | "sha256:9ba587db9daee7ec761cfc656272be6aabe2ed300fece21208e4aab2e457bc8f", 127 | "sha256:a856869c7ff079ad84a3e19cd87a64998350c2b94e9e08e44270faef33400f81" 128 | ], 129 | "version": "==2.6.11" 130 | }, 131 | "mock": { 132 | "hashes": [ 133 | "sha256:5ce3c71c5545b472da17b72268978914d0252980348636840bd34a00b5cc96c1", 134 | "sha256:b158b6df76edd239b8208d481dc46b6afd45a846b7812ff0ce58971cf5bc8bba" 135 | ], 136 | "version": "==2.0.0" 137 | }, 138 | "numpy": { 139 | "hashes": [ 140 | "sha256:e2335d56d2fd9fc4e3a3f2d3148aafec4962682375f429f05c45a64dacf19436", 141 | "sha256:9b762e78739b6e021124adbea07611682db99cd3fca7f3c3a8b98b8f74ea5699", 142 | "sha256:7d4c549e41507db4f04ec7cfab5597de8acf7871b16c9cf64cebcb9d39031ca6", 143 | "sha256:b803306c4c201e7dcda0ce1b9a9c87f61a7c7ce43de2c60c8e56147b76849a1a", 144 | "sha256:2da8dff91d489fea3e20155d41f4cd680de7d01d9a89fdd0ebb1bee6e72d3800", 145 | "sha256:6b8c2daacbbffc83b4a2ba83a61aa3ce60c66340b07b962bd27b6c6bb175bee1", 146 | "sha256:89b9419019c47ec87cf4cfca77d85da4611cc0be636ec87b5290346490b98450", 147 | "sha256:49880b47d7272f902946dd995f346842c95fe275e2deb3082ef0495f0c718a69", 148 | "sha256:3d7ddd5bdfb12ec9668edf1aa49a4a3eddb0db4661b57ea431477eb9a2468894", 149 | "sha256:788e1757f8e409cd805a7cd82993cd9252fa19e334758a4c6eb5a8b334abb084", 150 | "sha256:377def0873bbb1fbdedb14b3275b10a29b1b55619a3f7f775c4e7f9ce2461b9c", 151 | "sha256:9501c9ccd081977ca5579a3ec4009d6baff6bacb04bf07214aade3324734195a", 152 | "sha256:a1f5173df8190ef9c6235d260d70ca70c6fb029683ceb66e244c5cc6e335947a", 153 | "sha256:12cf4b27039b88e407ad66894d99a957ef60fea0eeb442026af325add2ab264d", 154 | "sha256:4e2fc841c8c642f7fd44591ef856ca409cedba6aea27928df34004c533839eee", 155 | "sha256:e5ade7a69dccbd99c4fdbb95b6d091d941e62ffa588b0ed8fb0a2854118fef3f", 156 | "sha256:6b1011ffc87d7e2b1b7bcc6dc21bdf177163658746ef778dcd21bf0516b9126c", 157 | "sha256:a8bc80f69570e11967763636db9b24c1e3e3689881d10ae793cec74cf7a627b6", 158 | "sha256:81b9d8f6450e752bd82e7d9618fa053df8db1725747880e76fb09710b57f78d0", 159 | "sha256:e8522cad377cc2ef20fe13aae742cc265172910c98e8a0d6014b1a8d564019e2", 160 | "sha256:a3d5dd437112292c707e54f47141be2f1100221242f07eda7bd8477f3ddc2252", 161 | "sha256:c8000a6cbc5140629be8c038c9c9cdb3a1c85ff90bd4180ec99f0f0c73050b5e", 162 | "sha256:fa0944650d5d3fb95869eaacd8eedbd2d83610c85e271bd9d3495ffa9bc4dc9c" 163 | ], 164 | "version": "==1.14.1" 165 | }, 166 | "pbr": { 167 | "hashes": [ 168 | "sha256:60c25b7dfd054ef9bb0ae327af949dd4676aa09ac3a9471cdc871d8a9213f9ac", 169 | "sha256:05f61c71aaefc02d8e37c0a3eeb9815ff526ea28b3b76324769e6158d7f95be1" 170 | ], 171 | "version": "==3.1.1" 172 | }, 173 | "pprint": { 174 | "hashes": [ 175 | "sha256:c0fa22d1462351671ca098e9779bb26a23880011e93eea5f199a150ee7b92a16" 176 | ], 177 | "version": "==0.1" 178 | }, 179 | "protobuf": { 180 | "hashes": [ 181 | "sha256:11788df3e176f44e0375fe6361342d7258a457b346504ea259a21b77ffc18a90", 182 | "sha256:50c24f0d00b7efb3a72ae638ddc118e713cfe8cef40527afe24f7ebcb878e46d", 183 | "sha256:41661f9a442eba2f1967f15333ebe9ecc7e7c51bcbaa2972303ad33a4ca0168e", 184 | "sha256:06ec363b74bceb7d018f2171e0892f03ab6816530e2b0f77d725a58264551e48", 185 | "sha256:b20f861b55efd8206428c13e017cc8e2c34b40b2a714446eb202bbf0ff7597a6", 186 | "sha256:c1f9c36004a7ae6f1ce4a23f06070f6b07f57495f251851aa15cc4da16d08378", 187 | "sha256:4d2e665410b0a278d2eb2c0a529ca2366bb325eb2ae34e189a826b71fb1b28cd", 188 | "sha256:95b78959572de7d7fafa3acb718ed71f482932ddddddbd29ba8319c10639d863" 189 | ], 190 | "version": "==3.5.1" 191 | }, 192 | "six": { 193 | "hashes": [ 194 | "sha256:832dc0e10feb1aa2c68dcc57dbb658f1c7e65b9b61af69048abc87a2db00a0eb", 195 | "sha256:70e8a77beed4562e7f14fe23a786b54f6296e34344c23bc42f07b15018ff98e9" 196 | ], 197 | "version": "==1.11.0" 198 | }, 199 | "tensorboard": { 200 | "hashes": [ 201 | "sha256:835ecbfee0ce505f8365435c23d1c7d6fd527fcad8a4829e4145b9f6f41d0ca2", 202 | "sha256:ab3e4568a277d4d06fc2928e6e5aa2a32ece073ad234a6b7ade08671dbf8f339" 203 | ], 204 | "version": "==1.6.0" 205 | }, 206 | "tensorflow": { 207 | "hashes": [ 208 | "sha256:d9c628c857ccd9213d75e0747fcad6ef4d6866a79d1a3987d87a5957c3f0f819", 209 | "sha256:6d6f8ac26f857ab79c13b594efcb13d5485d1dbfdaa8e54b064ea7b4b28a9cf8", 210 | "sha256:8c504b9c70649d8602592b148944b34207fde8d1f5d61aaf435dab99118a58bc", 211 | "sha256:638e690a4efc675dbfd0bb98f0c5db1b6b54e9791513fced4a4df80a3db946cb", 212 | "sha256:b6b4c52c61f9cea58a93e4c178e8acec157ab2ed8824350ec536f107fa67a491", 213 | "sha256:1ba17a5df0c2d3e0000f88479b3ce7a35e05ced3056e4594781e81afb45ac8fd", 214 | "sha256:eaa4af9eb161af500a5e3c65dcc434d63478801539f51cdd8c01572eda9dfa2d", 215 | "sha256:afb1f05ab3f8ff569b91be5ecb066b926dff1afd029f4f296964317f461268a6", 216 | "sha256:2f89cf8a05198e7e2415d040b679b4f50ef86c0c3834f4bf86d0b836ecc14a27", 217 | "sha256:17981691a5711dd8b0479b293229f3a22d2f0503d04c2691a2ef66ee50903257", 218 | "sha256:d0e5daf51cf9a711a0dc95b5bf4a9eeeb1cb485197cb923b8940f2a905850407", 219 | "sha256:0b41fb0d1d0be2d495dad67138b5a118f7b206a07b2c96474414109dbfe056a8" 220 | ], 221 | "version": "==1.6.0" 222 | }, 223 | "termcolor": { 224 | "hashes": [ 225 | "sha256:1d6d69ce66211143803fbc56652b41d73b4a400a2891d7bf7a1cdf4c02de613b" 226 | ], 227 | "version": "==1.1.0" 228 | }, 229 | "werkzeug": { 230 | "hashes": [ 231 | "sha256:d5da73735293558eb1651ee2fddc4d0dedcfa06538b8813a2e20011583c9e49b", 232 | "sha256:c3fd7a7d41976d9f44db327260e263132466836cef6f91512889ed60ad26557c" 233 | ], 234 | "version": "==0.14.1" 235 | }, 236 | "wheel": { 237 | "hashes": [ 238 | "sha256:e721e53864f084f956f40f96124a74da0631ac13fbbd1ba99e8e2b5e9cafdf64", 239 | "sha256:9515fe0a94e823fd90b08d22de45d7bde57c90edce705b22f5e1ecf7e1b653c8" 240 | ], 241 | "markers": "python_version < '3'", 242 | "version": "==0.30.0" 243 | } 244 | }, 245 | "develop": {} 246 | } 247 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Which optimizer and learning rate should I use for deep learning? 3 | 4 | A common problem we all face when working on deep learning projects is chosing hyper-parameters. If you’re like me, you find yourself guessing an optimizer and learning rate, then checking if they work (and we’re not alone). This is laborious and error prone. 5 | 6 | To better understand the affect of optimizer and learning rate choice, I trained the same model 500 times. The results show that the right hyper-parameters are crucial to training success. 7 | 8 | In this article I’ll show the results of training the same model across 6 different optimizers and 48 different learning rates. I’ll also show the results of how scaling the model up 10x affects its training on fixed hyper-parameters. 9 | 10 | The results show that: 11 | - Most learning-rates will fail to train the model 12 | - Training time vs learning rate exhibits a “valley” shape with the fastest training occuring in a narrow band of learning rates 13 | - Each optimizer has a different optimal learning rate 14 | - No one learning rate will successfully train across all optimizers tested 15 | 16 | 17 | ### Read more 18 | - [Read the full article](https://medium.com/octavian-ai/which-optimizer-and-learning-rate-should-i-use-for-deep-learning-5acb418f9b2) 19 | - [See the code on GitHub](https://github.com/Octavian-ai/learning-rates) 20 | - [Run the code on FloydHub](https://www.floydhub.com/davidmack/projects/learning-rates/) 21 | - [Octavian.ai](https://octavian.ai) 22 | 23 | -------------------------------------------------------------------------------- /analysis/Max min Time_to_train_vs_learning_rate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Octavian-ai/learning-rates/bed9edbc412949d0514ed7180221256688657be5/analysis/Max min Time_to_train_vs_learning_rate.png -------------------------------------------------------------------------------- /analysis/Time_to_train_vs_learning_rate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Octavian-ai/learning-rates/bed9edbc412949d0514ed7180221256688657be5/analysis/Time_to_train_vs_learning_rate.png -------------------------------------------------------------------------------- /analysis/Time_to_train_vs_learning_rate.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Octavian-ai/learning-rates/bed9edbc412949d0514ed7180221256688657be5/analysis/Time_to_train_vs_learning_rate.xlsx -------------------------------------------------------------------------------- /analysis/Time_to_train_vs_size_of_model.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Octavian-ai/learning-rates/bed9edbc412949d0514ed7180221256688657be5/analysis/Time_to_train_vs_size_of_model.xlsx -------------------------------------------------------------------------------- /analysis/Validation_accuracy_over_time matlib.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Octavian-ai/learning-rates/bed9edbc412949d0514ed7180221256688657be5/analysis/Validation_accuracy_over_time matlib.png -------------------------------------------------------------------------------- /analysis/Validation_accuracy_over_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Octavian-ai/learning-rates/bed9edbc412949d0514ed7180221256688657be5/analysis/Validation_accuracy_over_time.png -------------------------------------------------------------------------------- /analysis/Validation_accuracy_over_time.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Octavian-ai/learning-rates/bed9edbc412949d0514ed7180221256688657be5/analysis/Validation_accuracy_over_time.xlsx -------------------------------------------------------------------------------- /analysis/time-vs-size-all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Octavian-ai/learning-rates/bed9edbc412949d0514ed7180221256688657be5/analysis/time-vs-size-all.png -------------------------------------------------------------------------------- /analysis/time-vs-size-single.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Octavian-ai/learning-rates/bed9edbc412949d0514ed7180221256688657be5/analysis/time-vs-size-single.png -------------------------------------------------------------------------------- /hooks.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import tensorflow as tf 4 | from tensorflow.python.framework import ops 5 | from tensorflow.python.training import session_run_hook 6 | from tensorflow.python.training.basic_session_run_hooks import _as_graph_element 7 | 8 | import numpy as np 9 | 10 | import time 11 | 12 | class EarlyStopping(session_run_hook.SessionRunHook): 13 | 14 | def __init__(self, metric, start_time, target=0.97, check_every=100, max_secs=10): 15 | self.metric = metric 16 | self.target = target 17 | self.counter = 0 18 | self.check_every = check_every 19 | self.max_secs = max_secs 20 | self.start_time = start_time 21 | 22 | def before_run(self, run_context): 23 | self.counter += 1 24 | self.should_check = (self.counter % self.check_every) == 0 25 | 26 | if self.should_check: 27 | return session_run_hook.SessionRunArgs([self.metric]) 28 | 29 | def after_run(self, run_context, run_values): 30 | if self.should_check and run_values.results is not None: 31 | t = run_values.results[0][1] 32 | if t > self.target: 33 | tf.logging.info(f"Early stopping as exceeded target {t} > {self.target}") 34 | run_context.request_stop() 35 | 36 | if (time.time() - self.start_time) > self.max_secs: 37 | tf.logging.info(f"EarlyStopping as time run out {time.time() - self.start_time} > {self.max_secs}") 38 | run_context.request_stop() 39 | 40 | 41 | 42 | class CallbackHook(session_run_hook.SessionRunHook): 43 | def __init__(self, metrics=None, callback_after=None, callback_end=None): 44 | self.metrics = metrics 45 | self.callback_after = callback_after 46 | self.callback_end = callback_end 47 | 48 | def before_run(self, run_context): 49 | if self.metrics is not None: 50 | return session_run_hook.SessionRunArgs(self.metrics) 51 | 52 | def after_run(self, run_context, run_values): 53 | if self.callback_after is not None: 54 | self.callback_after(run_context, run_values) 55 | 56 | def end(self, session): 57 | if self.callback_end is not None: 58 | self.callback_end(session) 59 | 60 | 61 | 62 | class LastMetricHook(session_run_hook.SessionRunHook): 63 | def __init__(self, metric, cb): 64 | self.metric = metric 65 | self.cb = cb 66 | self.reading = None 67 | 68 | def before_run(self, run_context): 69 | return session_run_hook.SessionRunArgs([self.metric]) 70 | 71 | def after_run(self, run_context, run_values): 72 | self.reading = run_values.results[0][1] 73 | 74 | def end(self, session): 75 | self.cb(self.reading) 76 | 77 | 78 | class MetricHook(session_run_hook.SessionRunHook): 79 | def __init__(self, metric, cb): 80 | self.metric = metric 81 | self.cb = cb 82 | self.readings = [] 83 | 84 | def before_run(self, run_context): 85 | return session_run_hook.SessionRunArgs([self.metric]) 86 | 87 | def after_run(self, run_context, run_values): 88 | self.readings.append(run_values.results[0][1]) 89 | 90 | def end(self, session): 91 | self.cb(np.average(self.readings)) 92 | self.readings.clear() 93 | -------------------------------------------------------------------------------- /ploty.py: -------------------------------------------------------------------------------- 1 | 2 | import matplotlib 3 | matplotlib.use("Agg") 4 | 5 | import tensorflow as tf 6 | from matplotlib import pyplot as plt 7 | 8 | from IPython.display import clear_output 9 | import csv 10 | 11 | try: 12 | from google.colab import auth 13 | from googleapiclient.discovery import build 14 | from googleapiclient.http import MediaFileUpload 15 | except: 16 | pass 17 | 18 | 19 | class Ploty(object): 20 | 21 | def __init__(self, output_path, title='', x='', y="Time to training complete", legend=True, log_y=False, log_x=False, clear_screen=True, terminal=True, auto_render=True): 22 | self.output_path = output_path 23 | self.title = title 24 | self.label_x = x 25 | self.label_y = y 26 | self.log_y = log_y 27 | self.log_x = log_x 28 | self.clear_screen = clear_screen 29 | self.legend = legend 30 | self.terminal = terminal 31 | self.auto_render = auto_render 32 | 33 | self.header = ["x", "y", "label"] 34 | self.datas = {} 35 | 36 | self.c_i = 0 37 | self.cmap = plt.cm.get_cmap('hsv', 10) 38 | 39 | self.fig = plt.figure() 40 | self.ax = self.fig.add_subplot(111) 41 | 42 | if self.log_x: 43 | self.ax.set_xscale('log') 44 | 45 | if self.log_y: 46 | self.ax.set_yscale('log') 47 | 48 | 49 | def ensure(self, name, extra_data): 50 | if name not in self.datas: 51 | self.datas[name] = { 52 | "c": self.cmap(self.c_i), 53 | "x": [], 54 | "y": [], 55 | "m": ".", 56 | "l": '-' 57 | } 58 | 59 | for i in extra_data.keys(): 60 | self.datas[name][i] = [] 61 | if i not in self.header: 62 | self.header.append(i) 63 | 64 | self.c_i += 1 65 | 66 | # This method assumes extra_data will have the same keys every single call, otherwise csv writing will crash 67 | def add_result(self, x, y, name, marker="o", line="-", extra_data={}): 68 | self.ensure(name, extra_data) 69 | self.datas[name]["x"].append(x) 70 | self.datas[name]["y"].append(y) 71 | self.datas[name]["m"] = marker 72 | self.datas[name]["l"] = line 73 | 74 | for key, value in extra_data.items(): 75 | self.datas[name][key].append(value) 76 | 77 | if self.terminal: 78 | print(f'{{"metric": "{name}", "value": {y}, "x": {x} }}') 79 | 80 | if self.auto_render: 81 | self.render() 82 | self.save_csv() 83 | 84 | def runningMeanFast(x, N): 85 | return np.convolve(np.array(x), np.ones((N,))/N)[(N-1):] 86 | 87 | def render(self): 88 | self.render_pre() 89 | 90 | for k, d in self.datas.items(): 91 | plt.plot(d['x'], d['y'], d["l"]+d["m"], label=k) 92 | 93 | self.render_post() 94 | 95 | 96 | def render_pre(self): 97 | if self.clear_screen and not self.terminal: 98 | clear_output() 99 | 100 | plt.cla() 101 | 102 | def render_post(self): 103 | img_name = self.output_path + '/' + self.title.replace(" ", "_") + '.png' 104 | 105 | artists = [] 106 | 107 | self.fig.suptitle(self.title, fontsize=14, fontweight='bold') 108 | self.ax.set_xlabel(self.label_x) 109 | self.ax.set_ylabel(self.label_y) 110 | 111 | if self.legend: 112 | lgd = plt.legend(bbox_to_anchor=(1.04,0.5), loc="center left", borderaxespad=0) 113 | artists.append(lgd) 114 | 115 | try: 116 | os.remove(img_name) 117 | except: 118 | pass 119 | 120 | plt.savefig(img_name, bbox_extra_artists=artists, bbox_inches='tight') 121 | tf.logging.info("Saved image: " + img_name) 122 | 123 | if not self.terminal: 124 | plt.show() 125 | 126 | def save_csv(self): 127 | try: 128 | os.remove(csv_name) 129 | except: 130 | pass 131 | 132 | csv_name = self.output_path + '/' + self.title.replace(" ", "_") + '.csv' 133 | 134 | with open(csv_name, 'w') as csvfile: 135 | writer = csv.writer(csvfile) 136 | writer.writerow(self.header) 137 | 138 | for k, d in self.datas.items(): 139 | for i in range(len(d["x"])): 140 | row = [ 141 | k if h == "label" else d[h][i] for h in self.header 142 | ] 143 | writer.writerow(row) 144 | 145 | tf.logging.info("Saved CSV: " + csv_name) 146 | 147 | 148 | def copy_to_drive(self, snapshot=False): 149 | auth.authenticate_user() 150 | drive_service = build('drive', 'v3') 151 | 152 | if snapshot: 153 | name = self.title + "_latest" 154 | else: 155 | name = self.title +'_' + str(datetime.now()) 156 | 157 | def do_copy(source_name, dest_name, mime): 158 | file_metadata = { 159 | 'name': dest_name, 160 | 'mimeType': mime 161 | } 162 | media = MediaFileUpload(self.output_path + source_name, 163 | mimetype=file_metadata['mimeType'], 164 | resumable=True) 165 | 166 | created = drive_service.files().create(body=file_metadata, 167 | media_body=media, 168 | fields='id').execute() 169 | 170 | do_copy(self.title+'.csv', name + '.csv', 'text/csv') 171 | do_copy(self.title+'.png', name + '.png', 'image/png') 172 | 173 | 174 | -------------------------------------------------------------------------------- /run-notebook.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | floyd run --mode jupyter --gpu --env tensorflow-1.5 -------------------------------------------------------------------------------- /run-time-vs-lr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | floyd run --gpu --env tensorflow-1.5 "python train.py --output-dir /output --task time_vs_lr" -------------------------------------------------------------------------------- /run-time-vs-size.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | floyd run --gpu --env tensorflow-1.5 "python train.py --output-dir /output --task time_vs_size" -------------------------------------------------------------------------------- /run-trace.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | floyd run --cpu --env tensorflow-1.5 "python train.py --output-dir /output --task trace" -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Learning rates comparison - CNN 3 | 4 | Automatically generated by Colaboratory. 5 | 6 | Original file is located at 7 | https://colab.research.google.com/drive/1ynqfIQK9HgbAHqaED6mBxAVEP2MMsHhb 8 | """ 9 | 10 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 11 | # 12 | # Licensed under the Apache License, Version 2.0 (the "License"); 13 | # you may not use this file except in compliance with the License. 14 | # You may obtain a copy of the License at 15 | # 16 | # http://www.apache.org/licenses/LICENSE-2.0 17 | # 18 | # Unless required by applicable law or agreed to in writing, software 19 | # distributed under the License is distributed on an "AS IS" BASIS, 20 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 21 | # See the License for the specific language governing permissions and 22 | # limitations under the License. 23 | """Convolutional Neural Network Estimator for MNIST, built with tf.layers.""" 24 | 25 | import time 26 | from datetime import datetime 27 | import traceback 28 | import uuid 29 | import shutil 30 | import os 31 | import argparse 32 | 33 | import numpy as np 34 | import tensorflow as tf 35 | 36 | from ploty import Ploty 37 | from hooks import * 38 | 39 | 40 | 41 | class Model(object): 42 | 43 | def __init__(self, 44 | optimizer_fn=None, 45 | val_target=0.99, 46 | max_secs=100, 47 | scale=1, 48 | output_path="/tmp/", 49 | train_callback=None, 50 | eval_callback=None, 51 | train_end_callback=None, 52 | check_stopping_every=50): 53 | 54 | self.optimizer_fn = optimizer_fn 55 | self.val_target = val_target 56 | self.max_secs = max_secs 57 | self.scale = scale 58 | self.output_path = output_path 59 | self.train_callback = train_callback 60 | self.train_end_callback = train_end_callback 61 | self.eval_callback = eval_callback 62 | self.check_stopping_every = check_stopping_every 63 | self.early_stop = True 64 | 65 | self.start_time = time.time() 66 | 67 | # Load training and eval data 68 | mnist = tf.contrib.learn.datasets.load_dataset("mnist") 69 | train_data = mnist.train.images # Returns np.array 70 | train_labels = np.asarray(mnist.train.labels, dtype=np.int32) 71 | eval_data = mnist.test.images # Returns np.array 72 | eval_labels = np.asarray(mnist.test.labels, dtype=np.int32) 73 | 74 | # Data input functions 75 | self.train_input_fn = tf.estimator.inputs.numpy_input_fn( 76 | x={"x": train_data}, 77 | y=train_labels, 78 | batch_size=100, 79 | num_epochs=None, 80 | shuffle=True) 81 | 82 | self.eval_input_fn = tf.estimator.inputs.numpy_input_fn( 83 | x={"x": eval_data}, 84 | y=eval_labels, 85 | num_epochs=1, 86 | shuffle=False) 87 | 88 | # Create a model 89 | # This lambda hack removes the self reference 90 | self.model_fn = lambda features, labels, mode: self.model_fn_bare(features, labels, mode) 91 | 92 | 93 | 94 | def model_fn_bare(self, features, labels, mode): 95 | """Model function for CNN.""" 96 | 97 | # Input Layer 98 | # Reshape X to 4-D tensor: [batch_size, width, height, channels] 99 | # MNIST images are 28x28 pixels, and have one color channel 100 | input_layer = tf.reshape(features["x"], [-1, 28, 28, 1]) 101 | 102 | # Convolutional Layer #1 103 | # Computes 32 features using a 5x5 filter with ReLU activation. 104 | # Padding is added to preserve width and height. 105 | # Input Tensor Shape: [batch_size, 28, 28, 1] 106 | # Output Tensor Shape: [batch_size, 28, 28, 32] 107 | conv1 = tf.layers.conv2d( 108 | inputs=input_layer, 109 | filters=round(32*self.scale), 110 | kernel_size=[5, 5], 111 | padding="same", 112 | activation=tf.nn.relu) 113 | 114 | # Pooling Layer #1 115 | # First max pooling layer with a 2x2 filter and stride of 2 116 | # Input Tensor Shape: [batch_size, 28, 28, 32] 117 | # Output Tensor Shape: [batch_size, 14, 14, 32] 118 | pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2) 119 | 120 | # Convolutional Layer #2 121 | # Computes 64 features using a 5x5 filter. 122 | # Padding is added to preserve width and height. 123 | # Input Tensor Shape: [batch_size, 14, 14, 32] 124 | # Output Tensor Shape: [batch_size, 14, 14, 64] 125 | conv2 = tf.layers.conv2d( 126 | inputs=pool1, 127 | filters=round(64 * self.scale), 128 | kernel_size=[5, 5], 129 | padding="same", 130 | activation=tf.nn.relu) 131 | 132 | # Pooling Layer #2 133 | # Second max pooling layer with a 2x2 filter and stride of 2 134 | # Input Tensor Shape: [batch_size, 14, 14, 64] 135 | # Output Tensor Shape: [batch_size, 7, 7, 64] 136 | pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2) 137 | 138 | # Flatten tensor into a batch of vectors 139 | # Input Tensor Shape: [batch_size, 7, 7, 64] 140 | # Output Tensor Shape: [batch_size, 7 * 7 * 64] 141 | pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * round(self.scale* 64)]) 142 | 143 | # Dense Layer 144 | # Densely connected layer with 1024 neurons 145 | # Input Tensor Shape: [batch_size, 7 * 7 * 64] 146 | # Output Tensor Shape: [batch_size, 1024] 147 | dense = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu) 148 | 149 | # Add dropout operation; 0.6 probability that element will be kept 150 | dropout = tf.layers.dropout( 151 | inputs=dense, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN) 152 | 153 | # Logits layer 154 | # Input Tensor Shape: [batch_size, 1024] 155 | # Output Tensor Shape: [batch_size, 10] 156 | logits = tf.layers.dense(inputs=dropout, units=10) 157 | 158 | predictions = { 159 | # Generate predictions (for PREDICT and EVAL mode) 160 | "classes": tf.argmax(input=logits, axis=1), 161 | # Add `softmax_tensor` to the graph. It is used for PREDICT and by the 162 | # `logging_hook`. 163 | "probabilities": tf.nn.softmax(logits, name="softmax_tensor") 164 | } 165 | 166 | # Add evaluation metrics (for EVAL mode) 167 | eval_metric_ops = { 168 | "accuracy": tf.metrics.accuracy( 169 | labels=labels, predictions=predictions["classes"]) 170 | } 171 | 172 | 173 | # Hooks 174 | train_hooks = [] 175 | eval_hooks = [] 176 | 177 | early_stop = EarlyStopping( 178 | eval_metric_ops["accuracy"], 179 | start_time=self.start_time, 180 | target=self.val_target, 181 | check_every=self.check_stopping_every, 182 | max_secs=self.max_secs) 183 | 184 | if self.early_stop: 185 | train_hooks.append(early_stop) 186 | 187 | if self.train_end_callback is not None: 188 | m = LastMetricHook(eval_metric_ops["accuracy"], self.train_end_callback) 189 | train_hooks.append(m) 190 | 191 | if self.train_callback is not None: 192 | m = MetricHook(eval_metric_ops["accuracy"], self.train_callback) 193 | train_hooks.append(m) 194 | 195 | if self.eval_callback is not None: 196 | m = MetricHook(eval_metric_ops["accuracy"], self.eval_callback) 197 | eval_hooks.append(m) 198 | 199 | ### Create EstimatorSpecs ### 200 | 201 | if mode == tf.estimator.ModeKeys.PREDICT: 202 | return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) 203 | 204 | # Calculate Loss (for both TRAIN and EVAL modes) 205 | loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) 206 | 207 | # Configure the Training Op (for TRAIN mode) 208 | if mode == tf.estimator.ModeKeys.TRAIN: 209 | global_step = tf.train.get_global_step() 210 | self.optimizer = self.optimizer_fn(global_step) 211 | train_op = self.optimizer.minimize( 212 | loss=loss, 213 | global_step=global_step) 214 | 215 | return tf.estimator.EstimatorSpec( 216 | mode=mode, 217 | loss=loss, 218 | train_op=train_op, 219 | training_hooks=train_hooks) 220 | 221 | if mode == tf.estimator.ModeKeys.EVAL: 222 | return tf.estimator.EstimatorSpec( 223 | mode=mode, 224 | loss=loss, 225 | eval_metric_ops=eval_metric_ops, 226 | evaluation_hooks=eval_hooks) 227 | 228 | def generate_config(self): 229 | # Create the Estimator 230 | model_dir = self.output_path + str(uuid.uuid1()) 231 | 232 | config = tf.estimator.RunConfig( 233 | model_dir=model_dir, 234 | tf_random_seed=3141592 235 | ) 236 | 237 | return config 238 | 239 | def post_run(self, config): 240 | 241 | try: 242 | shutil.rmtree(config.model_dir) 243 | except: 244 | pass 245 | 246 | 247 | def train_and_evaluate(self, max_steps, eval_throttle_secs): 248 | 249 | config = self.generate_config() 250 | 251 | mnist_classifier = tf.estimator.Estimator( 252 | model_fn=self.model_fn, config=config) 253 | 254 | # Specs for train and eval 255 | train_spec = tf.estimator.TrainSpec(input_fn=self.train_input_fn, max_steps=max_steps) 256 | eval_spec = tf.estimator.EvalSpec(input_fn=self.eval_input_fn, throttle_secs=eval_throttle_secs) 257 | 258 | tf.estimator.train_and_evaluate(mnist_classifier, train_spec, eval_spec) 259 | 260 | self.post_run(config) 261 | 262 | def train(self, steps=None, max_steps=None): 263 | 264 | config = self.generate_config() 265 | 266 | mnist_classifier = tf.estimator.Estimator( 267 | model_fn=self.model_fn, config=config) 268 | 269 | r = mnist_classifier.train(self.train_input_fn, steps=steps, max_steps=max_steps) 270 | 271 | self.post_run(config) 272 | 273 | return r 274 | 275 | 276 | 277 | 278 | ### Static data ### 279 | 280 | output_path = "/tmp/" 281 | 282 | optimizers = { 283 | "Adam": tf.train.AdamOptimizer, 284 | "Adagrad": tf.train.AdagradOptimizer, 285 | "Momentum": lambda lr: tf.train.MomentumOptimizer(lr, 0.5), 286 | "GD": tf.train.GradientDescentOptimizer, 287 | "Adadelta": tf.train.AdadeltaOptimizer, 288 | "RMSProp": tf.train.RMSPropOptimizer, 289 | } 290 | 291 | # The best learning rates our grid search identified 292 | ideal_lr = { 293 | "Adam": 0.00146, 294 | "Adagrad": 0.1, 295 | "Momentum": 0.215, 296 | "GD": 0.215, 297 | "Adadelta": 3.16, 298 | "RMSProp": 0.00146, 299 | } 300 | 301 | schedules = [ 302 | # "exp_decay", 303 | "fixed", 304 | # "cosine_restart" 305 | ] 306 | 307 | 308 | 309 | 310 | ### Learning rates ### 311 | 312 | # A logarithmic grid search of learning rates 313 | def LRRange(mul=5): 314 | 315 | for i in range(mul*6, 0, -1): 316 | lr = pow(0.1, i/mul) 317 | yield lr 318 | 319 | for i in range(1, 2*mul+1): 320 | lr = pow(10, i/mul) 321 | yield lr 322 | 323 | 324 | def LRRangeAdam(): 325 | 326 | yield ideal_lr["Adam"] 327 | 328 | for i in range(1, 5): 329 | lr = pow(0.1, i) 330 | yield lr 331 | 332 | 333 | def lr_schedule(optimizer, starter_learning_rate=0.1, 334 | global_step=None, mode="fixed", 335 | decay_rate=0.96, decay_steps=100, 336 | cycle_lr_decay=0.001, cycle_length=1000): 337 | 338 | if mode == "fixed": 339 | return optimizer(starter_learning_rate) 340 | 341 | elif mode == "exp_decay": 342 | lr = tf.train.exponential_decay(starter_learning_rate, global_step, 343 | decay_steps, decay_rate, staircase=True) 344 | return optimizer(lr) 345 | 346 | elif mode == "cosine_restart": 347 | lr = tf.train.cosine_decay_restarts( 348 | starter_learning_rate, 349 | global_step, 350 | cycle_length, 351 | alpha=cycle_lr_decay) 352 | 353 | return optimizer(lr) 354 | 355 | elif mode == "triangle": 356 | 357 | min_lr = starter_learning_rate * cycle_lr_decay 358 | 359 | cycle = tf.floor(1+global_step/(2*cycle_length)) 360 | x = tf.abs(global_step/cycle_length - 2*cycle + 1) 361 | lr = starter_learning_rate + (starter_learning_rate-min_lr)*tf.maximum(0, (1-x))/float(2**(cycle-1)) 362 | 363 | 364 | 365 | 366 | 367 | def build_model( 368 | FLAGS, 369 | max_secs, 370 | optimizer="Adam", 371 | schedule="fixed", 372 | lr=0.01, 373 | scale=1, 374 | train_callback=None, 375 | eval_callback=None, 376 | train_end_callback=None, 377 | stop_after_acc=0.97): 378 | 379 | print(f"Starting run {optimizer}({lr}) scale={scale}") 380 | 381 | opt = optimizers[optimizer] 382 | 383 | def get_optimizer(global_step): 384 | return lr_schedule(opt, lr, global_step=global_step, mode=schedule) 385 | 386 | m = Model( 387 | optimizer_fn=get_optimizer, 388 | val_target=stop_after_acc, 389 | max_secs=max_secs, 390 | scale=scale, 391 | train_callback=train_callback, 392 | eval_callback=eval_callback, 393 | train_end_callback=train_end_callback, 394 | check_stopping_every=50) 395 | 396 | return m 397 | 398 | 399 | def prewarm(FLAGS): 400 | # Warm up the system caches - throw this result away 401 | # If we don't do this the first result is falsely slower 402 | m = build_model( 403 | FLAGS, 404 | max_secs=60*4, 405 | optimizer="Adam", 406 | schedule="fixed", 407 | lr=0.001, 408 | scale=0.4, 409 | stop_after_acc=0.1 410 | ) 411 | m.train() 412 | 413 | def plt_time_vs_lr(FLAGS): 414 | 415 | prewarm(FLAGS) 416 | scale = FLAGS.scale 417 | 418 | p = Ploty(output_path=FLAGS.output_dir, title="Time to train vs learning rate", x="Learning rate",log_x=True, log_y=True) 419 | 420 | for opt in optimizers.keys(): 421 | for sched in schedules: 422 | for lr in LRRange(6): 423 | try: 424 | # Hack for variable scopes 425 | d = {} 426 | 427 | def cb(acc): 428 | taken = time.time() - d["time_start"] 429 | print("Finished!", acc, taken) 430 | if acc >= FLAGS.stop_after_acc: 431 | p.add_result(lr, taken, opt, extra_data={"acc":acc, "lr": lr, "opt": opt, "scale":scale, "time":taken, "schedule": sched}) 432 | else: 433 | tf.logging.error("Failed to train.") 434 | 435 | m = build_model( 436 | FLAGS, 437 | max_secs=60*4, 438 | optimizer=opt, 439 | schedule=sched, 440 | lr=lr, 441 | scale=scale, 442 | train_end_callback=cb, 443 | stop_after_acc=FLAGS.stop_after_acc 444 | ) 445 | 446 | d["time_start"] = time.time() 447 | m.train() 448 | 449 | except Exception: 450 | traceback.print_exc() 451 | pass 452 | 453 | 454 | 455 | def plt_time_vs_model_size(FLAGS): 456 | 457 | oversample = FLAGS.oversample 458 | stop_after_acc = 0.96 459 | prewarm(FLAGS) 460 | 461 | # Perform real experiment 462 | p = Ploty(output_path=FLAGS.output_dir, title="Time to train vs size of model", x="Model scale", clear_screen=True) 463 | for opt in ["Adam"]: 464 | for sched in schedules: 465 | for lr in LRRangeAdam(): 466 | for i in range(1*oversample, 10*oversample): 467 | scale = i/oversample 468 | 469 | try: 470 | # Hack for variable scopes 471 | d = {} 472 | 473 | def cb(acc): 474 | taken = time.time() - d["time_start"] 475 | if acc >= FLAGS.stop_after_acc: 476 | p.add_result(scale, taken, opt+"("+str(lr)+")", extra_data={"acc":acc, "lr": lr, "opt": opt, "scale":scale, "time":taken}) 477 | else: 478 | tf.logging.error("Failed to train.") 479 | 480 | m = build_model( 481 | FLAGS, 482 | max_secs=60*4, 483 | optimizer=opt, 484 | schedule=sched, 485 | lr=lr, 486 | scale=scale, 487 | train_end_callback=cb, 488 | stop_after_acc=FLAGS.stop_after_acc 489 | ) 490 | 491 | d["time_start"] = time.time() 492 | m.train() 493 | 494 | except Exception: 495 | traceback.print_exc() 496 | pass 497 | 498 | 499 | 500 | def plt_train_trace(FLAGS): 501 | p = Ploty( 502 | output_path=FLAGS.output_dir, 503 | title="Validation accuracy over time", 504 | x="Time", 505 | y="Validation accuracy", 506 | log_x=True, 507 | log_y=True, 508 | legend=True) 509 | 510 | sched = "fixed" 511 | 512 | for opt in optimizers.keys(): 513 | 514 | lr = ideal_lr[opt] 515 | 516 | try: 517 | tf.logging.info(f"Running {opt} {sched} {lr}") 518 | 519 | time_start = time.time() 520 | 521 | def cb(mode): 522 | def d(acc): 523 | taken = time.time() - time_start 524 | p.add_result(taken, acc, opt+"-"+mode) 525 | return d 526 | 527 | m = build_model(FLAGS, 528 | max_steps=70, 529 | optimizer=opt, 530 | schedule=sched, 531 | lr=lr, 532 | scale=FLAGS.scale, 533 | train_callback=cb("train"), 534 | eval_callback=cb("eval"), 535 | eval_throttle_secs=3) 536 | 537 | m.train_and_evaluate(max_steps=70, eval_throttle_secs=3) 538 | 539 | 540 | except Exception: 541 | traceback.print_exc() 542 | pass 543 | 544 | 545 | if __name__ == "__main__": 546 | 547 | tf.logging.set_verbosity('INFO') 548 | 549 | tasks = { 550 | "trace": plt_train_trace, 551 | "time_vs_lr": plt_time_vs_lr, 552 | "time_vs_size": plt_time_vs_model_size 553 | } 554 | 555 | parser = argparse.ArgumentParser() 556 | parser.add_argument('--max-secs', type=float, default=120) 557 | parser.add_argument('--stop-after-acc', type=float, default=0.96) 558 | parser.add_argument('--scale', type=int, default=3) 559 | parser.add_argument('--oversample', type=int, default=4) 560 | parser.add_argument('--task', type=str, choices=tasks.keys(),required=True) 561 | parser.add_argument('--output-dir', type=str, default="./output") 562 | 563 | FLAGS = parser.parse_args() 564 | 565 | tf.logging.info("starting...") 566 | tasks[FLAGS.task](FLAGS) 567 | 568 | --------------------------------------------------------------------------------