├── .gitignore ├── LICENSE ├── Pipfile ├── Pipfile.lock ├── README.md ├── classifier.png ├── code ├── MNISTClassifier.py ├── requirements.txt └── train.py ├── create_execution_role.sh ├── main.py ├── notebook ├── .ipynb_checkpoints │ └── lightning_sagemaker-checkpoint.ipynb ├── lightning_logs │ ├── version_7 │ │ ├── checkpoints │ │ │ └── epoch=2.ckpt │ │ ├── events.out.tfevents.1594563548.ip-172-16-16-124.31276.7 │ │ └── hparams.yaml │ ├── version_8 │ │ ├── checkpoints │ │ │ └── epoch=9.ckpt │ │ ├── events.out.tfevents.1594563766.ip-172-16-16-124.31276.8 │ │ └── hparams.yaml │ └── version_9 │ │ ├── checkpoints │ │ └── epoch=6.ckpt │ │ ├── events.out.tfevents.1594564263.ip-172-16-16-124.31276.9 │ │ └── hparams.yaml ├── lightning_mnist_experiment.ipynb └── sagemaker_deploy.ipynb ├── sagemaker-run.png └── sagemaker-sdk.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | .vscode 141 | dataset 142 | code/out 143 | model.pth 144 | 145 | .ipynb_checkpoints 146 | lightning_logs 147 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Luca Bianchi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | name = "pypi" 3 | url = "https://pypi.org/simple" 4 | verify_ssl = true 5 | 6 | [dev-packages] 7 | 8 | [packages] 9 | sagemaker = "*" 10 | torch = "*" 11 | pytorch-lightning = "*" 12 | boto3 = "*" 13 | botocore = "*" 14 | 15 | [requires] 16 | python_version = "3.7" 17 | -------------------------------------------------------------------------------- /Pipfile.lock: -------------------------------------------------------------------------------- 1 | { 2 | "_meta": { 3 | "hash": { 4 | "sha256": "089b6eab9cff89b7db1cc2d016c2478e7a11ac9dba4b1a058897ec34805e3603" 5 | }, 6 | "pipfile-spec": 6, 7 | "requires": { 8 | "python_version": "3.7" 9 | }, 10 | "sources": [ 11 | { 12 | "name": "pypi", 13 | "url": "https://pypi.org/simple", 14 | "verify_ssl": true 15 | } 16 | ] 17 | }, 18 | "default": { 19 | "absl-py": { 20 | "hashes": [ 21 | "sha256:34995df9bd7a09b3b8749e230408f5a2a2dd7a68a0d33c12a3d0cb15a041a507", 22 | "sha256:463c38a08d2e4cef6c498b76ba5bd4858e4c6ef51da1a5a1f27139a022e20248" 23 | ], 24 | "markers": "python_full_version >= '3.6.0'", 25 | "version": "==1.3.0" 26 | }, 27 | "bleach": { 28 | "hashes": [ 29 | "sha256:085f7f33c15bd408dd9b17a4ad77c577db66d76203e5984b1bd59baeee948b2a", 30 | "sha256:0d03255c47eb9bd2f26aa9bb7f2107732e7e8fe195ca2f64709fcf3b0a4a085c" 31 | ], 32 | "markers": "python_version >= '3.7'", 33 | "version": "==5.0.1" 34 | }, 35 | "boto3": { 36 | "hashes": [ 37 | "sha256:e6ab26155b2f83798218106580ab2b3cd47691e25aba912e0351502eda8d86e0", 38 | "sha256:f7aa33b382cc9e73ef7f590b885e72732ad2bd9628c5e312c9aeb8ba011c6820" 39 | ], 40 | "index": "pypi", 41 | "version": "==1.14.20" 42 | }, 43 | "botocore": { 44 | "hashes": [ 45 | "sha256:d1bf8c2085719221683edf54913c6155c68705f26ab4a72c45e4de5176a8cf7b", 46 | "sha256:e7fee600092b51ca8016c541d5c50a8b39179d5c184ec3fd430400d99ba0c55a" 47 | ], 48 | "index": "pypi", 49 | "version": "==1.17.20" 50 | }, 51 | "cachetools": { 52 | "hashes": [ 53 | "sha256:6a94c6402995a99c3970cc7e4884bb60b4a8639938157eeed436098bf9831757", 54 | "sha256:f9f17d2aec496a9aa6b76f53e3b614c965223c061982d434d160f930c698a9db" 55 | ], 56 | "markers": "python_version ~= '3.7'", 57 | "version": "==5.2.0" 58 | }, 59 | "certifi": { 60 | "hashes": [ 61 | "sha256:0d9c601124e5a6ba9712dbc60d9c53c21e34f5f641fe83002317394311bdce14", 62 | "sha256:90c1a32f1d68f940488354e36370f6cca89f0f106db09518524c88d6ed83f382" 63 | ], 64 | "markers": "python_full_version >= '3.6.0'", 65 | "version": "==2022.9.24" 66 | }, 67 | "charset-normalizer": { 68 | "hashes": [ 69 | "sha256:5a3d016c7c547f69d6f81fb0db9449ce888b418b5b9952cc5e6e66843e9dd845", 70 | "sha256:83e9a75d1911279afd89352c68b45348559d1fc0506b054b346651b5e7fee29f" 71 | ], 72 | "markers": "python_full_version >= '3.6.0'", 73 | "version": "==2.1.1" 74 | }, 75 | "docutils": { 76 | "hashes": [ 77 | "sha256:6c4f696463b79f1fb8ba0c594b63840ebd41f059e92b31957c46b74a4599b6d0", 78 | "sha256:9e4d7ecfc600058e07ba661411a2b7de2fd0fafa17d1a7f7361cd47b1175c827", 79 | "sha256:a2aeea129088da402665e92e0b25b04b073c04b2dce4ab65caaa38b7ce2e1a99" 80 | ], 81 | "markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3'", 82 | "version": "==0.15.2" 83 | }, 84 | "future": { 85 | "hashes": [ 86 | "sha256:b1bead90b70cf6ec3f0710ae53a525360fa360d306a86583adc6bf83a4db537d" 87 | ], 88 | "markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3'", 89 | "version": "==0.18.2" 90 | }, 91 | "google-auth": { 92 | "hashes": [ 93 | "sha256:ccaa901f31ad5cbb562615eb8b664b3dd0bf5404a67618e642307f00613eda4d", 94 | "sha256:f5d8701633bebc12e0deea4df8abd8aff31c28b355360597f7f2ee60f2e4d016" 95 | ], 96 | "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5'", 97 | "version": "==2.14.1" 98 | }, 99 | "google-auth-oauthlib": { 100 | "hashes": [ 101 | "sha256:3f2a6e802eebbb6fb736a370fbf3b055edcb6b52878bf2f26330b5e041316c73", 102 | "sha256:a90a072f6993f2c327067bf65270046384cda5a8ecb20b94ea9a687f1f233a7a" 103 | ], 104 | "markers": "python_full_version >= '3.6.0'", 105 | "version": "==0.4.6" 106 | }, 107 | "grpcio": { 108 | "hashes": [ 109 | "sha256:16c2e5d498f292d64146da5623986521a694afd2d0823184c71d917a8d2d63d3", 110 | "sha256:19f783b3e164673be26eb1b3b53a40b9ca12b034189e0c15c70561a596035652", 111 | "sha256:248c32ee8a0e7d16cfee74d14c8e169e434edbd74b2146a2a96d06200a8f7d03", 112 | "sha256:2f39dee8a6516089ab5469ebf03031a6a574508728e38627e5c5c6589a4d46b9", 113 | "sha256:32e7d90d02121aab458d1ba5a899ce4639df14b387e3e436eab17d167d23b386", 114 | "sha256:3344952c68588d26e877de2e15934437a027526609d0cea698e41151ac4a50cb", 115 | "sha256:36a4af2a423688d9a0dd70e5f0a118713446c49ddfebff14547b3418d15c3714", 116 | "sha256:39187d5ec73a69bd6ac497aaac59e8c226b6c7c16fc0e00cedf0f83ac1c79daa", 117 | "sha256:3ac1c1b63da14172ef2d9a012076d0c6a5885f0fd4fb6f1104f60da6470b9591", 118 | "sha256:3b9f67b703d0b8a0ca5308f31b5552a9bf1692060e04b929d2242656c0493dcf", 119 | "sha256:3c52aa919c5483179df2612076c9b24608f44277a7978eeeea3014bb31bfae18", 120 | "sha256:3d0d7ee2e5bd43175110d323ba9650595707504e5a7d049d6be337df7b4aee77", 121 | "sha256:425baff0b0182514b739565d95962b080be134499967c925dbd6aa7a5bc9d82b", 122 | "sha256:4bfe3622d2b0406ca867e7558d41a2b7aba312209ab889bfc17ae2d1cd17b6ba", 123 | "sha256:4c012b0d4c7c4ba3f832a339330d35975f0181f032ba66f40dd71bd99849b224", 124 | "sha256:4de91e690d9179aad372ccecf0d3600f37ad63dd17ae061f521b7baff4e3e1a0", 125 | "sha256:63b5d4543bba518ff0c1af3b8031595422a6e243f38d23588b528a08a7851522", 126 | "sha256:694e282673dbf178d08c1d0236f22838ec1e2ea1db144f59fe01f9b0514a2f73", 127 | "sha256:6f644401f26aae6d012c461eb0ffa3d1c4e4f447a23b061baad9c2e201fe1573", 128 | "sha256:748686b1bdd17d5752f35256bb92aacbcb5b6584ad612178a0af4b5301d38f06", 129 | "sha256:74aa0a5157fd28037e2e951cb0228a4a7cdcaaac75d8e5650191dfca8e8f9449", 130 | "sha256:8a4f4de445aa2ccd99abb3a09d27af38eee75b0d411b28db65aebd46fb6373ca", 131 | "sha256:94d2c83590dde9d87ab3ff5f47b293e0c3e908f683ad327f47eb78bb7f7e198d", 132 | "sha256:98a4858613f1bc991f79b2dc5d453cf5a8dd95a4fdab8b6daaaff33d8710dd1e", 133 | "sha256:a103811a4b318abc9b1592cd6cc9187d34bfb2e192eb0fe113dfabb38a162284", 134 | "sha256:a1d9409bc633028f992c05b6b8342dd94b946072fa192c059e626032fb7cc3e0", 135 | "sha256:a29913905bc23b0054be96234b8d39fdf3222430c5bf5756a97278ecd49dad0a", 136 | "sha256:a9fc58eba09407124edbb17d91e091a296702cef486bd03d50d34e40689c1406", 137 | "sha256:bf33ba7e178e75f61f24dd06947599ef04a911d0b384616a4fcb843485f6d92d", 138 | "sha256:c0bae51edce0eff1b4a89b57fead01638632ec2baadca3e5b94c730c5baa74e8", 139 | "sha256:c15b03b864c046ea3198cbb9be04b8a892216c105d3b5e7a5b7e33bc286aaacb", 140 | "sha256:c5157927f98ef9fce6919d3c46330c97e9a179ef1d912458127c71a07372f341", 141 | "sha256:ce4a5546f17d68ba0c25ea28edce70c56b354a3cf4e7d78617dec037864d38d5", 142 | "sha256:d08475daadc04756d2a20a4bd51fe873b9b4f02fbe11f7bfd8995a7199792c38", 143 | "sha256:d1b1418f9d2b7be2fbe3d6e12e19f48a8d24e18597315f01f4df5e019b2d2685", 144 | "sha256:d33e17ee64ed77ad9c6deeecf1146e9ee0472b1ca7d37dc4f5ae28fdf6ed8e4e", 145 | "sha256:da4e8d1daa741e9165ed9d6f1dba8b94d7d2fc477a43564d4e0b61eac7d5b8e9", 146 | "sha256:db22902f085dd4c7979abe3cc415035f4a75a553140c98b751a3ca22d978f2a2", 147 | "sha256:e0837ef421af003ba36be10b830e187e8434c9052ef90bfb1a6de1c8586eb5e8", 148 | "sha256:e743a40579711c02388b021203f0785f00ac486637fea6ce9f9ab12cbc13ad69", 149 | "sha256:ef535dd391e4029834cc1a30a0353ca0d11b21201df570df46e3d6f62323d65a", 150 | "sha256:f085e0439129e562c157f85195d8c3f9f538d0d90315d0e41593c47580386f93", 151 | "sha256:f5ea8351b4a8713dd2818721c771a2b43c53e229cfc095898779c3aaff6f786e", 152 | "sha256:fc5438ea90768017818d54049603e96d5ad0c8b592ec33b2cdc35eaf75582802", 153 | "sha256:fe4552c363be5affe5133b9ba66a04e3c1e4f243154666a94fefa357f6352250" 154 | ], 155 | "markers": "python_version >= '3.7'", 156 | "version": "==1.51.0" 157 | }, 158 | "idna": { 159 | "hashes": [ 160 | "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4", 161 | "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2" 162 | ], 163 | "markers": "python_version >= '3.5'", 164 | "version": "==3.4" 165 | }, 166 | "imageio": { 167 | "hashes": [ 168 | "sha256:0fae027addf02bc89c73a56cc157ad84557f8b8b84aa19b4cb706fefca2d88ff", 169 | "sha256:bb173f8af27e4921f59539c4d45068fcedb892e58261fce8253f31c9a0ff9ccf" 170 | ], 171 | "markers": "python_version >= '3.7'", 172 | "version": "==2.22.4" 173 | }, 174 | "importlib-metadata": { 175 | "hashes": [ 176 | "sha256:da31db32b304314d044d3c12c79bd59e307889b287ad12ff387b3500835fc2ab", 177 | "sha256:ddb0e35065e8938f867ed4928d0ae5bf2a53b7773871bfe6bcc7e4fcdc7dea43" 178 | ], 179 | "markers": "python_version >= '3.7'", 180 | "version": "==5.0.0" 181 | }, 182 | "jmespath": { 183 | "hashes": [ 184 | "sha256:b85d0567b8666149a93172712e68920734333c0ce7e89b78b3e987f71e5ed4f9", 185 | "sha256:cdf6525904cc597730141d61b36f2e4b8ecc257c420fa2f4549bac2c2d0cb72f" 186 | ], 187 | "markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3'", 188 | "version": "==0.10.0" 189 | }, 190 | "markdown": { 191 | "hashes": [ 192 | "sha256:08fb8465cffd03d10b9dd34a5c3fea908e20391a2a90b88d66362cb05beed186", 193 | "sha256:3b809086bb6efad416156e00a0da66fe47618a5d6918dd688f53f40c8e4cfeff" 194 | ], 195 | "markers": "python_version >= '3.7'", 196 | "version": "==3.4.1" 197 | }, 198 | "markupsafe": { 199 | "hashes": [ 200 | "sha256:0212a68688482dc52b2d45013df70d169f542b7394fc744c02a57374a4207003", 201 | "sha256:089cf3dbf0cd6c100f02945abeb18484bd1ee57a079aefd52cffd17fba910b88", 202 | "sha256:10c1bfff05d95783da83491be968e8fe789263689c02724e0c691933c52994f5", 203 | "sha256:33b74d289bd2f5e527beadcaa3f401e0df0a89927c1559c8566c066fa4248ab7", 204 | "sha256:3799351e2336dc91ea70b034983ee71cf2f9533cdff7c14c90ea126bfd95d65a", 205 | "sha256:3ce11ee3f23f79dbd06fb3d63e2f6af7b12db1d46932fe7bd8afa259a5996603", 206 | "sha256:421be9fbf0ffe9ffd7a378aafebbf6f4602d564d34be190fc19a193232fd12b1", 207 | "sha256:43093fb83d8343aac0b1baa75516da6092f58f41200907ef92448ecab8825135", 208 | "sha256:46d00d6cfecdde84d40e572d63735ef81423ad31184100411e6e3388d405e247", 209 | "sha256:4a33dea2b688b3190ee12bd7cfa29d39c9ed176bda40bfa11099a3ce5d3a7ac6", 210 | "sha256:4b9fe39a2ccc108a4accc2676e77da025ce383c108593d65cc909add5c3bd601", 211 | "sha256:56442863ed2b06d19c37f94d999035e15ee982988920e12a5b4ba29b62ad1f77", 212 | "sha256:671cd1187ed5e62818414afe79ed29da836dde67166a9fac6d435873c44fdd02", 213 | "sha256:694deca8d702d5db21ec83983ce0bb4b26a578e71fbdbd4fdcd387daa90e4d5e", 214 | "sha256:6a074d34ee7a5ce3effbc526b7083ec9731bb3cbf921bbe1d3005d4d2bdb3a63", 215 | "sha256:6d0072fea50feec76a4c418096652f2c3238eaa014b2f94aeb1d56a66b41403f", 216 | "sha256:6fbf47b5d3728c6aea2abb0589b5d30459e369baa772e0f37a0320185e87c980", 217 | "sha256:7f91197cc9e48f989d12e4e6fbc46495c446636dfc81b9ccf50bb0ec74b91d4b", 218 | "sha256:86b1f75c4e7c2ac2ccdaec2b9022845dbb81880ca318bb7a0a01fbf7813e3812", 219 | "sha256:8dc1c72a69aa7e082593c4a203dcf94ddb74bb5c8a731e4e1eb68d031e8498ff", 220 | "sha256:8e3dcf21f367459434c18e71b2a9532d96547aef8a871872a5bd69a715c15f96", 221 | "sha256:8e576a51ad59e4bfaac456023a78f6b5e6e7651dcd383bcc3e18d06f9b55d6d1", 222 | "sha256:96e37a3dc86e80bf81758c152fe66dbf60ed5eca3d26305edf01892257049925", 223 | "sha256:97a68e6ada378df82bc9f16b800ab77cbf4b2fada0081794318520138c088e4a", 224 | "sha256:99a2a507ed3ac881b975a2976d59f38c19386d128e7a9a18b7df6fff1fd4c1d6", 225 | "sha256:a49907dd8420c5685cfa064a1335b6754b74541bbb3706c259c02ed65b644b3e", 226 | "sha256:b09bf97215625a311f669476f44b8b318b075847b49316d3e28c08e41a7a573f", 227 | "sha256:b7bd98b796e2b6553da7225aeb61f447f80a1ca64f41d83612e6139ca5213aa4", 228 | "sha256:b87db4360013327109564f0e591bd2a3b318547bcef31b468a92ee504d07ae4f", 229 | "sha256:bcb3ed405ed3222f9904899563d6fc492ff75cce56cba05e32eff40e6acbeaa3", 230 | "sha256:d4306c36ca495956b6d568d276ac11fdd9c30a36f1b6eb928070dc5360b22e1c", 231 | "sha256:d5ee4f386140395a2c818d149221149c54849dfcfcb9f1debfe07a8b8bd63f9a", 232 | "sha256:dda30ba7e87fbbb7eab1ec9f58678558fd9a6b8b853530e176eabd064da81417", 233 | "sha256:e04e26803c9c3851c931eac40c695602c6295b8d432cbe78609649ad9bd2da8a", 234 | "sha256:e1c0b87e09fa55a220f058d1d49d3fb8df88fbfab58558f1198e08c1e1de842a", 235 | "sha256:e72591e9ecd94d7feb70c1cbd7be7b3ebea3f548870aa91e2732960fa4d57a37", 236 | "sha256:e8c843bbcda3a2f1e3c2ab25913c80a3c5376cd00c6e8c4a86a89a28c8dc5452", 237 | "sha256:efc1913fd2ca4f334418481c7e595c00aad186563bbc1ec76067848c7ca0a933", 238 | "sha256:f121a1420d4e173a5d96e47e9a0c0dcff965afdf1626d28de1460815f7c4ee7a", 239 | "sha256:fc7b548b17d238737688817ab67deebb30e8073c95749d55538ed473130ec0c7" 240 | ], 241 | "markers": "python_version >= '3.7'", 242 | "version": "==2.1.1" 243 | }, 244 | "numpy": { 245 | "hashes": [ 246 | "sha256:0778076e764e146d3078b17c24c4d89e0ecd4ac5401beff8e1c87879043a0633", 247 | "sha256:141c7102f20abe6cf0d54c4ced8d565b86df4d3077ba2343b61a6db996cefec7", 248 | "sha256:14270a1ee8917d11e7753fb54fc7ffd1934f4d529235beec0b275e2ccf00333b", 249 | "sha256:27e11c7a8ec9d5838bc59f809bfa86efc8a4fd02e58960fa9c49d998e14332d5", 250 | "sha256:2a04dda79606f3d2f760384c38ccd3d5b9bb79d4c8126b67aff5eb09a253763e", 251 | "sha256:3c26010c1b51e1224a3ca6b8df807de6e95128b0908c7e34f190e7775455b0ca", 252 | "sha256:52c40f1a4262c896420c6ea1c6fda62cf67070e3947e3307f5562bd783a90336", 253 | "sha256:6e4f8d9e8aa79321657079b9ac03f3cf3fd067bf31c1cca4f56d49543f4356a5", 254 | "sha256:7242be12a58fec245ee9734e625964b97cf7e3f2f7d016603f9e56660ce479c7", 255 | "sha256:7dc253b542bfd4b4eb88d9dbae4ca079e7bf2e2afd819ee18891a43db66c60c7", 256 | "sha256:94f5bd885f67bbb25c82d80184abbf7ce4f6c3c3a41fbaa4182f034bba803e69", 257 | "sha256:a89e188daa119ffa0d03ce5123dee3f8ffd5115c896c2a9d4f0dbb3d8b95bfa3", 258 | "sha256:ad3399da9b0ca36e2f24de72f67ab2854a62e623274607e37e0ce5f5d5fa9166", 259 | "sha256:b0348be89275fd1d4c44ffa39530c41a21062f52299b1e3ee7d1c61f060044b8", 260 | "sha256:b5554368e4ede1856121b0dfa35ce71768102e4aa55e526cb8de7f374ff78722", 261 | "sha256:cbddc56b2502d3f87fda4f98d948eb5b11f36ff3902e17cb6cc44727f2200525", 262 | "sha256:d79f18f41751725c56eceab2a886f021d70fd70a6188fd386e29a045945ffc10", 263 | "sha256:dc2ca26a19ab32dc475dbad9dfe723d3a64c835f4c23f625c2b6566ca32b9f29", 264 | "sha256:dd9bcd4f294eb0633bb33d1a74febdd2b9018b8b8ed325f861fffcd2c7660bb8", 265 | "sha256:e8baab1bc7c9152715844f1faca6744f2416929de10d7639ed49555a85549f52", 266 | "sha256:ec31fe12668af687b99acf1567399632a7c47b0e17cfb9ae47c098644ef36797", 267 | "sha256:f12b4f7e2d8f9da3141564e6737d79016fe5336cc92de6814eba579744f65b0a", 268 | "sha256:f58ac38d5ca045a377b3b377c84df8175ab992c970a53332fa8ac2373df44ff7" 269 | ], 270 | "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", 271 | "version": "==1.16.4" 272 | }, 273 | "oauthlib": { 274 | "hashes": [ 275 | "sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca", 276 | "sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918" 277 | ], 278 | "markers": "python_full_version >= '3.6.0'", 279 | "version": "==3.2.2" 280 | }, 281 | "packaging": { 282 | "hashes": [ 283 | "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb", 284 | "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522" 285 | ], 286 | "markers": "python_full_version >= '3.6.0'", 287 | "version": "==21.3" 288 | }, 289 | "pandas": { 290 | "hashes": [ 291 | "sha256:0a643bae4283a37732ddfcecab3f62dd082996021b980f580903f4e8e01b3c5b", 292 | "sha256:0de3ddb414d30798cbf56e642d82cac30a80223ad6fe484d66c0ce01a84d6f2f", 293 | "sha256:19a2148a1d02791352e9fa637899a78e371a3516ac6da5c4edc718f60cbae648", 294 | "sha256:21b5a2b033380adbdd36b3116faaf9a4663e375325831dac1b519a44f9e439bb", 295 | "sha256:24c7f8d4aee71bfa6401faeba367dd654f696a77151a8a28bc2013f7ced4af98", 296 | "sha256:26fa92d3ac743a149a31b21d6f4337b0594b6302ea5575b37af9ca9611e8981a", 297 | "sha256:2860a97cbb25444ffc0088b457da0a79dc79f9c601238a3e0644312fcc14bf11", 298 | "sha256:2b1c6cd28a0dfda75c7b5957363333f01d370936e4c6276b7b8e696dd500582a", 299 | "sha256:2c2f7c670ea4e60318e4b7e474d56447cf0c7d83b3c2a5405a0dbb2600b9c48e", 300 | "sha256:3be7a7a0ca71a2640e81d9276f526bca63505850add10206d0da2e8a0a325dae", 301 | "sha256:4c62e94d5d49db116bef1bd5c2486723a292d79409fc9abd51adf9e05329101d", 302 | "sha256:5008374ebb990dad9ed48b0f5d0038124c73748f5384cc8c46904dace27082d9", 303 | "sha256:5447ea7af4005b0daf695a316a423b96374c9c73ffbd4533209c5ddc369e644b", 304 | "sha256:573fba5b05bf2c69271a32e52399c8de599e4a15ab7cec47d3b9c904125ab788", 305 | "sha256:5a780260afc88268a9d3ac3511d8f494fdcf637eece62fb9eb656a63d53eb7ca", 306 | "sha256:70865f96bb38fec46f7ebd66d4b5cfd0aa6b842073f298d621385ae3898d28b5", 307 | "sha256:731568be71fba1e13cae212c362f3d2ca8932e83cb1b85e3f1b4dd77d019254a", 308 | "sha256:b61080750d19a0122469ab59b087380721d6b72a4e7d962e4d7e63e0c4504814", 309 | "sha256:bf23a3b54d128b50f4f9d4675b3c1857a688cc6731a32f931837d72effb2698d", 310 | "sha256:c16d59c15d946111d2716856dd5479221c9e4f2f5c7bc2d617f39d870031e086", 311 | "sha256:c61c043aafb69329d0f961b19faa30b1dab709dd34c9388143fc55680059e55a", 312 | "sha256:c94ff2780a1fd89f190390130d6d36173ca59fcfb3fe0ff596f9a56518191ccb", 313 | "sha256:edda9bacc3843dfbeebaf7a701763e68e741b08fccb889c003b0a52f0ee95782", 314 | "sha256:f10fc41ee3c75a474d3bdf68d396f10782d013d7f67db99c0efbfd0acb99701b" 315 | ], 316 | "markers": "python_full_version >= '3.6.1'", 317 | "version": "==1.1.5" 318 | }, 319 | "pillow": { 320 | "hashes": [ 321 | "sha256:03150abd92771742d4a8cd6f2fa6246d847dcd2e332a18d0c15cc75bf6703040", 322 | "sha256:073adb2ae23431d3b9bcbcff3fe698b62ed47211d0716b067385538a1b0f28b8", 323 | "sha256:0b07fffc13f474264c336298d1b4ce01d9c5a011415b79d4ee5527bb69ae6f65", 324 | "sha256:0b7257127d646ff8676ec8a15520013a698d1fdc48bc2a79ba4e53df792526f2", 325 | "sha256:12ce4932caf2ddf3e41d17fc9c02d67126935a44b86df6a206cf0d7161548627", 326 | "sha256:15c42fb9dea42465dfd902fb0ecf584b8848ceb28b41ee2b58f866411be33f07", 327 | "sha256:18498994b29e1cf86d505edcb7edbe814d133d2232d256db8c7a8ceb34d18cef", 328 | "sha256:1c7c8ae3864846fc95f4611c78129301e203aaa2af813b703c55d10cc1628535", 329 | "sha256:22b012ea2d065fd163ca096f4e37e47cd8b59cf4b0fd47bfca6abb93df70b34c", 330 | "sha256:276a5ca930c913f714e372b2591a22c4bd3b81a418c0f6635ba832daec1cbcfc", 331 | "sha256:2e0918e03aa0c72ea56edbb00d4d664294815aa11291a11504a377ea018330d3", 332 | "sha256:3033fbe1feb1b59394615a1cafaee85e49d01b51d54de0cbf6aa8e64182518a1", 333 | "sha256:3168434d303babf495d4ba58fc22d6604f6e2afb97adc6a423e917dab828939c", 334 | "sha256:32a44128c4bdca7f31de5be641187367fe2a450ad83b833ef78910397db491aa", 335 | "sha256:3dd6caf940756101205dffc5367babf288a30043d35f80936f9bfb37f8355b32", 336 | "sha256:40e1ce476a7804b0fb74bcfa80b0a2206ea6a882938eaba917f7a0f004b42502", 337 | "sha256:41e0051336807468be450d52b8edd12ac60bebaa97fe10c8b660f116e50b30e4", 338 | "sha256:4390e9ce199fc1951fcfa65795f239a8a4944117b5935a9317fb320e7767b40f", 339 | "sha256:502526a2cbfa431d9fc2a079bdd9061a2397b842bb6bc4239bb176da00993812", 340 | "sha256:51e0e543a33ed92db9f5ef69a0356e0b1a7a6b6a71b80df99f1d181ae5875636", 341 | "sha256:57751894f6618fd4308ed8e0c36c333e2f5469744c34729a27532b3db106ee20", 342 | "sha256:5d77adcd56a42d00cc1be30843d3426aa4e660cab4a61021dc84467123f7a00c", 343 | "sha256:655a83b0058ba47c7c52e4e2df5ecf484c1b0b0349805896dd350cbc416bdd91", 344 | "sha256:68943d632f1f9e3dce98908e873b3a090f6cba1cbb1b892a9e8d97c938871fbe", 345 | "sha256:6c738585d7a9961d8c2821a1eb3dcb978d14e238be3d70f0a706f7fa9316946b", 346 | "sha256:73bd195e43f3fadecfc50c682f5055ec32ee2c933243cafbfdec69ab1aa87cad", 347 | "sha256:772a91fc0e03eaf922c63badeca75e91baa80fe2f5f87bdaed4280662aad25c9", 348 | "sha256:77ec3e7be99629898c9a6d24a09de089fa5356ee408cdffffe62d67bb75fdd72", 349 | "sha256:7db8b751ad307d7cf238f02101e8e36a128a6cb199326e867d1398067381bff4", 350 | "sha256:801ec82e4188e935c7f5e22e006d01611d6b41661bba9fe45b60e7ac1a8f84de", 351 | "sha256:82409ffe29d70fd733ff3c1025a602abb3e67405d41b9403b00b01debc4c9a29", 352 | "sha256:828989c45c245518065a110434246c44a56a8b2b2f6347d1409c787e6e4651ee", 353 | "sha256:829f97c8e258593b9daa80638aee3789b7df9da5cf1336035016d76f03b8860c", 354 | "sha256:871b72c3643e516db4ecf20efe735deb27fe30ca17800e661d769faab45a18d7", 355 | "sha256:89dca0ce00a2b49024df6325925555d406b14aa3efc2f752dbb5940c52c56b11", 356 | "sha256:90fb88843d3902fe7c9586d439d1e8c05258f41da473952aa8b328d8b907498c", 357 | "sha256:97aabc5c50312afa5e0a2b07c17d4ac5e865b250986f8afe2b02d772567a380c", 358 | "sha256:9aaa107275d8527e9d6e7670b64aabaaa36e5b6bd71a1015ddd21da0d4e06448", 359 | "sha256:9f47eabcd2ded7698106b05c2c338672d16a6f2a485e74481f524e2a23c2794b", 360 | "sha256:a0a06a052c5f37b4ed81c613a455a81f9a3a69429b4fd7bb913c3fa98abefc20", 361 | "sha256:ab388aaa3f6ce52ac1cb8e122c4bd46657c15905904b3120a6248b5b8b0bc228", 362 | "sha256:ad58d27a5b0262c0c19b47d54c5802db9b34d38bbf886665b626aff83c74bacd", 363 | "sha256:ae5331c23ce118c53b172fa64a4c037eb83c9165aba3a7ba9ddd3ec9fa64a699", 364 | "sha256:af0372acb5d3598f36ec0914deed2a63f6bcdb7b606da04dc19a88d31bf0c05b", 365 | "sha256:afa4107d1b306cdf8953edde0534562607fe8811b6c4d9a486298ad31de733b2", 366 | "sha256:b03ae6f1a1878233ac620c98f3459f79fd77c7e3c2b20d460284e1fb370557d4", 367 | "sha256:b0915e734b33a474d76c28e07292f196cdf2a590a0d25bcc06e64e545f2d146c", 368 | "sha256:b4012d06c846dc2b80651b120e2cdd787b013deb39c09f407727ba90015c684f", 369 | "sha256:b472b5ea442148d1c3e2209f20f1e0bb0eb556538690fa70b5e1f79fa0ba8dc2", 370 | "sha256:b59430236b8e58840a0dfb4099a0e8717ffb779c952426a69ae435ca1f57210c", 371 | "sha256:b90f7616ea170e92820775ed47e136208e04c967271c9ef615b6fbd08d9af0e3", 372 | "sha256:b9a65733d103311331875c1dca05cb4606997fd33d6acfed695b1232ba1df193", 373 | "sha256:bac18ab8d2d1e6b4ce25e3424f709aceef668347db8637c2296bcf41acb7cf48", 374 | "sha256:bca31dd6014cb8b0b2db1e46081b0ca7d936f856da3b39744aef499db5d84d02", 375 | "sha256:be55f8457cd1eac957af0c3f5ece7bc3f033f89b114ef30f710882717670b2a8", 376 | "sha256:c7025dce65566eb6e89f56c9509d4f628fddcedb131d9465cacd3d8bac337e7e", 377 | "sha256:c935a22a557a560108d780f9a0fc426dd7459940dc54faa49d83249c8d3e760f", 378 | "sha256:dbb8e7f2abee51cef77673be97760abff1674ed32847ce04b4af90f610144c7b", 379 | "sha256:e6ea6b856a74d560d9326c0f5895ef8050126acfdc7ca08ad703eb0081e82b74", 380 | "sha256:ebf2029c1f464c59b8bdbe5143c79fa2045a581ac53679733d3a91d400ff9efb", 381 | "sha256:f1ff2ee69f10f13a9596480335f406dd1f70c3650349e2be67ca3139280cade0" 382 | ], 383 | "index": "pypi", 384 | "version": "==9.3.0" 385 | }, 386 | "pkginfo": { 387 | "hashes": [ 388 | "sha256:848865108ec99d4901b2f7e84058b6e7660aae8ae10164e015a6dcf5b242a594", 389 | "sha256:a84da4318dd86f870a9447a8c98340aa06216bfc6f2b7bdc4b8766984ae1867c" 390 | ], 391 | "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5'", 392 | "version": "==1.8.3" 393 | }, 394 | "protobuf": { 395 | "hashes": [ 396 | "sha256:03038ac1cfbc41aa21f6afcbcd357281d7521b4157926f30ebecc8d4ea59dcb7", 397 | "sha256:28545383d61f55b57cf4df63eebd9827754fd2dc25f80c5253f9184235db242c", 398 | "sha256:2e3427429c9cffebf259491be0af70189607f365c2f41c7c3764af6f337105f2", 399 | "sha256:398a9e0c3eaceb34ec1aee71894ca3299605fa8e761544934378bbc6c97de23b", 400 | "sha256:44246bab5dd4b7fbd3c0c80b6f16686808fab0e4aca819ade6e8d294a29c7050", 401 | "sha256:447d43819997825d4e71bf5769d869b968ce96848b6479397e29fc24c4a5dfe9", 402 | "sha256:67a3598f0a2dcbc58d02dd1928544e7d88f764b47d4a286202913f0b2801c2e7", 403 | "sha256:74480f79a023f90dc6e18febbf7b8bac7508420f2006fabd512013c0c238f454", 404 | "sha256:819559cafa1a373b7096a482b504ae8a857c89593cf3a25af743ac9ecbd23480", 405 | "sha256:899dc660cd599d7352d6f10d83c95df430a38b410c1b66b407a6b29265d66469", 406 | "sha256:8c0c984a1b8fef4086329ff8dd19ac77576b384079247c770f29cc8ce3afa06c", 407 | "sha256:9aae4406ea63d825636cc11ffb34ad3379335803216ee3a856787bcf5ccc751e", 408 | "sha256:a7ca6d488aa8ff7f329d4c545b2dbad8ac31464f1d8b1c87ad1346717731e4db", 409 | "sha256:b6cc7ba72a8850621bfec987cb72623e703b7fe2b9127a161ce61e61558ad905", 410 | "sha256:bf01b5720be110540be4286e791db73f84a2b721072a3711efff6c324cdf074b", 411 | "sha256:c02ce36ec760252242a33967d51c289fd0e1c0e6e5cc9397e2279177716add86", 412 | "sha256:d9e4432ff660d67d775c66ac42a67cf2453c27cb4d738fc22cb53b5d84c135d4", 413 | "sha256:daa564862dd0d39c00f8086f88700fdbe8bc717e993a21e90711acfed02f2402", 414 | "sha256:de78575669dddf6099a8a0f46a27e82a1783c557ccc38ee620ed8cc96d3be7d7", 415 | "sha256:e64857f395505ebf3d2569935506ae0dfc4a15cb80dc25261176c784662cdcc4", 416 | "sha256:f4bd856d702e5b0d96a00ec6b307b0f51c1982c2bf9c0052cf9019e9a544ba99", 417 | "sha256:f4c42102bc82a51108e449cbb32b19b180022941c727bac0cfd50170341f16ee" 418 | ], 419 | "markers": "python_version >= '3.7'", 420 | "version": "==3.20.3" 421 | }, 422 | "protobuf3-to-dict": { 423 | "hashes": [ 424 | "sha256:1e42c25b5afb5868e3a9b1962811077e492c17557f9c66f0fe40d821375d2b5a" 425 | ], 426 | "version": "==0.1.5" 427 | }, 428 | "pyasn1": { 429 | "hashes": [ 430 | "sha256:014c0e9976956a08139dc0712ae195324a75e142284d5f87f1a87ee1b068a359", 431 | "sha256:03840c999ba71680a131cfaee6fab142e1ed9bbd9c693e285cc6aca0d555e576", 432 | "sha256:0458773cfe65b153891ac249bcf1b5f8f320b7c2ce462151f8fa74de8934becf", 433 | "sha256:08c3c53b75eaa48d71cf8c710312316392ed40899cb34710d092e96745a358b7", 434 | "sha256:39c7e2ec30515947ff4e87fb6f456dfc6e84857d34be479c9d4a4ba4bf46aa5d", 435 | "sha256:5c9414dcfede6e441f7e8f81b43b34e834731003427e5b09e4e00e3172a10f00", 436 | "sha256:6e7545f1a61025a4e58bb336952c5061697da694db1cae97b116e9c46abcf7c8", 437 | "sha256:78fa6da68ed2727915c4767bb386ab32cdba863caa7dbe473eaae45f9959da86", 438 | "sha256:7ab8a544af125fb704feadb008c99a88805126fb525280b2270bb25cc1d78a12", 439 | "sha256:99fcc3c8d804d1bc6d9a099921e39d827026409a58f2a720dcdb89374ea0c776", 440 | "sha256:aef77c9fb94a3ac588e87841208bdec464471d9871bd5050a287cc9a475cd0ba", 441 | "sha256:e89bf84b5437b532b0803ba5c9a5e054d21fec423a89952a74f87fa2c9b7bce2", 442 | "sha256:fec3e9d8e36808a28efb59b489e4528c10ad0f480e57dcc32b4de5c9d8c9fdf3" 443 | ], 444 | "version": "==0.4.8" 445 | }, 446 | "pyasn1-modules": { 447 | "hashes": [ 448 | "sha256:0845a5582f6a02bb3e1bde9ecfc4bfcae6ec3210dd270522fee602365430c3f8", 449 | "sha256:0fe1b68d1e486a1ed5473f1302bd991c1611d319bba158e98b106ff86e1d7199", 450 | "sha256:15b7c67fabc7fc240d87fb9aabf999cf82311a6d6fb2c70d00d3d0604878c811", 451 | "sha256:426edb7a5e8879f1ec54a1864f16b882c2837bfd06eee62f2c982315ee2473ed", 452 | "sha256:65cebbaffc913f4fe9e4808735c95ea22d7a7775646ab690518c056784bc21b4", 453 | "sha256:905f84c712230b2c592c19470d3ca8d552de726050d1d1716282a1f6146be65e", 454 | "sha256:a50b808ffeb97cb3601dd25981f6b016cbb3d31fbf57a8b8a87428e6158d0c74", 455 | "sha256:a99324196732f53093a84c4369c996713eb8c89d360a496b599fb1a9c47fc3eb", 456 | "sha256:b80486a6c77252ea3a3e9b1e360bc9cf28eaac41263d173c032581ad2f20fe45", 457 | "sha256:c29a5e5cc7a3f05926aff34e097e84f8589cd790ce0ed41b67aed6857b26aafd", 458 | "sha256:cbac4bc38d117f2a49aeedec4407d23e8866ea4ac27ff2cf7fb3e5b570df19e0", 459 | "sha256:f39edd8c4ecaa4556e989147ebf219227e2cd2e8a43c7e7fcb1f1c18c5fd6a3d", 460 | "sha256:fe0644d9ab041506b62782e92b06b8c68cca799e1a9636ec398675459e031405" 461 | ], 462 | "version": "==0.2.8" 463 | }, 464 | "pygments": { 465 | "hashes": [ 466 | "sha256:56a8508ae95f98e2b9bdf93a6be5ae3f7d8af858b43e02c5a2ff083726be40c1", 467 | "sha256:f643f331ab57ba3c9d89212ee4a2dabc6e94f117cf4eefde99a0574720d14c42" 468 | ], 469 | "markers": "python_full_version >= '3.6.0'", 470 | "version": "==2.13.0" 471 | }, 472 | "pyparsing": { 473 | "hashes": [ 474 | "sha256:2b020ecf7d21b687f219b71ecad3631f644a47f01403fa1d1036b0c6416d70fb", 475 | "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc" 476 | ], 477 | "markers": "python_full_version >= '3.6.8'", 478 | "version": "==3.0.9" 479 | }, 480 | "python-dateutil": { 481 | "hashes": [ 482 | "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86", 483 | "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9" 484 | ], 485 | "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", 486 | "version": "==2.8.2" 487 | }, 488 | "pytorch-lightning": { 489 | "hashes": [ 490 | "sha256:e76dea0ef7765a6d2fe1c40f591247aa8000f3cc02d7ad53a6e2856bdbee42b4" 491 | ], 492 | "index": "pypi", 493 | "version": "==0.5.3.1" 494 | }, 495 | "pytz": { 496 | "hashes": [ 497 | "sha256:222439474e9c98fced559f1709d89e6c9cbf8d79c794ff3eb9f8800064291427", 498 | "sha256:e89512406b793ca39f5971bc999cc538ce125c0e51c27941bef4568b460095e2" 499 | ], 500 | "version": "==2022.6" 501 | }, 502 | "readme-renderer": { 503 | "hashes": [ 504 | "sha256:cd653186dfc73055656f090f227f5cb22a046d7f71a841dfa305f55c9a513273", 505 | "sha256:f67a16caedfa71eef48a31b39708637a6f4664c4394801a7b0d6432d13907343" 506 | ], 507 | "markers": "python_version >= '3.7'", 508 | "version": "==37.3" 509 | }, 510 | "requests": { 511 | "hashes": [ 512 | "sha256:7c5599b102feddaa661c826c56ab4fee28bfd17f5abca1ebbe3e7f19d7c97983", 513 | "sha256:8fefa2a1a1365bf5520aac41836fbee479da67864514bdb821f31ce07ce65349" 514 | ], 515 | "markers": "python_version >= '3.7' and python_version < '4'", 516 | "version": "==2.28.1" 517 | }, 518 | "requests-oauthlib": { 519 | "hashes": [ 520 | "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5", 521 | "sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a" 522 | ], 523 | "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", 524 | "version": "==1.3.1" 525 | }, 526 | "requests-toolbelt": { 527 | "hashes": [ 528 | "sha256:18565aa58116d9951ac39baa288d3adb5b3ff975c4f25eee78555d89e8f247f7", 529 | "sha256:62e09f7ff5ccbda92772a29f394a49c3ad6cb181d568b1337626b2abb628a63d" 530 | ], 531 | "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", 532 | "version": "==0.10.1" 533 | }, 534 | "rsa": { 535 | "hashes": [ 536 | "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7", 537 | "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21" 538 | ], 539 | "markers": "python_full_version >= '3.6.0'", 540 | "version": "==4.9" 541 | }, 542 | "s3transfer": { 543 | "hashes": [ 544 | "sha256:35627b86af8ff97e7ac27975fe0a98a312814b46c6333d8a6b889627bcd80994", 545 | "sha256:efa5bd92a897b6a8d5c1383828dca3d52d0790e0756d49740563a3fb6ed03246" 546 | ], 547 | "version": "==0.3.7" 548 | }, 549 | "sagemaker": { 550 | "hashes": [ 551 | "sha256:76aa1a9fe6f892e48cfa9212a15044fa5b713bcff02a86ee098329b1b18ceadb" 552 | ], 553 | "index": "pypi", 554 | "version": "==1.69.0" 555 | }, 556 | "scikit-learn": { 557 | "hashes": [ 558 | "sha256:05d061606657af85365b5f71484e3362d924429edde17a90068960843ad597f5", 559 | "sha256:071317afbb5c67fa493635376ddd724b414290255cbf6947c1155846956e93f7", 560 | "sha256:0d03aaf19a25e59edac3099cda6879ba05129f0fa1e152e23b728ccd36104f57", 561 | "sha256:1665ea0d4b75ef24f5f2a9d1527b7296eeabcbe3a1329791c954541e2ebde5a2", 562 | "sha256:24eccb0ff31f84e88e00936c09197735ef1dcabd370aacb10e55dbc8ee464a78", 563 | "sha256:27b48cabacce677a205e6bcda1f32bdc968fbf40cd2aa0a4f52852f6997fce51", 564 | "sha256:2c51826b9daa87d7d356bebd39f8665f7c32e90e3b21cbe853d6c7f0d6b0d23b", 565 | "sha256:3116299d392bd1d054655fa2a740e7854de87f1d573fa85503e64494e52ac795", 566 | "sha256:3771861abe1fd1b2bbeaec7ba8cfca58fdedd75d790f099960e5332af9d1ff7a", 567 | "sha256:473ba7d9a5eaec47909ee83d74b4a3be47a44505c5189d2cab67c0418cd030f1", 568 | "sha256:621e2c91f9afde06e9295d128cb15cb6fc77dc00719393e9ec9d47119895b0d4", 569 | "sha256:645865462c383e5faad473b93145a8aee97d839c9ad1fd7a17ae54ec8256d42b", 570 | "sha256:80e2276d4869d302e84b7c03b5bac4a67f6cd331162e62ae775a3e5855441a60", 571 | "sha256:84d2cfe0dee3c22b26364266d69850e0eb406d99714045929875032f91d3c918", 572 | "sha256:87ea9ace7fe811638dfc39b850b60887509b8bfc93c4006d5552fa066d04ddc7", 573 | "sha256:a4d1e535c75881f668010e6e53dfeb89dd50db85b05c5c45af1991c8b832d757", 574 | "sha256:a4f14c4327d2e44567bfb3a0bee8c55470f820bc9a67af3faf200abd8ed79bf2", 575 | "sha256:a7b3c24e193e8c6eaeac075b5d0bb0a7fea478aa2e4b991f6a7b030fc4fd410d", 576 | "sha256:ab2919aca84f1ac6ef60a482148eec0944364ab1832e63f28679b16f9ef279c8", 577 | "sha256:b0f79d5ff74f3c68a4198ad5b4dfa891326b5ce272dd064d11d572b25aae5b43", 578 | "sha256:bc5bc7c7ee2572a1edcb51698a6caf11fae554194aaab9a38105d9ec419f29e6", 579 | "sha256:bc5c750d548795def79576533f8f0f065915f17f48d6e443afce2a111f713747", 580 | "sha256:c68969c30b3b2c1fe07c1376110928eade61da4fc29c24c9f1a89435a7d08abe", 581 | "sha256:d3b4f791d2645fe936579d61f1ff9b5dcf0c8f50db7f0245ca8f16407d7a5a46", 582 | "sha256:dac0cd9fdd8ac6dd6108a10558e2e0ca1b411b8ea0a3165641f9ab0b4322df4e", 583 | "sha256:eb7ddbdf33eb822fdc916819b0ab7009d954eb43c3a78e7dd2ec5455e074922a", 584 | "sha256:ed537844348402ed53420187b3a6948c576986d0b2811a987a49613b6a26f29e", 585 | "sha256:fcca54733e692fe03b8584f7d4b9344f4b6e3a74f5b326c6e5f5e9d2504bdce7" 586 | ], 587 | "version": "==0.20.2" 588 | }, 589 | "scipy": { 590 | "hashes": [ 591 | "sha256:168c45c0c32e23f613db7c9e4e780bc61982d71dcd406ead746c7c7c2f2004ce", 592 | "sha256:213bc59191da2f479984ad4ec39406bf949a99aba70e9237b916ce7547b6ef42", 593 | "sha256:25b241034215247481f53355e05f9e25462682b13bd9191359075682adcd9554", 594 | "sha256:2c872de0c69ed20fb1a9b9cf6f77298b04a26f0b8720a5457be08be254366c6e", 595 | "sha256:3397c129b479846d7eaa18f999369a24322d008fac0782e7828fa567358c36ce", 596 | "sha256:368c0f69f93186309e1b4beb8e26d51dd6f5010b79264c0f1e9ca00cd92ea8c9", 597 | "sha256:3d5db5d815370c28d938cf9b0809dade4acf7aba57eaf7ef733bfedc9b2474c4", 598 | "sha256:4598cf03136067000855d6b44d7a1f4f46994164bcd450fb2c3d481afc25dd06", 599 | "sha256:4a453d5e5689de62e5d38edf40af3f17560bfd63c9c5bd228c18c1f99afa155b", 600 | "sha256:4f12d13ffbc16e988fa40809cbbd7a8b45bc05ff6ea0ba8e3e41f6f4db3a9e47", 601 | "sha256:634568a3018bc16a83cda28d4f7aed0d803dd5618facb36e977e53b2df868443", 602 | "sha256:65923bc3809524e46fb7eb4d6346552cbb6a1ffc41be748535aa502a2e3d3389", 603 | "sha256:6b0ceb23560f46dd236a8ad4378fc40bad1783e997604ba845e131d6c680963e", 604 | "sha256:8c8d6ca19c8497344b810b0b0344f8375af5f6bb9c98bd42e33f747417ab3f57", 605 | "sha256:9ad4fcddcbf5dc67619379782e6aeef41218a79e17979aaed01ed099876c0e62", 606 | "sha256:a254b98dbcc744c723a838c03b74a8a34c0558c9ac5c86d5561703362231107d", 607 | "sha256:b03c4338d6d3d299e8ca494194c0ae4f611548da59e3c038813f1a43976cb437", 608 | "sha256:cc1f78ebc982cd0602c9a7615d878396bec94908db67d4ecddca864d049112f2", 609 | "sha256:d6d25c41a009e3c6b7e757338948d0076ee1dd1770d1c09ec131f11946883c54", 610 | "sha256:d84cadd7d7998433334c99fa55bcba0d8b4aeff0edb123b2a1dfcface538e474", 611 | "sha256:e360cb2299028d0b0d0f65a5c5e51fc16a335f1603aa2357c25766c8dab56938", 612 | "sha256:e98d49a5717369d8241d6cf33ecb0ca72deee392414118198a8e5b4c35c56340", 613 | "sha256:ed572470af2438b526ea574ff8f05e7f39b44ac37f712105e57fc4d53a6fb660", 614 | "sha256:f87b39f4d69cf7d7529d7b1098cb712033b17ea7714aed831b95628f483fd012", 615 | "sha256:fa789583fc94a7689b45834453fec095245c7e69c58561dc159b5d5277057e4c" 616 | ], 617 | "markers": "python_full_version >= '3.6.0'", 618 | "version": "==1.5.4" 619 | }, 620 | "setuptools": { 621 | "hashes": [ 622 | "sha256:6211d2f5eddad8757bd0484923ca7c0a6302ebc4ab32ea5e94357176e0ca0840", 623 | "sha256:d1eebf881c6114e51df1664bc2c9133d022f78d12d5f4f665b9191f084e2862d" 624 | ], 625 | "markers": "python_version >= '3.7'", 626 | "version": "==65.6.0" 627 | }, 628 | "six": { 629 | "hashes": [ 630 | "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926", 631 | "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254" 632 | ], 633 | "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", 634 | "version": "==1.16.0" 635 | }, 636 | "smdebug-rulesconfig": { 637 | "hashes": [ 638 | "sha256:72e69cffd2f35708dc2a758d07368c6b15ddc031bb183a2a75273ffe2c0f8319", 639 | "sha256:8c8e14926222451b2821e7ab065830d90770e895e6f095fe7cefd235c84b5996" 640 | ], 641 | "markers": "python_version >= '2.7'", 642 | "version": "==0.1.4" 643 | }, 644 | "tensorboard": { 645 | "hashes": [ 646 | "sha256:a0e592ee87962e17af3f0dce7faae3fbbd239030159e9e625cce810b7e35c53d" 647 | ], 648 | "markers": "python_version >= '3.7'", 649 | "version": "==2.11.0" 650 | }, 651 | "tensorboard-data-server": { 652 | "hashes": [ 653 | "sha256:809fe9887682d35c1f7d1f54f0f40f98bb1f771b14265b453ca051e2ce58fca7", 654 | "sha256:d8237580755e58eff68d1f3abefb5b1e39ae5c8b127cc40920f9c4fb33f4b98a", 655 | "sha256:fa8cef9be4fcae2f2363c88176638baf2da19c5ec90addb49b1cde05c95c88ee" 656 | ], 657 | "markers": "python_full_version >= '3.6.0'", 658 | "version": "==0.6.1" 659 | }, 660 | "tensorboard-plugin-wit": { 661 | "hashes": [ 662 | "sha256:ff26bdd583d155aa951ee3b152b3d0cffae8005dc697f72b44a8e8c2a77a8cbe" 663 | ], 664 | "version": "==1.8.1" 665 | }, 666 | "test-tube": { 667 | "hashes": [ 668 | "sha256:1379c33eb8cde3e9b36610f87da0f16c2e06496b1cfebac473df4e7be2faa124" 669 | ], 670 | "version": "==0.7.5" 671 | }, 672 | "torch": { 673 | "hashes": [ 674 | "sha256:271d4d1e44df6ed57c530f8849b028447c62b8a19b8e8740dd9baa56e7f682c1", 675 | "sha256:30ce089475b287a37d6fbb8d71853e672edaf66699e3dd2eb19be6ce6296732a", 676 | "sha256:405b9eb40e44037d2525b3ddb5bc4c66b519cd742bff249d4207d23f83e88ea5", 677 | "sha256:504915c6bc6051ba6a4c2a43c446463dff04411e352f1e26fe13debeae431778", 678 | "sha256:54d06a0e8ee85e5a437c24f4af9f4196c819294c23ffb5914e177756f55f1829", 679 | "sha256:6f2fd9eb8c7eaf38a982ab266dbbfba0f29fb643bc74e677d045d6f2595e4692", 680 | "sha256:8856f334aa9ecb742c1504bd2563d0ffb8dceb97149c8d72a04afa357f667dbc", 681 | "sha256:8fff03bf7b474c16e4b50da65ea14200cc64553b67b9b2307f9dc7e8c69b9d28", 682 | "sha256:9a1b1db73d8dcfd94b2eee24b939368742aa85f1217c55b8f5681e76c581e99a", 683 | "sha256:bb1e87063661414e1149bef2e3a2499ce0b5060290799d7e26bc5578037075ba", 684 | "sha256:d7b34a78f021935ad727a3bede56a8a8d4fda0b0272314a04c5f6890bbe7bb29" 685 | ], 686 | "index": "pypi", 687 | "version": "==1.4.0" 688 | }, 689 | "torchvision": { 690 | "hashes": [ 691 | "sha256:0ca9cae9ddf1784737493e201aa9411abe62a4479b2e67d1d51b4b7acf16f6eb", 692 | "sha256:1a68d3d98e074d995f3d42a492cca716b0d94605a6fadddf0ce9665425968669", 693 | "sha256:1af6d7b0a515d2a83fe9b6e7969b57ba94ba87a3333e7ed707324a5be1ef5f60", 694 | "sha256:2bf1dc1e16c73c5810d96e4ea463e61129e890100740cd57724413a84d301e41", 695 | "sha256:323500d349d8d91ce2662de41212e8eb1845c68dbf5d4f215ca1e94c7f20723b", 696 | "sha256:358967343eaba74fd748a87f40ea75ca23757e947dbef9a11cd53414d707f793", 697 | "sha256:35e9483858cf8a38debc647c74741605c5c12448d314aa96961082380aadf7e5", 698 | "sha256:4dd05cbc497210928ae3d4d6194561985263c879c3554e9f1823a0fa43d35746", 699 | "sha256:517425af7d41b64caae0f5d9e6b14eeb48d6e62d45f302b73a11a9ec5ee3b6c8", 700 | "sha256:78d455a1da7d10bd38f2e2a0d2ac285e4845c9e7e28aafdf068472cc96bd156b", 701 | "sha256:9e85ba17ff93a0cf6afd39b9a0ad56ca7321db4f1eb90d2034d3b0ecd79be47b", 702 | "sha256:a696ec5009eb52356508eb9b23ddb977043fb82ff7b204459e4c81aca1e5affe", 703 | "sha256:aa4354d339de2c5ea2633a6c94294c68bae3e42a4b099624299e2a50c9e97a85", 704 | "sha256:ec7e4cd54f5ff3a889b90f24b33da1fa9fe3f78d17348965678d9503de1e4a49", 705 | "sha256:fea3d431bf639c0719afff5972eb568ebe143eba447c1c8bb491c7dfb0025ed6" 706 | ], 707 | "version": "==0.5.0" 708 | }, 709 | "tqdm": { 710 | "hashes": [ 711 | "sha256:1be3e4e3198f2d0e47b928e9d9a8ec1b63525db29095cec1467f4c5a4ea8ebf9", 712 | "sha256:7e39a30e3d34a7a6539378e39d7490326253b7ee354878a92255656dc4284457" 713 | ], 714 | "markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3'", 715 | "version": "==4.35.0" 716 | }, 717 | "twine": { 718 | "hashes": [ 719 | "sha256:0fb0bfa3df4f62076cab5def36b1a71a2e4acb4d1fa5c97475b048117b1a6446", 720 | "sha256:d6c29c933ecfc74e9b1d9fa13aa1f87c5d5770e119f5a4ce032092f0ff5b14dc" 721 | ], 722 | "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", 723 | "version": "==1.13.0" 724 | }, 725 | "typing-extensions": { 726 | "hashes": [ 727 | "sha256:1511434bb92bf8dd198c12b1cc812e800d4181cfcb867674e0f8279cc93087aa", 728 | "sha256:16fa4864408f655d35ec496218b85f79b3437c829e93320c7c9215ccfd92489e" 729 | ], 730 | "markers": "python_version < '3.8'", 731 | "version": "==4.4.0" 732 | }, 733 | "urllib3": { 734 | "hashes": [ 735 | "sha256:8d7eaa5a82a1cac232164990f04874c594c9453ec55eef02eab885aa02fc17a2", 736 | "sha256:f5321fbe4bf3fefa0efd0bfe7fb14e90909eb62a48ccda331726b4319897dd5e" 737 | ], 738 | "markers": "python_version != '3.4'", 739 | "version": "==1.25.11" 740 | }, 741 | "webencodings": { 742 | "hashes": [ 743 | "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78", 744 | "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923" 745 | ], 746 | "version": "==0.5.1" 747 | }, 748 | "werkzeug": { 749 | "hashes": [ 750 | "sha256:7ea2d48322cc7c0f8b3a215ed73eabd7b5d75d0b50e31ab006286ccff9e00b8f", 751 | "sha256:f979ab81f58d7318e064e99c4506445d60135ac5cd2e177a2de0089bfd4c9bd5" 752 | ], 753 | "markers": "python_version >= '3.7'", 754 | "version": "==2.2.2" 755 | }, 756 | "wheel": { 757 | "hashes": [ 758 | "sha256:965f5259b566725405b05e7cf774052044b1ed30119b5d586b2703aafe8719ac", 759 | "sha256:b60533f3f5d530e971d6737ca6d58681ee434818fab630c83a734bb10c083ce8" 760 | ], 761 | "markers": "python_version >= '3.7'", 762 | "version": "==0.38.4" 763 | }, 764 | "zipp": { 765 | "hashes": [ 766 | "sha256:4fcb6f278987a6605757302a6e40e896257570d11c51628968ccb2a47e80c6c1", 767 | "sha256:7a7262fd930bd3e36c50b9a64897aec3fafff3dfdeec9623ae22b40e93f99bb8" 768 | ], 769 | "markers": "python_version >= '3.7'", 770 | "version": "==3.10.0" 771 | } 772 | }, 773 | "develop": {} 774 | } 775 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Networks on Amazon SageMaker with PyTorch Lightning 2 | 3 | [![made-with-python](https://img.shields.io/badge/Made%20with-Python-1f425f.svg)](https://www.python.org/) 4 | [![Open Source Love svg2](https://badges.frapsoft.com/os/v2/open-source.svg?v=103)](https://github.com/ellerbrock/open-source-badges/) 5 | [![GitHub license](https://img.shields.io/github/license/aletheia/mnist_pl_sagemaker.svg)](https://github.com/aletheia/mnist_pl_sagemaker/blob/master/LICENSE) 6 | [![GitHub issues](https://img.shields.io/github/issues/Naereen/StrapDown.js.svg)](https://GitHub.com/Naereen/StrapDown.js/issues/) 7 | [![GitHub pull-requests](https://img.shields.io/github/issues-pr/Naereen/StrapDown.js.svg)](https://GitHub.com/Naereen/StrapDown.js/pull/) 8 | [![saythanks](https://img.shields.io/badge/say-thanks-ff69b4.svg)](https://saythanks.io/to/aletheia) 9 | 10 | This is an example project showing how you could use Amazon SageMaker with Pytorch Lightning, from getting started to model training. 11 | A detailed discussion about SageMaker and PyTorch Lightning can be found in the article [**Neural Network on Amazon SageMaker with PyTorch Lightning**](https://medium.com/@aletheia/machine-learning-model-development-on-amazon-sagemaker-with-pytorch-lightning-63730ec740ea). 12 | 13 | ### PyTorch Lightning 14 | The super cool [Pytorch Lightning Framework](https://github.com/PyTorchLightning/pytorch-lightning) to simplify machine learning model development. 15 | 16 | Pytorch Lightning (PL) offers support to a wide number of advanced functions to ML researchers developing models. It is also useful when managing multiple projects because imposes a defined structure: 17 | 18 | ![Image](./classifier.png) 19 | 20 | ### Amazon SageMaker 21 | Amazon SageMaker offers suport to model training and instance management through a number of features exposed to developers and researched. One of the most interesting feature is its capability of managing GPU instances on your behalf, through Python CLI: 22 | 23 | ![Image](./sagemaker-sdk.png) 24 | 25 | ### Running this code 26 | The conde contained in this repo can be run either using SageMaker Notebook Instances and from a standard Python project. I personally prefer the latter approach because it does not require any Jupyter Notebook instance to be set up and configured and has the improved capability to create computation resources when they are needed and destroy them after usage, in a Serverless compliant way. 27 | 28 | #### Run on SageMaker 29 | To run this project on Amazon SageMaker, please spin up your Amazon SageMaker Notebook, attach a github repository, then run **notebook/sagemaker_deploy.ipynb** 30 | 31 | #### Run using Python 32 | To run the project using Python 33 | ![Sagemaker run](./sagemaker-run.png) 34 | 35 | ## Credits 36 | This project uses [SageMaker Execution Role creation](https://medium.com/ml-bytes/how-to-a-create-a-sagemaker-execution-role-539866910bda) script by [Márcio Dos Santos](https://medium.com/@marcio_santos?source=post_page-----539866910bda----------------------) 37 | 38 | ## License 39 | This project is released under [MIT](./LICENSE) license. Feel free to download, change and distribute the code published here. -------------------------------------------------------------------------------- /classifier.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aletheia/mnist_pl_sagemaker/9083684ee22f3dbf369c4a1509d2044760628a07/classifier.png -------------------------------------------------------------------------------- /code/MNISTClassifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import random as rn 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import functional as F 9 | from torch.utils.data import DataLoader 10 | from torch.utils.data.sampler import SubsetRandomSampler 11 | from torchvision import transforms as T, datasets 12 | import pytorch_lightning as pl 13 | 14 | 15 | class MNISTClassifier(pl.LightningModule): 16 | def __init__(self, train_data_dir, batch_size=128, test_data_dir=None, num_workers=4): 17 | '''Constructor method 18 | 19 | Parameters: 20 | train_data_dir (string): path of training dataset to be used either for training and validation 21 | batch_size (int): number of images per batch. Defaults to 128. 22 | test_data_dir (string): path of testing dataset to be used after training. Optional. 23 | num_workers (int): number of processes used by data loader. Defaults to 4. 24 | 25 | ''' 26 | 27 | # Invoke constructor 28 | super(MNISTClassifier, self).__init__() 29 | 30 | # Set up class attributes 31 | self.batch_size = batch_size 32 | self.train_data_dir = train_data_dir 33 | self.test_data_dir = test_data_dir 34 | self.num_workers = num_workers 35 | 36 | # Define network layers as class attributes to be used 37 | self.conv_layer_1 = torch.nn.Sequential( 38 | # The first block is made of a convolutional layer (3 channels, 28x28 images and a kernel mask of 5), 39 | torch.nn.Conv2d(3, 28, kernel_size=5), 40 | # a non linear activation function 41 | torch.nn.ReLU(), 42 | # a maximization layer, with mask of size 2 43 | torch.nn.MaxPool2d(kernel_size=2)) 44 | 45 | # A second block is equal to the first, except for input size which is different 46 | self.conv_layer_2 = torch.nn.Sequential( 47 | torch.nn.Conv2d(28, 10, kernel_size=2), 48 | torch.nn.ReLU(), 49 | torch.nn.MaxPool2d(kernel_size=2)) 50 | 51 | # A dropout layer, useful to reduce network overfitting 52 | self.dropout1 = torch.nn.Dropout(0.25) 53 | 54 | # A fully connected layer to reduce dimensionality 55 | self.fully_connected_1 = torch.nn.Linear(250, 18) 56 | 57 | # Another fine tuning dropout layer to make network fine tune 58 | self.dropout2 = torch.nn.Dropout(0.08) 59 | 60 | # The final fully connected layer wich output maps to the number of desired classes 61 | self.fully_connected_2 = torch.nn.Linear(18, 10) 62 | 63 | def load_split_train_test(self, valid_size=.2): 64 | '''Loads data and builds training/validation dataset with provided split size 65 | 66 | Parameters: 67 | valid_size (float): the percentage of data reserved to validation 68 | 69 | Returns: 70 | (torch.utils.data.DataLoader): Training data loader 71 | (torch.utils.data.DataLoader): Validation data loader 72 | (torch.utils.data.DataLoader): Test data loader 73 | 74 | ''' 75 | 76 | num_workers = self.num_workers 77 | 78 | # Create transforms for data augmentation. Since we don't care wheter numbers are upside-down, we add a horizontal flip, 79 | # then normalized data to PyTorch defaults 80 | train_transforms = T.Compose([T.RandomHorizontalFlip(), 81 | T.ToTensor(), 82 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 83 | # Use ImageFolder to load data from main folder. Images are contained in subfolders wich name represents their label. I.e. 84 | # training 85 | # |--> 0 86 | # | |--> image023.png 87 | # | |--> image024.png 88 | # | ... 89 | # |--> 1 90 | # | |--> image032.png 91 | # | |--> image0433.png 92 | # | ... 93 | # ... 94 | train_data = datasets.ImageFolder( 95 | self.train_data_dir, transform=train_transforms) 96 | 97 | # loads image indexes within dataset, then computes split and shuffles images to add randomness 98 | num_train = len(train_data) 99 | indices = list(range(num_train)) 100 | split = int(np.floor(valid_size * num_train)) 101 | np.random.shuffle(indices) 102 | 103 | # extracts indexes for train and validation, then builds a random sampler 104 | train_idx, val_idx = indices[split:], indices[:split] 105 | train_sampler = SubsetRandomSampler(train_idx) 106 | val_sampler = SubsetRandomSampler(val_idx) 107 | # which is passed to data loader to perform image sampling when loading data 108 | train_loader = torch.utils.data.DataLoader( 109 | train_data, sampler=train_sampler, batch_size=self.batch_size, num_workers=num_workers) 110 | val_loader = torch.utils.data.DataLoader( 111 | train_data, sampler=val_sampler, batch_size=self.batch_size, num_workers=num_workers) 112 | 113 | # if testing dataset is defined, we build its data loader as well 114 | test_loader = None 115 | if self.test_data_dir is not None: 116 | test_transforms = T.Compose([T.ToTensor(), T.Normalize( 117 | [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 118 | test_data = datasets.ImageFolder( 119 | self.test_data_dir, transform=test_transforms) 120 | test_loader = torch.utils.data.DataLoader( 121 | train_data, batch_size=self.batch_size, num_workers=num_workers) 122 | return train_loader, val_loader, test_loader 123 | 124 | def prepare_data(self): 125 | '''Prepares datasets. Called once per training execution 126 | ''' 127 | self.train_loader, self.val_loader, self.test_loader = self.load_split_train_test() 128 | 129 | def train_dataloader(self): 130 | ''' 131 | Returns: 132 | (torch.utils.data.DataLoader): Training set data loader 133 | ''' 134 | return self.train_loader 135 | 136 | def val_dataloader(self): 137 | ''' 138 | Returns: 139 | (torch.utils.data.DataLoader): Validation set data loader 140 | ''' 141 | return self.val_loader 142 | 143 | def test_dataloader(self): 144 | ''' 145 | Returns: 146 | (torch.utils.data.DataLoader): Testing set data loader 147 | ''' 148 | return DataLoader(MNIST(os.getcwd(), train=False, download=False, transform=transform.ToTensor()), batch_size=128) 149 | 150 | def forward(self, x): 151 | '''Forward pass, it is equal to PyTorch forward method. Here network computational graph is built 152 | 153 | Parameters: 154 | x (Tensor): A Tensor containing the input batch of the network 155 | 156 | Returns: 157 | An one dimensional Tensor with probability array for each input image 158 | ''' 159 | x = self.conv_layer_1(x) 160 | x = self.conv_layer_2(x) 161 | x = self.dropout1(x) 162 | x = torch.relu(self.fully_connected_1(x.view(x.size(0), -1))) 163 | x = F.leaky_relu(self.dropout2(x)) 164 | return F.softmax(self.fully_connected_2(x), dim=1) 165 | 166 | def configure_optimizers(self): 167 | ''' 168 | Returns: 169 | (Optimizer): Adam optimizer tuned wit model parameters 170 | ''' 171 | return torch.optim.Adam(self.parameters()) 172 | 173 | def training_step(self, batch, batch_idx): 174 | '''Called for every training step, uses NLL Loss to compute training loss, then logs and sends back 175 | logs parameter to Trainer to perform backpropagation 176 | 177 | ''' 178 | 179 | # Get input and output from batch 180 | x, labels = batch 181 | 182 | # Compute prediction through the network 183 | prediction = self.forward(x) 184 | 185 | loss = F.nll_loss(prediction, labels) 186 | 187 | # Logs training loss 188 | logs = {'train_loss': loss} 189 | 190 | output = { 191 | # This is required in training to be used by backpropagation 192 | 'loss': loss, 193 | # This is optional for logging pourposes 194 | 'log': logs 195 | } 196 | 197 | return output 198 | 199 | def test_step(self, batch, batch_idx): 200 | '''Called for every testing step, uses NLL Loss to compute testing loss 201 | ''' 202 | # Get input and output from batch 203 | x, labels = batch 204 | 205 | # Compute prediction through the network 206 | prediction = self.forward(x) 207 | 208 | loss = F.nll_loss(prediction, labels) 209 | 210 | # Logs training loss 211 | logs = {'train_loss': loss} 212 | 213 | output = { 214 | # This is required in training to be used by backpropagation 215 | 'loss': loss, 216 | # This is optional for logging pourposes 217 | 'log': logs 218 | } 219 | 220 | return output 221 | 222 | def validation_step(self, batch, batch_idx): 223 | ''' Prforms model validation computing cross entropy for predictions and labels 224 | ''' 225 | x, labels = batch 226 | prediction = self.forward(x) 227 | return { 228 | 'val_loss': F.cross_entropy(prediction, labels) 229 | } 230 | 231 | def validation_epoch_end(self, outputs): 232 | '''Called after every epoch, stacks validation loss 233 | ''' 234 | val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean() 235 | return {'val_loss': val_loss_mean} 236 | 237 | def validation_end(self, outputs): 238 | '''Called after validation completes. Stacks all testing loss and computes average. 239 | ''' 240 | avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() 241 | print('Average training loss: '+str(avg_loss.item())) 242 | logs = {'val_loss': avg_loss} 243 | return { 244 | 'avg_val_loss': avg_loss, 245 | 'log': logs 246 | } 247 | 248 | def testing_step(self, batch, batch_idx): 249 | ''' Prforms model testing, computing cross entropy for predictions and labels 250 | ''' 251 | x, labels = batch 252 | prediction = self.forward(x) 253 | return { 254 | 'test_loss': F.cross_entropy(prediction, labels) 255 | } 256 | 257 | def testing_epoch_end(self, outputs): 258 | '''Called after every epoch, stacks testing loss 259 | ''' 260 | test_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean() 261 | return {'test_loss': test_loss_mean} 262 | 263 | def testing_end(self, outputs): 264 | '''Called after testing completes. Stacks all testing loss and computes average. 265 | ''' 266 | avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() 267 | print('Average testing loss: '+str(avg_loss.item())) 268 | logs = {'test_loss': avg_loss} 269 | return { 270 | 'avg_test_loss': avg_loss, 271 | 'log': logs 272 | } 273 | -------------------------------------------------------------------------------- /code/requirements.txt: -------------------------------------------------------------------------------- 1 | -i https://pypi.org/simple 2 | absl-py==0.9.0 3 | boto3==1.14.20 4 | botocore==1.17.20 5 | cachetools==4.1.1; python_version ~= '3.5' 6 | certifi==2022.12.7 7 | chardet==3.0.4 8 | docutils==0.15.2; python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3' 9 | future==0.18.2; python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3' 10 | google-auth-oauthlib==0.4.1 11 | google-auth==1.18.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3' 12 | grpcio==1.30.0 13 | idna==2.10; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3' 14 | importlib-metadata==1.7.0; python_version < '3.8' 15 | jmespath==0.10.0; python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3' 16 | markdown==3.2.2; python_version >= '3.5' 17 | numpy==1.19.0; python_version >= '3.6' 18 | oauthlib==3.1.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3' 19 | packaging==20.4; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3' 20 | protobuf3-to-dict==0.1.5 21 | protobuf==3.18.3 22 | pyasn1-modules==0.2.8 23 | pyasn1==0.4.8 24 | pyparsing==2.4.7; python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3' 25 | python-dateutil==2.8.1; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3' 26 | pytorch-lightning==1.6.0 27 | pyyaml==5.4 28 | requests-oauthlib==1.3.0 29 | requests==2.24.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4' 30 | rsa==4.7; python_version >= '3' 31 | s3transfer==0.3.3 32 | sagemaker==1.69.0 33 | scipy==1.5.1; python_version >= '3.6' 34 | six==1.15.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3' 35 | smdebug-rulesconfig==0.1.4; python_version >= '2.7' 36 | tensorboard-plugin-wit==1.7.0 37 | tensorboard==2.2.2; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3' 38 | torch==1.5.1 39 | tqdm==4.47.0; python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3' 40 | urllib3==1.26.5; python_version != '3.4' 41 | werkzeug==1.0.1; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4' 42 | wheel==0.38.1; python_version >= '3' 43 | zipp==3.1.0; python_version >= '3.6' 44 | -------------------------------------------------------------------------------- /code/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | # default pytorch import 5 | import torch 6 | 7 | # import lightning library 8 | import pytorch_lightning as pl 9 | 10 | # import trainer class, which orchestrates our model training 11 | from pytorch_lightning import Trainer 12 | 13 | # import our model class, to be trained 14 | from MNISTClassifier import MNISTClassifier 15 | 16 | # This is the main method, to be run when train.py is invoked 17 | if __name__ =='__main__': 18 | 19 | parser = argparse.ArgumentParser() 20 | 21 | # hyperparameters sent by the client are passed as command-line arguments to the script. 22 | parser.add_argument('--epochs', type=int, default=50) 23 | parser.add_argument('--batch-size', type=int, default=64) 24 | parser.add_argument('--gpus', type=int, default=1) # used to support multi-GPU or CPU training 25 | 26 | # Data, model, and output directories. Passed by sagemaker with default to os env variables 27 | parser.add_argument('-o','--output-data-dir', type=str, default=os.environ['SM_OUTPUT_DATA_DIR']) 28 | parser.add_argument('-m','--model-dir', type=str, default=os.environ['SM_MODEL_DIR']) 29 | parser.add_argument('-tr','--train', type=str, default=os.environ['SM_CHANNEL_TRAIN']) 30 | parser.add_argument('-te','--test', type=str, default=os.environ['SM_CHANNEL_TEST']) 31 | 32 | args, _ = parser.parse_known_args() 33 | print(args) 34 | 35 | # Now we have all parameters and hyperparameters available and we need to match them with sagemaker 36 | # structure. default_root_dir is set to out_put_data_dir to retrieve from training instances all the 37 | # checkpoint and intermediary data produced by lightning 38 | mnistTrainer=pl.Trainer(gpus=args.gpus, max_epochs=args.epochs, default_root_dir=args.output_data_dir) 39 | 40 | # Set up our classifier class, passing params to the constructor 41 | model = MNISTClassifier( 42 | batch_size=args.batch_size, 43 | train_data_dir=args.train, 44 | test_data_dir=args.test 45 | ) 46 | 47 | # Runs model training 48 | mnistTrainer.fit(model) 49 | 50 | # After model has been trained, save its state into model_dir which is then copied to back S3 51 | with open(os.path.join(args.model_dir, 'model.pth'), 'wb') as f: 52 | torch.save(model.state_dict(), f) 53 | -------------------------------------------------------------------------------- /create_execution_role.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This script creates a role named SageMakerRole 4 | # that can be used by SageMaker and has Full access to S3. 5 | 6 | ROLE_NAME=SageMakerRole_MNIST 7 | 8 | # WARNING: this policy gives full S3 access to container that 9 | # is running in SageMaker. You can change this policy to a more 10 | # restrictive one, or create your own policy. 11 | POLICY=arn:aws:iam::aws:policy/AmazonS3FullAccess 12 | 13 | # Creates a AWS policy that allows the role to interact 14 | # with ANY S3 bucket 15 | cat < ./assume-role-policy-document.json 16 | { 17 | "Version": "2012-10-17", 18 | "Statement": [{ 19 | "Effect": "Allow", 20 | "Principal": { 21 | "Service": "sagemaker.amazonaws.com" 22 | }, 23 | "Action": "sts:AssumeRole" 24 | }] 25 | } 26 | EOF 27 | 28 | # Creates the role 29 | aws iam create-role --role-name ${ROLE_NAME} --assume-role-policy-document file://./assume-role-policy-document.json 30 | 31 | # attaches the S3 full access policy to the role 32 | aws iam attach-role-policy --policy-arn ${POLICY} --role-name ${ROLE_NAME} -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # MNIST on SageMaker with PyTorch Lightning 2 | import json 3 | import boto3 4 | import sagemaker 5 | from sagemaker.pytorch import PyTorch 6 | 7 | # Initializes SageMaker session which holds context data 8 | sagemaker_session = sagemaker.Session() 9 | 10 | # The bucket containig our input data 11 | bucket = 's3://dataset.mnist' 12 | 13 | # The IAM Role which SageMaker will impersonate to run the estimator 14 | # Remember you cannot use sagemaker.get_execution_role() 15 | # if you're not in a SageMaker notebook, an EC2 or a Lambda 16 | # (i.e. running from your local PC) 17 | 18 | # sagemaker.get_execution_role() 19 | role = 'arn:aws:iam::XXXXXX:role/SageMakerRole_MNIST' 20 | 21 | # Creates a new PyTorch Estimator with params 22 | estimator = PyTorch( 23 | # name of the runnable script containing __main__ function (entrypoint) 24 | entry_point='train.py', 25 | # path of the folder containing training code. It could also contain a 26 | # requirements.txt file with all the dependencies that needs 27 | # to be installed before running 28 | source_dir='code', 29 | role=role, 30 | framework_version='1.4.0', 31 | train_instance_count=1, 32 | train_instance_type='ml.p2.xlarge', 33 | # these hyperparameters are passed to the main script as arguments and 34 | # can be overridden when fine tuning the algorithm 35 | hyperparameters={ 36 | 'epochs': 6, 37 | 'batch-size': 128, 38 | }) 39 | 40 | # Call fit method on estimator, wich trains our model, passing training 41 | # and testing datasets as environment variables. Data is copied from S3 42 | # before initializing the container 43 | estimator.fit({ 44 | 'train': bucket+'/training', 45 | 'test': bucket+'/testing' 46 | }) 47 | -------------------------------------------------------------------------------- /notebook/.ipynb_checkpoints/lightning_sagemaker-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# MNIST on SageMaker with PyTorch Lightning" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Download dataset to local folder" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "S3_DATA_BUCKET = 'dataset.mnist'\n", 24 | "S3_TRAINING_DATA = S3_DATA_BUCKET+'/training'\n", 25 | "S3_TESTING_DATA = S3_DATA_BUCKET+'/testing'\n", 26 | "\n", 27 | "DATA_PATH = '../dataset'\n", 28 | "BATCH_SIZE = 128" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": { 35 | "collapsed": true, 36 | "jupyter": { 37 | "outputs_hidden": true 38 | } 39 | }, 40 | "outputs": [], 41 | "source": [ 42 | "#!mkdir -p $DATA_PATH/training\n", 43 | "#!mkdir -p $DATA_PATH/testing\n", 44 | "#!aws s3api get-object --bucket $S3_DATA_BUCKET --key mnist.tar.gz $DATA_PATH/mnist.tar.gz\n", 45 | "#!cd $DATA_PATH && tar xvf mnist.tar.gz && rm -f mnist.tar.gz" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 3, 51 | "metadata": { 52 | "collapsed": true, 53 | "jupyter": { 54 | "outputs_hidden": true 55 | } 56 | }, 57 | "outputs": [ 58 | { 59 | "name": "stdout", 60 | "output_type": "stream", 61 | "text": [ 62 | "Requirement already satisfied: torch in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (1.4.0)\n", 63 | "\u001b[33mWARNING: You are using pip version 20.0.2; however, version 20.1.1 is available.\n", 64 | "You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_p36/bin/python -m pip install --upgrade pip' command.\u001b[0m\n", 65 | "Requirement already satisfied: torchvision in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (0.5.0)\n", 66 | "Requirement already satisfied: numpy in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from torchvision) (1.18.1)\n", 67 | "Requirement already satisfied: six in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from torchvision) (1.14.0)\n", 68 | "Requirement already satisfied: torch in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from torchvision) (1.4.0)\n", 69 | "Requirement already satisfied: pillow>=4.1.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from torchvision) (7.0.0)\n", 70 | "\u001b[33mWARNING: You are using pip version 20.0.2; however, version 20.1.1 is available.\n", 71 | "You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_p36/bin/python -m pip install --upgrade pip' command.\u001b[0m\n", 72 | "Requirement already satisfied: pytorch_lightning in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (0.8.5)\n", 73 | "Requirement already satisfied: tensorboard>=1.14 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from pytorch_lightning) (2.2.2)\n", 74 | "Requirement already satisfied: torch>=1.3 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from pytorch_lightning) (1.4.0)\n", 75 | "Requirement already satisfied: tqdm>=4.41.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from pytorch_lightning) (4.44.1)\n", 76 | "Requirement already satisfied: numpy>=1.16.4 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from pytorch_lightning) (1.18.1)\n", 77 | "Requirement already satisfied: future>=0.17.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from pytorch_lightning) (0.18.2)\n", 78 | "Requirement already satisfied: PyYAML>=5.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from pytorch_lightning) (5.3.1)\n", 79 | "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from tensorboard>=1.14->pytorch_lightning) (0.4.1)\n", 80 | "Requirement already satisfied: google-auth<2,>=1.6.3 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from tensorboard>=1.14->pytorch_lightning) (1.18.0)\n", 81 | "Requirement already satisfied: grpcio>=1.24.3 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from tensorboard>=1.14->pytorch_lightning) (1.30.0)\n", 82 | "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from tensorboard>=1.14->pytorch_lightning) (1.7.0)\n", 83 | "Requirement already satisfied: werkzeug>=0.11.15 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from tensorboard>=1.14->pytorch_lightning) (1.0.1)\n", 84 | "Requirement already satisfied: wheel>=0.26; python_version >= \"3\" in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from tensorboard>=1.14->pytorch_lightning) (0.34.2)\n", 85 | "Requirement already satisfied: protobuf>=3.6.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from tensorboard>=1.14->pytorch_lightning) (3.12.2)\n", 86 | "Requirement already satisfied: requests<3,>=2.21.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from tensorboard>=1.14->pytorch_lightning) (2.23.0)\n", 87 | "Requirement already satisfied: absl-py>=0.4 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from tensorboard>=1.14->pytorch_lightning) (0.9.0)\n", 88 | "Requirement already satisfied: six>=1.10.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from tensorboard>=1.14->pytorch_lightning) (1.14.0)\n", 89 | "Requirement already satisfied: markdown>=2.6.8 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from tensorboard>=1.14->pytorch_lightning) (3.2.2)\n", 90 | "Requirement already satisfied: setuptools>=41.0.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from tensorboard>=1.14->pytorch_lightning) (46.1.3.post20200330)\n", 91 | "Requirement already satisfied: requests-oauthlib>=0.7.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=1.14->pytorch_lightning) (1.3.0)\n", 92 | "Requirement already satisfied: rsa<5,>=3.1.4; python_version >= \"3\" in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from google-auth<2,>=1.6.3->tensorboard>=1.14->pytorch_lightning) (3.4.2)\n", 93 | "Requirement already satisfied: pyasn1-modules>=0.2.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from google-auth<2,>=1.6.3->tensorboard>=1.14->pytorch_lightning) (0.2.8)\n", 94 | "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from google-auth<2,>=1.6.3->tensorboard>=1.14->pytorch_lightning) (4.1.1)\n", 95 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from requests<3,>=2.21.0->tensorboard>=1.14->pytorch_lightning) (1.25.8)\n", 96 | "Requirement already satisfied: idna<3,>=2.5 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from requests<3,>=2.21.0->tensorboard>=1.14->pytorch_lightning) (2.9)\n", 97 | "Requirement already satisfied: certifi>=2017.4.17 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from requests<3,>=2.21.0->tensorboard>=1.14->pytorch_lightning) (2020.4.5.2)\n", 98 | "Requirement already satisfied: chardet<4,>=3.0.2 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from requests<3,>=2.21.0->tensorboard>=1.14->pytorch_lightning) (3.0.4)\n", 99 | "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from markdown>=2.6.8->tensorboard>=1.14->pytorch_lightning) (1.5.0)\n", 100 | "Requirement already satisfied: oauthlib>=3.0.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=1.14->pytorch_lightning) (3.1.0)\n", 101 | "Requirement already satisfied: pyasn1>=0.1.3 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from rsa<5,>=3.1.4; python_version >= \"3\"->google-auth<2,>=1.6.3->tensorboard>=1.14->pytorch_lightning) (0.4.8)\n", 102 | "Requirement already satisfied: zipp>=0.5 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from importlib-metadata; python_version < \"3.8\"->markdown>=2.6.8->tensorboard>=1.14->pytorch_lightning) (2.2.0)\n", 103 | "\u001b[33mWARNING: You are using pip version 20.0.2; however, version 20.1.1 is available.\n", 104 | "You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_p36/bin/python -m pip install --upgrade pip' command.\u001b[0m\n" 105 | ] 106 | } 107 | ], 108 | "source": [ 109 | "# Install libraries if not already installed\n", 110 | "! pip install torch\n", 111 | "! pip install torchvision\n", 112 | "! pip install pytorch_lightning" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 4, 118 | "metadata": { 119 | "collapsed": true, 120 | "jupyter": { 121 | "outputs_hidden": true 122 | } 123 | }, 124 | "outputs": [ 125 | { 126 | "name": "stdout", 127 | "output_type": "stream", 128 | "text": [ 129 | "Requirement already satisfied: ipywidgets in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (7.5.1)\n", 130 | "Requirement already satisfied: nbformat>=4.2.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipywidgets) (5.0.4)\n", 131 | "Requirement already satisfied: widgetsnbextension~=3.5.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipywidgets) (3.5.1)\n", 132 | "Requirement already satisfied: traitlets>=4.3.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipywidgets) (4.3.3)\n", 133 | "Requirement already satisfied: ipython>=4.0.0; python_version >= \"3.3\" in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipywidgets) (7.13.0)\n", 134 | "Requirement already satisfied: ipykernel>=4.5.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipywidgets) (5.1.4)\n", 135 | "Requirement already satisfied: jsonschema!=2.5.0,>=2.4 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from nbformat>=4.2.0->ipywidgets) (3.2.0)\n", 136 | "Requirement already satisfied: ipython-genutils in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from nbformat>=4.2.0->ipywidgets) (0.2.0)\n", 137 | "Requirement already satisfied: jupyter-core in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from nbformat>=4.2.0->ipywidgets) (4.6.3)\n", 138 | "Requirement already satisfied: notebook>=4.4.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from widgetsnbextension~=3.5.0->ipywidgets) (6.0.3)\n", 139 | "Requirement already satisfied: six in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from traitlets>=4.3.1->ipywidgets) (1.14.0)\n", 140 | "Requirement already satisfied: decorator in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from traitlets>=4.3.1->ipywidgets) (4.4.2)\n", 141 | "Requirement already satisfied: pygments in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (2.6.1)\n", 142 | "Requirement already satisfied: pexpect; sys_platform != \"win32\" in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (4.8.0)\n", 143 | "Requirement already satisfied: setuptools>=18.5 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (46.1.3.post20200330)\n", 144 | "Requirement already satisfied: pickleshare in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.7.5)\n", 145 | "Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (3.0.4)\n", 146 | "Requirement already satisfied: jedi>=0.10 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.15.2)\n", 147 | "Requirement already satisfied: backcall in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.1.0)\n", 148 | "Requirement already satisfied: tornado>=4.2 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipykernel>=4.5.1->ipywidgets) (6.0.4)\n", 149 | "Requirement already satisfied: jupyter-client in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipykernel>=4.5.1->ipywidgets) (6.1.2)\n", 150 | "Requirement already satisfied: pyrsistent>=0.14.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets) (0.16.0)\n", 151 | "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets) (1.5.0)\n", 152 | "Requirement already satisfied: attrs>=17.4.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets) (19.3.0)\n", 153 | "Requirement already satisfied: Send2Trash in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (1.5.0)\n", 154 | "Requirement already satisfied: nbconvert in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (5.6.1)\n", 155 | "Requirement already satisfied: pyzmq>=17 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (18.1.1)\n", 156 | "Requirement already satisfied: terminado>=0.8.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.8.3)\n", 157 | "Requirement already satisfied: prometheus-client in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.7.1)\n", 158 | "Requirement already satisfied: jinja2 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (2.11.1)\n", 159 | "Requirement already satisfied: ptyprocess>=0.5 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from pexpect; sys_platform != \"win32\"->ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.6.0)\n", 160 | "Requirement already satisfied: wcwidth in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.1.9)\n", 161 | "Requirement already satisfied: parso>=0.5.2 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from jedi>=0.10->ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.5.2)\n", 162 | "Requirement already satisfied: python-dateutil>=2.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from jupyter-client->ipykernel>=4.5.1->ipywidgets) (2.8.1)\n", 163 | "Requirement already satisfied: zipp>=0.5 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from importlib-metadata; python_version < \"3.8\"->jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets) (2.2.0)\n", 164 | "Requirement already satisfied: mistune<2,>=0.8.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.8.4)\n", 165 | "Requirement already satisfied: bleach in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (3.1.4)\n", 166 | "Requirement already satisfied: defusedxml in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.6.0)\n", 167 | "Requirement already satisfied: entrypoints>=0.2.2 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.3)\n", 168 | "Requirement already satisfied: testpath in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.4.4)\n", 169 | "Requirement already satisfied: pandocfilters>=1.4.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (1.4.2)\n", 170 | "Requirement already satisfied: MarkupSafe>=0.23 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from jinja2->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (1.1.1)\n", 171 | "Requirement already satisfied: webencodings in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.5.1)\n", 172 | "\u001b[33mWARNING: You are using pip version 20.0.2; however, version 20.1.1 is available.\n", 173 | "You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_p36/bin/python -m pip install --upgrade pip' command.\u001b[0m\n", 174 | "Enabling notebook extension jupyter-js-widgets/extension...\n", 175 | " - Validating: \u001b[32mOK\u001b[0m\n" 176 | ] 177 | } 178 | ], 179 | "source": [ 180 | "# Install libraries specific for Jupyter Notebook\n", 181 | "! pip install ipywidgets\n", 182 | "! jupyter nbextension enable --py widgetsnbextension" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 21, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "import os\n", 192 | "import math\n", 193 | "import random as rn\n", 194 | "\n", 195 | "import numpy as np # linear algebra\n", 196 | "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n", 197 | "from sklearn.model_selection import train_test_split\n", 198 | "from sklearn.metrics import accuracy_score\n", 199 | "from PIL import Image, ImageFile\n", 200 | "\n", 201 | "import torch\n", 202 | "import torch.nn as nn\n", 203 | "from torch.nn import functional as F\n", 204 | "from torch.utils.data import DataLoader\n", 205 | "from torchvision import transforms as T, datasets\n", 206 | "import pytorch_lightning as pl\n", 207 | "from pytorch_lightning import Trainer\n", 208 | "from pytorch_lightning.callbacks import ModelCheckpoint\n" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 6, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "# Import libraries specific for Jupyter notebook visualization\n", 218 | "from matplotlib import pyplot as plt, image\n", 219 | "from PIL import Image\n", 220 | "%matplotlib inline" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 8, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "#train_data_dir = DATA_PATH+'/training'\n", 230 | "#dataset = datasets.ImageFolder(train_data_dir)\n", 231 | "#train_set, val_set = torch.utils.data.random_split(dataset, [55000, 5000])" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "metadata": {}, 237 | "source": [ 238 | "Display a few images to have an idea about the input" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 9, 244 | "metadata": { 245 | "collapsed": true, 246 | "jupyter": { 247 | "outputs_hidden": true 248 | } 249 | }, 250 | "outputs": [], 251 | "source": [ 252 | "#plt.imshow(image.imread(train_set.dataset.imgs[0][0]))\n", 253 | "#plt.figure()\n", 254 | "#plt.imshow(image.imread(train_set.dataset.imgs[7000][0]))\n", 255 | "#plt.figure()\n", 256 | "#plt.imshow(image.imread(train_set.dataset.imgs[20000][0]))\n", 257 | "#plt.figure()" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 10, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "#fix random seed\n", 267 | "os.environ['PYTHONHASHSEED'] = '0'\n", 268 | "np.random.seed(42)\n", 269 | "rn.seed(12345)\n", 270 | "torch.manual_seed(2020)\n", 271 | "torch.cuda.manual_seed(2020)\n", 272 | "torch.cuda.manual_seed_all(2020)\n", 273 | "torch.backends.cudnn.deterministic = True" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 11, 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "num_workers = 4\n", 283 | "epochs = 10\n", 284 | "validation_size = .3\n", 285 | "batch_size = 128" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": 12, 291 | "metadata": {}, 292 | "outputs": [], 293 | "source": [ 294 | "data_dir = DATA_PATH+'/training'" 295 | ] 296 | }, 297 | { 298 | "cell_type": "markdown", 299 | "metadata": {}, 300 | "source": [ 301 | "## Create model" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": 22, 307 | "metadata": {}, 308 | "outputs": [], 309 | "source": [ 310 | "class MNISTClassifier(pl.LightningModule):\n", 311 | " def __init__(self):\n", 312 | " super(MNISTClassifier, self).__init__()\n", 313 | " self.conv_layer_1 = torch.nn.Sequential(\n", 314 | " torch.nn.Conv2d(3,28, kernel_size=5),\n", 315 | " torch.nn.ReLU(),\n", 316 | " torch.nn.MaxPool2d(kernel_size=2))\n", 317 | " self.conv_layer_2 = torch.nn.Sequential(\n", 318 | " torch.nn.Conv2d(28,10, kernel_size=2),\n", 319 | " torch.nn.ReLU(),\n", 320 | " torch.nn.MaxPool2d(kernel_size=2))\n", 321 | " self.dropout1=torch.nn.Dropout(0.25)\n", 322 | " self.fully_connected_1=torch.nn.Linear(250,18)\n", 323 | " self.dropout2=torch.nn.Dropout(0.08)\n", 324 | " self.fully_connected_2=torch.nn.Linear(18,10)\n", 325 | "\n", 326 | " def load_split_train_test(self, datadir, valid_size = .2):\n", 327 | " train_transforms = T.Compose([T.RandomHorizontalFlip(), \n", 328 | " T.ToTensor(),\n", 329 | " T.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])\n", 330 | "\n", 331 | " test_transforms = T.Compose([T.ToTensor(),T.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])\n", 332 | "\n", 333 | " train_data = datasets.ImageFolder(datadir, transform=train_transforms)\n", 334 | " test_data = datasets.ImageFolder(datadir, transform=test_transforms)\n", 335 | "\n", 336 | " num_train = len(train_data)\n", 337 | " indices = list(range(num_train))\n", 338 | " split = int(np.floor(valid_size * num_train))\n", 339 | " np.random.shuffle(indices)\n", 340 | " from torch.utils.data.sampler import SubsetRandomSampler\n", 341 | " train_idx, test_idx = indices[split:], indices[:split]\n", 342 | " train_sampler = SubsetRandomSampler(train_idx)\n", 343 | " test_sampler = SubsetRandomSampler(test_idx)\n", 344 | " trainloader = torch.utils.data.DataLoader(train_data, sampler=train_sampler, batch_size=batch_size, num_workers=num_workers)\n", 345 | " testloader = torch.utils.data.DataLoader(test_data, sampler=test_sampler, batch_size=batch_size, num_workers=num_workers)\n", 346 | " return trainloader, testloader\n", 347 | " \n", 348 | " def prepare_data(self):\n", 349 | " self.train_loader, self.val_loader = self.load_split_train_test(data_dir, validation_size)\n", 350 | " \n", 351 | " def train_dataloader(self):\n", 352 | " return self.train_loader\n", 353 | " \n", 354 | " def val_dataloader(self):\n", 355 | " return self.val_loader\n", 356 | " \n", 357 | "# def test_dataloader(self):\n", 358 | "# return DataLoader(MNIST(os.getcwd(), train=False, download=False, transform=transform.ToTensor()), batch_size=128)\n", 359 | " \n", 360 | " def forward(self,x):\n", 361 | " x=self.conv_layer_1(x)\n", 362 | " x=self.conv_layer_2(x)\n", 363 | " x=self.dropout1(x)\n", 364 | " x=torch.relu(self.fully_connected_1(x.view(x.size(0),-1)))\n", 365 | " x=F.leaky_relu(self.dropout2(x))\n", 366 | " return F.softmax(self.fully_connected_2(x), dim=1)\n", 367 | " \n", 368 | " def configure_optimizers(self):\n", 369 | " return torch.optim.Adam(self.parameters())\n", 370 | " \n", 371 | " def training_step(self, batch, batch_idx):\n", 372 | " \n", 373 | " # Get input and output from batch\n", 374 | " x, labels = batch\n", 375 | " \n", 376 | " # Compute prediction through the network\n", 377 | " prediction = self.forward(x)\n", 378 | " \n", 379 | " loss = F.nll_loss(prediction, labels)\n", 380 | " \n", 381 | " # Logs training loss\n", 382 | " logs={'train_loss':loss}\n", 383 | " \n", 384 | " output = {\n", 385 | " # This is required in training to be used by backpropagation\n", 386 | " 'loss':loss,\n", 387 | " # This is optional for logging pourposes\n", 388 | " 'log':logs\n", 389 | " }\n", 390 | " \n", 391 | " return output\n", 392 | " \n", 393 | " def validation_step(self, batch, batch_idx):\n", 394 | " x, labels = batch\n", 395 | " prediction = self.forward(x)\n", 396 | " return {\n", 397 | " 'val_loss': F.cross_entropy(prediction, labels)\n", 398 | " }\n", 399 | " \n", 400 | " def validation_epoch_end(self, outputs):\n", 401 | " val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean()\n", 402 | " return {'val_loss': val_loss_mean}\n", 403 | "\n", 404 | " \n", 405 | " def validation_end(self, outputs):\n", 406 | " avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()\n", 407 | " print('Average training loss: '+str(avg_loss.item()))\n", 408 | " logs = {'val_loss':avg_loss}\n", 409 | " return {\n", 410 | " 'avg_val_loss':avg_loss,\n", 411 | " 'log':logs\n", 412 | " }" 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": 23, 418 | "metadata": {}, 419 | "outputs": [ 420 | { 421 | "name": "stderr", 422 | "output_type": "stream", 423 | "text": [ 424 | "GPU available: True, used: True\n", 425 | "TPU available: False, using: 0 TPU cores\n", 426 | "CUDA_VISIBLE_DEVICES: [0]\n", 427 | "\n", 428 | " | Name | Type | Params\n", 429 | "-------------------------------------------------\n", 430 | "0 | conv_layer_1 | Sequential | 2 K \n", 431 | "1 | conv_layer_2 | Sequential | 1 K \n", 432 | "2 | dropout1 | Dropout | 0 \n", 433 | "3 | fully_connected_1 | Linear | 4 K \n", 434 | "4 | dropout2 | Dropout | 0 \n", 435 | "5 | fully_connected_2 | Linear | 190 \n" 436 | ] 437 | }, 438 | { 439 | "data": { 440 | "application/vnd.jupyter.widget-view+json": { 441 | "model_id": "", 442 | "version_major": 2, 443 | "version_minor": 0 444 | }, 445 | "text/plain": [ 446 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…" 447 | ] 448 | }, 449 | "metadata": {}, 450 | "output_type": "display_data" 451 | }, 452 | { 453 | "name": "stdout", 454 | "output_type": "stream", 455 | "text": [ 456 | "Average training loss: 2.302227735519409\n", 457 | "\r" 458 | ] 459 | }, 460 | { 461 | "data": { 462 | "application/vnd.jupyter.widget-view+json": { 463 | "model_id": "51e3aaaa31df49be8101f9725da2be9e", 464 | "version_major": 2, 465 | "version_minor": 0 466 | }, 467 | "text/plain": [ 468 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…" 469 | ] 470 | }, 471 | "metadata": {}, 472 | "output_type": "display_data" 473 | }, 474 | { 475 | "data": { 476 | "application/vnd.jupyter.widget-view+json": { 477 | "model_id": "", 478 | "version_major": 2, 479 | "version_minor": 0 480 | }, 481 | "text/plain": [ 482 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" 483 | ] 484 | }, 485 | "metadata": {}, 486 | "output_type": "display_data" 487 | }, 488 | { 489 | "name": "stdout", 490 | "output_type": "stream", 491 | "text": [ 492 | "Average training loss: 1.8066515922546387\n" 493 | ] 494 | }, 495 | { 496 | "data": { 497 | "application/vnd.jupyter.widget-view+json": { 498 | "model_id": "", 499 | "version_major": 2, 500 | "version_minor": 0 501 | }, 502 | "text/plain": [ 503 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" 504 | ] 505 | }, 506 | "metadata": {}, 507 | "output_type": "display_data" 508 | }, 509 | { 510 | "name": "stdout", 511 | "output_type": "stream", 512 | "text": [ 513 | "Average training loss: 1.7187001705169678\n" 514 | ] 515 | }, 516 | { 517 | "data": { 518 | "application/vnd.jupyter.widget-view+json": { 519 | "model_id": "", 520 | "version_major": 2, 521 | "version_minor": 0 522 | }, 523 | "text/plain": [ 524 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" 525 | ] 526 | }, 527 | "metadata": {}, 528 | "output_type": "display_data" 529 | }, 530 | { 531 | "name": "stdout", 532 | "output_type": "stream", 533 | "text": [ 534 | "Average training loss: 1.705643892288208\n" 535 | ] 536 | }, 537 | { 538 | "data": { 539 | "application/vnd.jupyter.widget-view+json": { 540 | "model_id": "", 541 | "version_major": 2, 542 | "version_minor": 0 543 | }, 544 | "text/plain": [ 545 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" 546 | ] 547 | }, 548 | "metadata": {}, 549 | "output_type": "display_data" 550 | }, 551 | { 552 | "name": "stdout", 553 | "output_type": "stream", 554 | "text": [ 555 | "Average training loss: 1.702525019645691\n" 556 | ] 557 | }, 558 | { 559 | "data": { 560 | "application/vnd.jupyter.widget-view+json": { 561 | "model_id": "", 562 | "version_major": 2, 563 | "version_minor": 0 564 | }, 565 | "text/plain": [ 566 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" 567 | ] 568 | }, 569 | "metadata": {}, 570 | "output_type": "display_data" 571 | }, 572 | { 573 | "name": "stdout", 574 | "output_type": "stream", 575 | "text": [ 576 | "Average training loss: 1.701408863067627\n" 577 | ] 578 | }, 579 | { 580 | "data": { 581 | "application/vnd.jupyter.widget-view+json": { 582 | "model_id": "", 583 | "version_major": 2, 584 | "version_minor": 0 585 | }, 586 | "text/plain": [ 587 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" 588 | ] 589 | }, 590 | "metadata": {}, 591 | "output_type": "display_data" 592 | }, 593 | { 594 | "name": "stdout", 595 | "output_type": "stream", 596 | "text": [ 597 | "Average training loss: 1.69669771194458\n" 598 | ] 599 | }, 600 | { 601 | "data": { 602 | "application/vnd.jupyter.widget-view+json": { 603 | "model_id": "", 604 | "version_major": 2, 605 | "version_minor": 0 606 | }, 607 | "text/plain": [ 608 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" 609 | ] 610 | }, 611 | "metadata": {}, 612 | "output_type": "display_data" 613 | }, 614 | { 615 | "name": "stdout", 616 | "output_type": "stream", 617 | "text": [ 618 | "Average training loss: 1.697007179260254\n" 619 | ] 620 | }, 621 | { 622 | "data": { 623 | "application/vnd.jupyter.widget-view+json": { 624 | "model_id": "", 625 | "version_major": 2, 626 | "version_minor": 0 627 | }, 628 | "text/plain": [ 629 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" 630 | ] 631 | }, 632 | "metadata": {}, 633 | "output_type": "display_data" 634 | }, 635 | { 636 | "name": "stdout", 637 | "output_type": "stream", 638 | "text": [ 639 | "Average training loss: 1.6952210664749146\n" 640 | ] 641 | }, 642 | { 643 | "data": { 644 | "application/vnd.jupyter.widget-view+json": { 645 | "model_id": "", 646 | "version_major": 2, 647 | "version_minor": 0 648 | }, 649 | "text/plain": [ 650 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" 651 | ] 652 | }, 653 | "metadata": {}, 654 | "output_type": "display_data" 655 | }, 656 | { 657 | "name": "stdout", 658 | "output_type": "stream", 659 | "text": [ 660 | "Average training loss: 1.693630337715149\n" 661 | ] 662 | }, 663 | { 664 | "data": { 665 | "application/vnd.jupyter.widget-view+json": { 666 | "model_id": "", 667 | "version_major": 2, 668 | "version_minor": 0 669 | }, 670 | "text/plain": [ 671 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" 672 | ] 673 | }, 674 | "metadata": {}, 675 | "output_type": "display_data" 676 | }, 677 | { 678 | "name": "stdout", 679 | "output_type": "stream", 680 | "text": [ 681 | "Average training loss: 1.696158528327942\n", 682 | "\n" 683 | ] 684 | }, 685 | { 686 | "data": { 687 | "text/plain": [ 688 | "1" 689 | ] 690 | }, 691 | "execution_count": 23, 692 | "metadata": {}, 693 | "output_type": "execute_result" 694 | } 695 | ], 696 | "source": [ 697 | "# The trainer abstracts training, validation and test loops\n", 698 | "\n", 699 | "mnistTrainer=pl.Trainer(gpus=1, max_epochs=epochs)\n", 700 | "\n", 701 | "model = MNISTClassifier()\n", 702 | "mnistTrainer.fit(model)" 703 | ] 704 | }, 705 | { 706 | "cell_type": "code", 707 | "execution_count": 26, 708 | "metadata": {}, 709 | "outputs": [], 710 | "source": [ 711 | "with open(os.path.join('./', 'model.pth'), 'wb') as f:\n", 712 | " torch.save(model.state_dict(), f)" 713 | ] 714 | }, 715 | { 716 | "cell_type": "code", 717 | "execution_count": null, 718 | "metadata": {}, 719 | "outputs": [], 720 | "source": [] 721 | } 722 | ], 723 | "metadata": { 724 | "instance_type": "ml.g4dn.xlarge", 725 | "kernelspec": { 726 | "display_name": "conda_pytorch_p36", 727 | "language": "python", 728 | "name": "conda_pytorch_p36" 729 | }, 730 | "language_info": { 731 | "codemirror_mode": { 732 | "name": "ipython", 733 | "version": 3 734 | }, 735 | "file_extension": ".py", 736 | "mimetype": "text/x-python", 737 | "name": "python", 738 | "nbconvert_exporter": "python", 739 | "pygments_lexer": "ipython3", 740 | "version": "3.6.10" 741 | } 742 | }, 743 | "nbformat": 4, 744 | "nbformat_minor": 4 745 | } 746 | -------------------------------------------------------------------------------- /notebook/lightning_logs/version_7/checkpoints/epoch=2.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aletheia/mnist_pl_sagemaker/9083684ee22f3dbf369c4a1509d2044760628a07/notebook/lightning_logs/version_7/checkpoints/epoch=2.ckpt -------------------------------------------------------------------------------- /notebook/lightning_logs/version_7/events.out.tfevents.1594563548.ip-172-16-16-124.31276.7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aletheia/mnist_pl_sagemaker/9083684ee22f3dbf369c4a1509d2044760628a07/notebook/lightning_logs/version_7/events.out.tfevents.1594563548.ip-172-16-16-124.31276.7 -------------------------------------------------------------------------------- /notebook/lightning_logs/version_7/hparams.yaml: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /notebook/lightning_logs/version_8/checkpoints/epoch=9.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aletheia/mnist_pl_sagemaker/9083684ee22f3dbf369c4a1509d2044760628a07/notebook/lightning_logs/version_8/checkpoints/epoch=9.ckpt -------------------------------------------------------------------------------- /notebook/lightning_logs/version_8/events.out.tfevents.1594563766.ip-172-16-16-124.31276.8: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aletheia/mnist_pl_sagemaker/9083684ee22f3dbf369c4a1509d2044760628a07/notebook/lightning_logs/version_8/events.out.tfevents.1594563766.ip-172-16-16-124.31276.8 -------------------------------------------------------------------------------- /notebook/lightning_logs/version_8/hparams.yaml: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /notebook/lightning_logs/version_9/checkpoints/epoch=6.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aletheia/mnist_pl_sagemaker/9083684ee22f3dbf369c4a1509d2044760628a07/notebook/lightning_logs/version_9/checkpoints/epoch=6.ckpt -------------------------------------------------------------------------------- /notebook/lightning_logs/version_9/events.out.tfevents.1594564263.ip-172-16-16-124.31276.9: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aletheia/mnist_pl_sagemaker/9083684ee22f3dbf369c4a1509d2044760628a07/notebook/lightning_logs/version_9/events.out.tfevents.1594564263.ip-172-16-16-124.31276.9 -------------------------------------------------------------------------------- /notebook/lightning_logs/version_9/hparams.yaml: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /notebook/lightning_mnist_experiment.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# MNIST on SageMaker with PyTorch Lightning" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Download dataset to local folder" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "S3_DATA_BUCKET = 'dataset.mnist'\n", 24 | "S3_TRAINING_DATA = S3_DATA_BUCKET+'/training'\n", 25 | "S3_TESTING_DATA = S3_DATA_BUCKET+'/testing'\n", 26 | "\n", 27 | "DATA_PATH = '../dataset'\n", 28 | "BATCH_SIZE = 128" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": { 35 | "collapsed": true, 36 | "jupyter": { 37 | "outputs_hidden": true 38 | } 39 | }, 40 | "outputs": [], 41 | "source": [ 42 | "#!mkdir -p $DATA_PATH/training\n", 43 | "#!mkdir -p $DATA_PATH/testing\n", 44 | "#!aws s3api get-object --bucket $S3_DATA_BUCKET --key mnist.tar.gz $DATA_PATH/mnist.tar.gz\n", 45 | "#!cd $DATA_PATH && tar xvf mnist.tar.gz && rm -f mnist.tar.gz" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 3, 51 | "metadata": { 52 | "collapsed": true, 53 | "jupyter": { 54 | "outputs_hidden": true 55 | } 56 | }, 57 | "outputs": [ 58 | { 59 | "name": "stdout", 60 | "output_type": "stream", 61 | "text": [ 62 | "Requirement already satisfied: torch in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (1.4.0)\n", 63 | "\u001b[33mWARNING: You are using pip version 20.0.2; however, version 20.1.1 is available.\n", 64 | "You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_p36/bin/python -m pip install --upgrade pip' command.\u001b[0m\n", 65 | "Requirement already satisfied: torchvision in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (0.5.0)\n", 66 | "Requirement already satisfied: numpy in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from torchvision) (1.18.1)\n", 67 | "Requirement already satisfied: six in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from torchvision) (1.14.0)\n", 68 | "Requirement already satisfied: torch in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from torchvision) (1.4.0)\n", 69 | "Requirement already satisfied: pillow>=4.1.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from torchvision) (7.0.0)\n", 70 | "\u001b[33mWARNING: You are using pip version 20.0.2; however, version 20.1.1 is available.\n", 71 | "You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_p36/bin/python -m pip install --upgrade pip' command.\u001b[0m\n", 72 | "Requirement already satisfied: pytorch_lightning in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (0.8.5)\n", 73 | "Requirement already satisfied: tensorboard>=1.14 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from pytorch_lightning) (2.2.2)\n", 74 | "Requirement already satisfied: torch>=1.3 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from pytorch_lightning) (1.4.0)\n", 75 | "Requirement already satisfied: tqdm>=4.41.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from pytorch_lightning) (4.44.1)\n", 76 | "Requirement already satisfied: numpy>=1.16.4 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from pytorch_lightning) (1.18.1)\n", 77 | "Requirement already satisfied: future>=0.17.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from pytorch_lightning) (0.18.2)\n", 78 | "Requirement already satisfied: PyYAML>=5.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from pytorch_lightning) (5.3.1)\n", 79 | "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from tensorboard>=1.14->pytorch_lightning) (0.4.1)\n", 80 | "Requirement already satisfied: google-auth<2,>=1.6.3 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from tensorboard>=1.14->pytorch_lightning) (1.18.0)\n", 81 | "Requirement already satisfied: grpcio>=1.24.3 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from tensorboard>=1.14->pytorch_lightning) (1.30.0)\n", 82 | "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from tensorboard>=1.14->pytorch_lightning) (1.7.0)\n", 83 | "Requirement already satisfied: werkzeug>=0.11.15 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from tensorboard>=1.14->pytorch_lightning) (1.0.1)\n", 84 | "Requirement already satisfied: wheel>=0.26; python_version >= \"3\" in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from tensorboard>=1.14->pytorch_lightning) (0.34.2)\n", 85 | "Requirement already satisfied: protobuf>=3.6.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from tensorboard>=1.14->pytorch_lightning) (3.12.2)\n", 86 | "Requirement already satisfied: requests<3,>=2.21.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from tensorboard>=1.14->pytorch_lightning) (2.23.0)\n", 87 | "Requirement already satisfied: absl-py>=0.4 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from tensorboard>=1.14->pytorch_lightning) (0.9.0)\n", 88 | "Requirement already satisfied: six>=1.10.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from tensorboard>=1.14->pytorch_lightning) (1.14.0)\n", 89 | "Requirement already satisfied: markdown>=2.6.8 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from tensorboard>=1.14->pytorch_lightning) (3.2.2)\n", 90 | "Requirement already satisfied: setuptools>=41.0.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from tensorboard>=1.14->pytorch_lightning) (46.1.3.post20200330)\n", 91 | "Requirement already satisfied: requests-oauthlib>=0.7.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=1.14->pytorch_lightning) (1.3.0)\n", 92 | "Requirement already satisfied: rsa<5,>=3.1.4; python_version >= \"3\" in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from google-auth<2,>=1.6.3->tensorboard>=1.14->pytorch_lightning) (3.4.2)\n", 93 | "Requirement already satisfied: pyasn1-modules>=0.2.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from google-auth<2,>=1.6.3->tensorboard>=1.14->pytorch_lightning) (0.2.8)\n", 94 | "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from google-auth<2,>=1.6.3->tensorboard>=1.14->pytorch_lightning) (4.1.1)\n", 95 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from requests<3,>=2.21.0->tensorboard>=1.14->pytorch_lightning) (1.25.8)\n", 96 | "Requirement already satisfied: idna<3,>=2.5 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from requests<3,>=2.21.0->tensorboard>=1.14->pytorch_lightning) (2.9)\n", 97 | "Requirement already satisfied: certifi>=2017.4.17 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from requests<3,>=2.21.0->tensorboard>=1.14->pytorch_lightning) (2020.4.5.2)\n", 98 | "Requirement already satisfied: chardet<4,>=3.0.2 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from requests<3,>=2.21.0->tensorboard>=1.14->pytorch_lightning) (3.0.4)\n", 99 | "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from markdown>=2.6.8->tensorboard>=1.14->pytorch_lightning) (1.5.0)\n", 100 | "Requirement already satisfied: oauthlib>=3.0.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=1.14->pytorch_lightning) (3.1.0)\n", 101 | "Requirement already satisfied: pyasn1>=0.1.3 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from rsa<5,>=3.1.4; python_version >= \"3\"->google-auth<2,>=1.6.3->tensorboard>=1.14->pytorch_lightning) (0.4.8)\n", 102 | "Requirement already satisfied: zipp>=0.5 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from importlib-metadata; python_version < \"3.8\"->markdown>=2.6.8->tensorboard>=1.14->pytorch_lightning) (2.2.0)\n", 103 | "\u001b[33mWARNING: You are using pip version 20.0.2; however, version 20.1.1 is available.\n", 104 | "You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_p36/bin/python -m pip install --upgrade pip' command.\u001b[0m\n" 105 | ] 106 | } 107 | ], 108 | "source": [ 109 | "# Install libraries if not already installed\n", 110 | "! pip install torch\n", 111 | "! pip install torchvision\n", 112 | "! pip install pytorch_lightning" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 4, 118 | "metadata": { 119 | "collapsed": true, 120 | "jupyter": { 121 | "outputs_hidden": true 122 | } 123 | }, 124 | "outputs": [ 125 | { 126 | "name": "stdout", 127 | "output_type": "stream", 128 | "text": [ 129 | "Requirement already satisfied: ipywidgets in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (7.5.1)\n", 130 | "Requirement already satisfied: nbformat>=4.2.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipywidgets) (5.0.4)\n", 131 | "Requirement already satisfied: widgetsnbextension~=3.5.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipywidgets) (3.5.1)\n", 132 | "Requirement already satisfied: traitlets>=4.3.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipywidgets) (4.3.3)\n", 133 | "Requirement already satisfied: ipython>=4.0.0; python_version >= \"3.3\" in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipywidgets) (7.13.0)\n", 134 | "Requirement already satisfied: ipykernel>=4.5.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipywidgets) (5.1.4)\n", 135 | "Requirement already satisfied: jsonschema!=2.5.0,>=2.4 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from nbformat>=4.2.0->ipywidgets) (3.2.0)\n", 136 | "Requirement already satisfied: ipython-genutils in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from nbformat>=4.2.0->ipywidgets) (0.2.0)\n", 137 | "Requirement already satisfied: jupyter-core in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from nbformat>=4.2.0->ipywidgets) (4.6.3)\n", 138 | "Requirement already satisfied: notebook>=4.4.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from widgetsnbextension~=3.5.0->ipywidgets) (6.0.3)\n", 139 | "Requirement already satisfied: six in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from traitlets>=4.3.1->ipywidgets) (1.14.0)\n", 140 | "Requirement already satisfied: decorator in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from traitlets>=4.3.1->ipywidgets) (4.4.2)\n", 141 | "Requirement already satisfied: pygments in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (2.6.1)\n", 142 | "Requirement already satisfied: pexpect; sys_platform != \"win32\" in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (4.8.0)\n", 143 | "Requirement already satisfied: setuptools>=18.5 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (46.1.3.post20200330)\n", 144 | "Requirement already satisfied: pickleshare in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.7.5)\n", 145 | "Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (3.0.4)\n", 146 | "Requirement already satisfied: jedi>=0.10 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.15.2)\n", 147 | "Requirement already satisfied: backcall in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.1.0)\n", 148 | "Requirement already satisfied: tornado>=4.2 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipykernel>=4.5.1->ipywidgets) (6.0.4)\n", 149 | "Requirement already satisfied: jupyter-client in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from ipykernel>=4.5.1->ipywidgets) (6.1.2)\n", 150 | "Requirement already satisfied: pyrsistent>=0.14.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets) (0.16.0)\n", 151 | "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets) (1.5.0)\n", 152 | "Requirement already satisfied: attrs>=17.4.0 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets) (19.3.0)\n", 153 | "Requirement already satisfied: Send2Trash in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (1.5.0)\n", 154 | "Requirement already satisfied: nbconvert in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (5.6.1)\n", 155 | "Requirement already satisfied: pyzmq>=17 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (18.1.1)\n", 156 | "Requirement already satisfied: terminado>=0.8.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.8.3)\n", 157 | "Requirement already satisfied: prometheus-client in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.7.1)\n", 158 | "Requirement already satisfied: jinja2 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (2.11.1)\n", 159 | "Requirement already satisfied: ptyprocess>=0.5 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from pexpect; sys_platform != \"win32\"->ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.6.0)\n", 160 | "Requirement already satisfied: wcwidth in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.1.9)\n", 161 | "Requirement already satisfied: parso>=0.5.2 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from jedi>=0.10->ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.5.2)\n", 162 | "Requirement already satisfied: python-dateutil>=2.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from jupyter-client->ipykernel>=4.5.1->ipywidgets) (2.8.1)\n", 163 | "Requirement already satisfied: zipp>=0.5 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from importlib-metadata; python_version < \"3.8\"->jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets) (2.2.0)\n", 164 | "Requirement already satisfied: mistune<2,>=0.8.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.8.4)\n", 165 | "Requirement already satisfied: bleach in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (3.1.4)\n", 166 | "Requirement already satisfied: defusedxml in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.6.0)\n", 167 | "Requirement already satisfied: entrypoints>=0.2.2 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.3)\n", 168 | "Requirement already satisfied: testpath in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.4.4)\n", 169 | "Requirement already satisfied: pandocfilters>=1.4.1 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (1.4.2)\n", 170 | "Requirement already satisfied: MarkupSafe>=0.23 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from jinja2->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (1.1.1)\n", 171 | "Requirement already satisfied: webencodings in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.5.1)\n", 172 | "\u001b[33mWARNING: You are using pip version 20.0.2; however, version 20.1.1 is available.\n", 173 | "You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_p36/bin/python -m pip install --upgrade pip' command.\u001b[0m\n", 174 | "Enabling notebook extension jupyter-js-widgets/extension...\n", 175 | " - Validating: \u001b[32mOK\u001b[0m\n" 176 | ] 177 | } 178 | ], 179 | "source": [ 180 | "# Install libraries specific for Jupyter Notebook\n", 181 | "! pip install ipywidgets\n", 182 | "! jupyter nbextension enable --py widgetsnbextension" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 21, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "import os\n", 192 | "import math\n", 193 | "import random as rn\n", 194 | "\n", 195 | "import numpy as np # linear algebra\n", 196 | "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n", 197 | "from sklearn.model_selection import train_test_split\n", 198 | "from sklearn.metrics import accuracy_score\n", 199 | "from PIL import Image, ImageFile\n", 200 | "\n", 201 | "import torch\n", 202 | "import torch.nn as nn\n", 203 | "from torch.nn import functional as F\n", 204 | "from torch.utils.data import DataLoader\n", 205 | "from torchvision import transforms as T, datasets\n", 206 | "import pytorch_lightning as pl\n", 207 | "from pytorch_lightning import Trainer\n", 208 | "from pytorch_lightning.callbacks import ModelCheckpoint\n" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 6, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "# Import libraries specific for Jupyter notebook visualization\n", 218 | "from matplotlib import pyplot as plt, image\n", 219 | "from PIL import Image\n", 220 | "%matplotlib inline" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 8, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "#train_data_dir = DATA_PATH+'/training'\n", 230 | "#dataset = datasets.ImageFolder(train_data_dir)\n", 231 | "#train_set, val_set = torch.utils.data.random_split(dataset, [55000, 5000])" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "metadata": {}, 237 | "source": [ 238 | "Display a few images to have an idea about the input" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 9, 244 | "metadata": { 245 | "collapsed": true, 246 | "jupyter": { 247 | "outputs_hidden": true 248 | } 249 | }, 250 | "outputs": [], 251 | "source": [ 252 | "#plt.imshow(image.imread(train_set.dataset.imgs[0][0]))\n", 253 | "#plt.figure()\n", 254 | "#plt.imshow(image.imread(train_set.dataset.imgs[7000][0]))\n", 255 | "#plt.figure()\n", 256 | "#plt.imshow(image.imread(train_set.dataset.imgs[20000][0]))\n", 257 | "#plt.figure()" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 10, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "#fix random seed\n", 267 | "os.environ['PYTHONHASHSEED'] = '0'\n", 268 | "np.random.seed(42)\n", 269 | "rn.seed(12345)\n", 270 | "torch.manual_seed(2020)\n", 271 | "torch.cuda.manual_seed(2020)\n", 272 | "torch.cuda.manual_seed_all(2020)\n", 273 | "torch.backends.cudnn.deterministic = True" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 11, 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "num_workers = 4\n", 283 | "epochs = 10\n", 284 | "validation_size = .3\n", 285 | "batch_size = 128" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": 12, 291 | "metadata": {}, 292 | "outputs": [], 293 | "source": [ 294 | "data_dir = DATA_PATH+'/training'" 295 | ] 296 | }, 297 | { 298 | "cell_type": "markdown", 299 | "metadata": {}, 300 | "source": [ 301 | "## Create model" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": 22, 307 | "metadata": {}, 308 | "outputs": [], 309 | "source": [ 310 | "class MNISTClassifier(pl.LightningModule):\n", 311 | " def __init__(self):\n", 312 | " super(MNISTClassifier, self).__init__()\n", 313 | " self.conv_layer_1 = torch.nn.Sequential(\n", 314 | " torch.nn.Conv2d(3,28, kernel_size=5),\n", 315 | " torch.nn.ReLU(),\n", 316 | " torch.nn.MaxPool2d(kernel_size=2))\n", 317 | " self.conv_layer_2 = torch.nn.Sequential(\n", 318 | " torch.nn.Conv2d(28,10, kernel_size=2),\n", 319 | " torch.nn.ReLU(),\n", 320 | " torch.nn.MaxPool2d(kernel_size=2))\n", 321 | " self.dropout1=torch.nn.Dropout(0.25)\n", 322 | " self.fully_connected_1=torch.nn.Linear(250,18)\n", 323 | " self.dropout2=torch.nn.Dropout(0.08)\n", 324 | " self.fully_connected_2=torch.nn.Linear(18,10)\n", 325 | "\n", 326 | " def load_split_train_test(self, datadir, valid_size = .2):\n", 327 | " train_transforms = T.Compose([T.RandomHorizontalFlip(), \n", 328 | " T.ToTensor(),\n", 329 | " T.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])\n", 330 | "\n", 331 | " test_transforms = T.Compose([T.ToTensor(),T.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])\n", 332 | "\n", 333 | " train_data = datasets.ImageFolder(datadir, transform=train_transforms)\n", 334 | " test_data = datasets.ImageFolder(datadir, transform=test_transforms)\n", 335 | "\n", 336 | " num_train = len(train_data)\n", 337 | " indices = list(range(num_train))\n", 338 | " split = int(np.floor(valid_size * num_train))\n", 339 | " np.random.shuffle(indices)\n", 340 | " from torch.utils.data.sampler import SubsetRandomSampler\n", 341 | " train_idx, test_idx = indices[split:], indices[:split]\n", 342 | " train_sampler = SubsetRandomSampler(train_idx)\n", 343 | " test_sampler = SubsetRandomSampler(test_idx)\n", 344 | " trainloader = torch.utils.data.DataLoader(train_data, sampler=train_sampler, batch_size=batch_size, num_workers=num_workers)\n", 345 | " testloader = torch.utils.data.DataLoader(test_data, sampler=test_sampler, batch_size=batch_size, num_workers=num_workers)\n", 346 | " return trainloader, testloader\n", 347 | " \n", 348 | " def prepare_data(self):\n", 349 | " self.train_loader, self.val_loader = self.load_split_train_test(data_dir, validation_size)\n", 350 | " \n", 351 | " def train_dataloader(self):\n", 352 | " return self.train_loader\n", 353 | " \n", 354 | " def val_dataloader(self):\n", 355 | " return self.val_loader\n", 356 | " \n", 357 | "# def test_dataloader(self):\n", 358 | "# return DataLoader(MNIST(os.getcwd(), train=False, download=False, transform=transform.ToTensor()), batch_size=128)\n", 359 | " \n", 360 | " def forward(self,x):\n", 361 | " x=self.conv_layer_1(x)\n", 362 | " x=self.conv_layer_2(x)\n", 363 | " x=self.dropout1(x)\n", 364 | " x=torch.relu(self.fully_connected_1(x.view(x.size(0),-1)))\n", 365 | " x=F.leaky_relu(self.dropout2(x))\n", 366 | " return F.softmax(self.fully_connected_2(x), dim=1)\n", 367 | " \n", 368 | " def configure_optimizers(self):\n", 369 | " return torch.optim.Adam(self.parameters())\n", 370 | " \n", 371 | " def training_step(self, batch, batch_idx):\n", 372 | " \n", 373 | " # Get input and output from batch\n", 374 | " x, labels = batch\n", 375 | " \n", 376 | " # Compute prediction through the network\n", 377 | " prediction = self.forward(x)\n", 378 | " \n", 379 | " loss = F.nll_loss(prediction, labels)\n", 380 | " \n", 381 | " # Logs training loss\n", 382 | " logs={'train_loss':loss}\n", 383 | " \n", 384 | " output = {\n", 385 | " # This is required in training to be used by backpropagation\n", 386 | " 'loss':loss,\n", 387 | " # This is optional for logging pourposes\n", 388 | " 'log':logs\n", 389 | " }\n", 390 | " \n", 391 | " return output\n", 392 | " \n", 393 | " def validation_step(self, batch, batch_idx):\n", 394 | " x, labels = batch\n", 395 | " prediction = self.forward(x)\n", 396 | " return {\n", 397 | " 'val_loss': F.cross_entropy(prediction, labels)\n", 398 | " }\n", 399 | " \n", 400 | " def validation_epoch_end(self, outputs):\n", 401 | " val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean()\n", 402 | " return {'val_loss': val_loss_mean}\n", 403 | "\n", 404 | " \n", 405 | " def validation_end(self, outputs):\n", 406 | " avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()\n", 407 | " print('Average training loss: '+str(avg_loss.item()))\n", 408 | " logs = {'val_loss':avg_loss}\n", 409 | " return {\n", 410 | " 'avg_val_loss':avg_loss,\n", 411 | " 'log':logs\n", 412 | " }" 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": 23, 418 | "metadata": {}, 419 | "outputs": [ 420 | { 421 | "name": "stderr", 422 | "output_type": "stream", 423 | "text": [ 424 | "GPU available: True, used: True\n", 425 | "TPU available: False, using: 0 TPU cores\n", 426 | "CUDA_VISIBLE_DEVICES: [0]\n", 427 | "\n", 428 | " | Name | Type | Params\n", 429 | "-------------------------------------------------\n", 430 | "0 | conv_layer_1 | Sequential | 2 K \n", 431 | "1 | conv_layer_2 | Sequential | 1 K \n", 432 | "2 | dropout1 | Dropout | 0 \n", 433 | "3 | fully_connected_1 | Linear | 4 K \n", 434 | "4 | dropout2 | Dropout | 0 \n", 435 | "5 | fully_connected_2 | Linear | 190 \n" 436 | ] 437 | }, 438 | { 439 | "data": { 440 | "application/vnd.jupyter.widget-view+json": { 441 | "model_id": "", 442 | "version_major": 2, 443 | "version_minor": 0 444 | }, 445 | "text/plain": [ 446 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…" 447 | ] 448 | }, 449 | "metadata": {}, 450 | "output_type": "display_data" 451 | }, 452 | { 453 | "name": "stdout", 454 | "output_type": "stream", 455 | "text": [ 456 | "Average training loss: 2.302227735519409\n", 457 | "\r" 458 | ] 459 | }, 460 | { 461 | "data": { 462 | "application/vnd.jupyter.widget-view+json": { 463 | "model_id": "51e3aaaa31df49be8101f9725da2be9e", 464 | "version_major": 2, 465 | "version_minor": 0 466 | }, 467 | "text/plain": [ 468 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…" 469 | ] 470 | }, 471 | "metadata": {}, 472 | "output_type": "display_data" 473 | }, 474 | { 475 | "data": { 476 | "application/vnd.jupyter.widget-view+json": { 477 | "model_id": "", 478 | "version_major": 2, 479 | "version_minor": 0 480 | }, 481 | "text/plain": [ 482 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" 483 | ] 484 | }, 485 | "metadata": {}, 486 | "output_type": "display_data" 487 | }, 488 | { 489 | "name": "stdout", 490 | "output_type": "stream", 491 | "text": [ 492 | "Average training loss: 1.8066515922546387\n" 493 | ] 494 | }, 495 | { 496 | "data": { 497 | "application/vnd.jupyter.widget-view+json": { 498 | "model_id": "", 499 | "version_major": 2, 500 | "version_minor": 0 501 | }, 502 | "text/plain": [ 503 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" 504 | ] 505 | }, 506 | "metadata": {}, 507 | "output_type": "display_data" 508 | }, 509 | { 510 | "name": "stdout", 511 | "output_type": "stream", 512 | "text": [ 513 | "Average training loss: 1.7187001705169678\n" 514 | ] 515 | }, 516 | { 517 | "data": { 518 | "application/vnd.jupyter.widget-view+json": { 519 | "model_id": "", 520 | "version_major": 2, 521 | "version_minor": 0 522 | }, 523 | "text/plain": [ 524 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" 525 | ] 526 | }, 527 | "metadata": {}, 528 | "output_type": "display_data" 529 | }, 530 | { 531 | "name": "stdout", 532 | "output_type": "stream", 533 | "text": [ 534 | "Average training loss: 1.705643892288208\n" 535 | ] 536 | }, 537 | { 538 | "data": { 539 | "application/vnd.jupyter.widget-view+json": { 540 | "model_id": "", 541 | "version_major": 2, 542 | "version_minor": 0 543 | }, 544 | "text/plain": [ 545 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" 546 | ] 547 | }, 548 | "metadata": {}, 549 | "output_type": "display_data" 550 | }, 551 | { 552 | "name": "stdout", 553 | "output_type": "stream", 554 | "text": [ 555 | "Average training loss: 1.702525019645691\n" 556 | ] 557 | }, 558 | { 559 | "data": { 560 | "application/vnd.jupyter.widget-view+json": { 561 | "model_id": "", 562 | "version_major": 2, 563 | "version_minor": 0 564 | }, 565 | "text/plain": [ 566 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" 567 | ] 568 | }, 569 | "metadata": {}, 570 | "output_type": "display_data" 571 | }, 572 | { 573 | "name": "stdout", 574 | "output_type": "stream", 575 | "text": [ 576 | "Average training loss: 1.701408863067627\n" 577 | ] 578 | }, 579 | { 580 | "data": { 581 | "application/vnd.jupyter.widget-view+json": { 582 | "model_id": "", 583 | "version_major": 2, 584 | "version_minor": 0 585 | }, 586 | "text/plain": [ 587 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" 588 | ] 589 | }, 590 | "metadata": {}, 591 | "output_type": "display_data" 592 | }, 593 | { 594 | "name": "stdout", 595 | "output_type": "stream", 596 | "text": [ 597 | "Average training loss: 1.69669771194458\n" 598 | ] 599 | }, 600 | { 601 | "data": { 602 | "application/vnd.jupyter.widget-view+json": { 603 | "model_id": "", 604 | "version_major": 2, 605 | "version_minor": 0 606 | }, 607 | "text/plain": [ 608 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" 609 | ] 610 | }, 611 | "metadata": {}, 612 | "output_type": "display_data" 613 | }, 614 | { 615 | "name": "stdout", 616 | "output_type": "stream", 617 | "text": [ 618 | "Average training loss: 1.697007179260254\n" 619 | ] 620 | }, 621 | { 622 | "data": { 623 | "application/vnd.jupyter.widget-view+json": { 624 | "model_id": "", 625 | "version_major": 2, 626 | "version_minor": 0 627 | }, 628 | "text/plain": [ 629 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" 630 | ] 631 | }, 632 | "metadata": {}, 633 | "output_type": "display_data" 634 | }, 635 | { 636 | "name": "stdout", 637 | "output_type": "stream", 638 | "text": [ 639 | "Average training loss: 1.6952210664749146\n" 640 | ] 641 | }, 642 | { 643 | "data": { 644 | "application/vnd.jupyter.widget-view+json": { 645 | "model_id": "", 646 | "version_major": 2, 647 | "version_minor": 0 648 | }, 649 | "text/plain": [ 650 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" 651 | ] 652 | }, 653 | "metadata": {}, 654 | "output_type": "display_data" 655 | }, 656 | { 657 | "name": "stdout", 658 | "output_type": "stream", 659 | "text": [ 660 | "Average training loss: 1.693630337715149\n" 661 | ] 662 | }, 663 | { 664 | "data": { 665 | "application/vnd.jupyter.widget-view+json": { 666 | "model_id": "", 667 | "version_major": 2, 668 | "version_minor": 0 669 | }, 670 | "text/plain": [ 671 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" 672 | ] 673 | }, 674 | "metadata": {}, 675 | "output_type": "display_data" 676 | }, 677 | { 678 | "name": "stdout", 679 | "output_type": "stream", 680 | "text": [ 681 | "Average training loss: 1.696158528327942\n", 682 | "\n" 683 | ] 684 | }, 685 | { 686 | "data": { 687 | "text/plain": [ 688 | "1" 689 | ] 690 | }, 691 | "execution_count": 23, 692 | "metadata": {}, 693 | "output_type": "execute_result" 694 | } 695 | ], 696 | "source": [ 697 | "# The trainer abstracts training, validation and test loops\n", 698 | "\n", 699 | "mnistTrainer=pl.Trainer(gpus=1, max_epochs=epochs)\n", 700 | "\n", 701 | "model = MNISTClassifier()\n", 702 | "mnistTrainer.fit(model)" 703 | ] 704 | }, 705 | { 706 | "cell_type": "code", 707 | "execution_count": 26, 708 | "metadata": {}, 709 | "outputs": [], 710 | "source": [ 711 | "with open(os.path.join('./', 'model.pth'), 'wb') as f:\n", 712 | " torch.save(model.state_dict(), f)" 713 | ] 714 | }, 715 | { 716 | "cell_type": "code", 717 | "execution_count": null, 718 | "metadata": {}, 719 | "outputs": [], 720 | "source": [] 721 | } 722 | ], 723 | "metadata": { 724 | "instance_type": "ml.g4dn.xlarge", 725 | "kernelspec": { 726 | "display_name": "conda_pytorch_p36", 727 | "language": "python", 728 | "name": "conda_pytorch_p36" 729 | }, 730 | "language_info": { 731 | "codemirror_mode": { 732 | "name": "ipython", 733 | "version": 3 734 | }, 735 | "file_extension": ".py", 736 | "mimetype": "text/x-python", 737 | "name": "python", 738 | "nbconvert_exporter": "python", 739 | "pygments_lexer": "ipython3", 740 | "version": "3.6.10" 741 | } 742 | }, 743 | "nbformat": 4, 744 | "nbformat_minor": 4 745 | } 746 | -------------------------------------------------------------------------------- /notebook/sagemaker_deploy.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Sagemaker setup of MNIST classifier" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "Import SageMaker [Python SDK](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html) into our project, and **Pytorch Estimator**" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 3, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import sagemaker\n", 24 | "from sagemaker.pytorch import PyTorch" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "Create a session object, which initializes data related to execution role" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 4, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "sagemaker_session = sagemaker.Session()\n", 41 | "\n", 42 | "bucket = 's3://dataset.mnist'\n", 43 | "role = sagemaker.get_execution_role()" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 11, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "estimator = PyTorch(entry_point='train.py',\n", 53 | " source_dir='../code',\n", 54 | " role=role,\n", 55 | " framework_version='1.4.0',\n", 56 | " train_instance_count=1,\n", 57 | " train_instance_type='ml.p2.xlarge',\n", 58 | " hyperparameters={\n", 59 | " 'epochs': 6,\n", 60 | " 'batch-size': 128,\n", 61 | " })\n" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 12, 67 | "metadata": {}, 68 | "outputs": [ 69 | { 70 | "name": "stderr", 71 | "output_type": "stream", 72 | "text": [ 73 | "'create_image_uri' will be deprecated in favor of 'ImageURIProvider' class in SageMaker Python SDK v2.\n", 74 | "'s3_input' class will be renamed to 'TrainingInput' in SageMaker Python SDK v2.\n", 75 | "'s3_input' class will be renamed to 'TrainingInput' in SageMaker Python SDK v2.\n", 76 | "'create_image_uri' will be deprecated in favor of 'ImageURIProvider' class in SageMaker Python SDK v2.\n" 77 | ] 78 | }, 79 | { 80 | "name": "stdout", 81 | "output_type": "stream", 82 | "text": [ 83 | "2020-07-13 13:04:37 Starting - Starting the training job...\n", 84 | "2020-07-13 13:04:39 Starting - Launching requested ML instances......\n", 85 | "2020-07-13 13:05:46 Starting - Preparing the instances for training......\n", 86 | "2020-07-13 13:07:03 Downloading - Downloading input data....................................\n", 87 | "2020-07-13 13:13:11 Training - Training image download completed. Training in progress..\u001b[34mbash: cannot set terminal process group (-1): Inappropriate ioctl for device\u001b[0m\n", 88 | "\u001b[34mbash: no job control in this shell\u001b[0m\n", 89 | "\u001b[34m2020-07-13 13:13:13,037 sagemaker-containers INFO Imported framework sagemaker_pytorch_container.training\u001b[0m\n", 90 | "\u001b[34m2020-07-13 13:13:13,061 sagemaker_pytorch_container.training INFO Block until all host DNS lookups succeed.\u001b[0m\n", 91 | "\u001b[34m2020-07-13 13:13:14,484 sagemaker_pytorch_container.training INFO Invoking user training script.\u001b[0m\n", 92 | "\u001b[34m2020-07-13 13:13:14,838 sagemaker-containers INFO Module default_user_module_name does not provide a setup.py. \u001b[0m\n", 93 | "\u001b[34mGenerating setup.py\u001b[0m\n", 94 | "\u001b[34m2020-07-13 13:13:14,839 sagemaker-containers INFO Generating setup.cfg\u001b[0m\n", 95 | "\u001b[34m2020-07-13 13:13:14,839 sagemaker-containers INFO Generating MANIFEST.in\u001b[0m\n", 96 | "\u001b[34m2020-07-13 13:13:14,839 sagemaker-containers INFO Installing module with the following command:\u001b[0m\n", 97 | "\u001b[34m/opt/conda/bin/python -m pip install . -r requirements.txt\u001b[0m\n", 98 | "\u001b[34mProcessing /tmp/tmp7fk17u99/module_dir\u001b[0m\n", 99 | "\u001b[34mCollecting absl-py==0.9.0\n", 100 | " Downloading absl-py-0.9.0.tar.gz (104 kB)\u001b[0m\n", 101 | "\u001b[34mCollecting boto3==1.14.20\n", 102 | " Downloading boto3-1.14.20-py2.py3-none-any.whl (128 kB)\u001b[0m\n", 103 | "\u001b[34mCollecting botocore==1.17.20\n", 104 | " Downloading botocore-1.17.20-py2.py3-none-any.whl (6.3 MB)\u001b[0m\n", 105 | "\u001b[34mCollecting cachetools==4.1.1\n", 106 | " Downloading cachetools-4.1.1-py3-none-any.whl (10 kB)\u001b[0m\n", 107 | "\u001b[34mCollecting certifi==2020.6.20\n", 108 | " Downloading certifi-2020.6.20-py2.py3-none-any.whl (156 kB)\u001b[0m\n", 109 | "\u001b[34mRequirement already satisfied: chardet==3.0.4 in /opt/conda/lib/python3.6/site-packages (from -r requirements.txt (line 7)) (3.0.4)\u001b[0m\n", 110 | "\u001b[34mRequirement already satisfied: docutils==0.15.2 in /opt/conda/lib/python3.6/site-packages (from -r requirements.txt (line 8)) (0.15.2)\u001b[0m\n", 111 | "\u001b[34mCollecting future==0.18.2\n", 112 | " Downloading future-0.18.2.tar.gz (829 kB)\u001b[0m\n", 113 | "\u001b[34mCollecting google-auth-oauthlib==0.4.1\n", 114 | " Downloading google_auth_oauthlib-0.4.1-py2.py3-none-any.whl (18 kB)\u001b[0m\n", 115 | "\u001b[34mCollecting google-auth==1.18.0\n", 116 | " Downloading google_auth-1.18.0-py2.py3-none-any.whl (90 kB)\u001b[0m\n", 117 | "\u001b[34mCollecting grpcio==1.30.0\n", 118 | " Downloading grpcio-1.30.0-cp36-cp36m-manylinux2010_x86_64.whl (3.0 MB)\u001b[0m\n", 119 | "\u001b[34mCollecting idna==2.10\n", 120 | " Downloading idna-2.10-py2.py3-none-any.whl (58 kB)\u001b[0m\n", 121 | "\u001b[34mCollecting importlib-metadata==1.7.0\n", 122 | " Downloading importlib_metadata-1.7.0-py2.py3-none-any.whl (31 kB)\u001b[0m\n", 123 | "\u001b[34mRequirement already satisfied: jmespath==0.10.0 in /opt/conda/lib/python3.6/site-packages (from -r requirements.txt (line 15)) (0.10.0)\u001b[0m\n", 124 | "\u001b[34mCollecting markdown==3.2.2\n", 125 | " Downloading Markdown-3.2.2-py3-none-any.whl (88 kB)\u001b[0m\n", 126 | "\u001b[34mCollecting numpy==1.19.0\n", 127 | " Downloading numpy-1.19.0-cp36-cp36m-manylinux2010_x86_64.whl (14.6 MB)\u001b[0m\n", 128 | "\u001b[34mCollecting oauthlib==3.1.0\n", 129 | " Downloading oauthlib-3.1.0-py2.py3-none-any.whl (147 kB)\u001b[0m\n", 130 | "\u001b[34mRequirement already satisfied: packaging==20.4 in /opt/conda/lib/python3.6/site-packages (from -r requirements.txt (line 19)) (20.4)\u001b[0m\n", 131 | "\u001b[34mRequirement already satisfied: protobuf3-to-dict==0.1.5 in /opt/conda/lib/python3.6/site-packages (from -r requirements.txt (line 20)) (0.1.5)\u001b[0m\n", 132 | "\u001b[34mRequirement already satisfied: protobuf==3.12.2 in /opt/conda/lib/python3.6/site-packages (from -r requirements.txt (line 21)) (3.12.2)\u001b[0m\n", 133 | "\u001b[34mCollecting pyasn1-modules==0.2.8\n", 134 | " Downloading pyasn1_modules-0.2.8-py2.py3-none-any.whl (155 kB)\u001b[0m\n", 135 | "\u001b[34mRequirement already satisfied: pyasn1==0.4.8 in /opt/conda/lib/python3.6/site-packages (from -r requirements.txt (line 23)) (0.4.8)\u001b[0m\n", 136 | "\u001b[34mRequirement already satisfied: pyparsing==2.4.7 in /opt/conda/lib/python3.6/site-packages (from -r requirements.txt (line 24)) (2.4.7)\u001b[0m\n", 137 | "\u001b[34mRequirement already satisfied: python-dateutil==2.8.1 in /opt/conda/lib/python3.6/site-packages (from -r requirements.txt (line 25)) (2.8.1)\u001b[0m\n", 138 | "\u001b[34mCollecting pytorch-lightning==0.8.5\n", 139 | " Downloading pytorch_lightning-0.8.5-py3-none-any.whl (313 kB)\u001b[0m\n", 140 | "\u001b[34mRequirement already satisfied: pyyaml==5.3.1 in /opt/conda/lib/python3.6/site-packages (from -r requirements.txt (line 27)) (5.3.1)\u001b[0m\n", 141 | "\u001b[34mCollecting requests-oauthlib==1.3.0\n", 142 | " Downloading requests_oauthlib-1.3.0-py2.py3-none-any.whl (23 kB)\u001b[0m\n", 143 | "\u001b[34mCollecting requests==2.24.0\n", 144 | " Downloading requests-2.24.0-py2.py3-none-any.whl (61 kB)\u001b[0m\n", 145 | "\u001b[34mCollecting rsa==4.6\n", 146 | " Downloading rsa-4.6-py3-none-any.whl (47 kB)\u001b[0m\n", 147 | "\u001b[34mRequirement already satisfied: s3transfer==0.3.3 in /opt/conda/lib/python3.6/site-packages (from -r requirements.txt (line 31)) (0.3.3)\u001b[0m\n", 148 | "\u001b[34mCollecting sagemaker==1.69.0\n", 149 | " Downloading sagemaker-1.69.0.tar.gz (296 kB)\u001b[0m\n", 150 | "\u001b[34mCollecting scipy==1.5.1\n", 151 | " Downloading scipy-1.5.1-cp36-cp36m-manylinux1_x86_64.whl (25.9 MB)\u001b[0m\n", 152 | "\u001b[34mCollecting six==1.15.0\n", 153 | " Downloading six-1.15.0-py2.py3-none-any.whl (10 kB)\u001b[0m\n", 154 | "\u001b[34mCollecting smdebug-rulesconfig==0.1.4\n", 155 | " Downloading smdebug_rulesconfig-0.1.4-py2.py3-none-any.whl (10 kB)\u001b[0m\n", 156 | "\u001b[34mCollecting tensorboard-plugin-wit==1.7.0\n", 157 | " Downloading tensorboard_plugin_wit-1.7.0-py3-none-any.whl (779 kB)\u001b[0m\n", 158 | "\u001b[34mCollecting tensorboard==2.2.2\n", 159 | " Downloading tensorboard-2.2.2-py3-none-any.whl (3.0 MB)\u001b[0m\n", 160 | "\u001b[34mCollecting torch==1.5.1\n", 161 | " Downloading torch-1.5.1-cp36-cp36m-manylinux1_x86_64.whl (753.2 MB)\u001b[0m\n", 162 | "\u001b[34mCollecting tqdm==4.47.0\n", 163 | " Downloading tqdm-4.47.0-py2.py3-none-any.whl (66 kB)\u001b[0m\n", 164 | "\u001b[34mCollecting urllib3==1.25.9\n", 165 | " Downloading urllib3-1.25.9-py2.py3-none-any.whl (126 kB)\u001b[0m\n", 166 | "\u001b[34mRequirement already satisfied: werkzeug==1.0.1 in /opt/conda/lib/python3.6/site-packages (from -r requirements.txt (line 41)) (1.0.1)\u001b[0m\n", 167 | "\u001b[34mRequirement already satisfied: wheel==0.34.2 in /opt/conda/lib/python3.6/site-packages (from -r requirements.txt (line 42)) (0.34.2)\u001b[0m\n", 168 | "\u001b[34mRequirement already satisfied: zipp==3.1.0 in /opt/conda/lib/python3.6/site-packages (from -r requirements.txt (line 43)) (3.1.0)\u001b[0m\n", 169 | "\u001b[34mRequirement already satisfied: setuptools>=40.3.0 in /opt/conda/lib/python3.6/site-packages (from google-auth==1.18.0->-r requirements.txt (line 11)) (46.4.0.post20200518)\u001b[0m\n", 170 | "\u001b[34mBuilding wheels for collected packages: absl-py, future, sagemaker, default-user-module-name\n", 171 | " Building wheel for absl-py (setup.py): started\n", 172 | " Building wheel for absl-py (setup.py): finished with status 'done'\n", 173 | " Created wheel for absl-py: filename=absl_py-0.9.0-py3-none-any.whl size=121931 sha256=d36ea27ba24b2d652ab5e36074d533d28496efe18c0dba78e266b6776edc2149\n", 174 | " Stored in directory: /root/.cache/pip/wheels/c3/af/84/3962a6af7b4ab336e951b7877dcfb758cf94548bb1771e0679\n", 175 | " Building wheel for future (setup.py): started\u001b[0m\n", 176 | "\u001b[34m Building wheel for future (setup.py): finished with status 'done'\n", 177 | " Created wheel for future: filename=future-0.18.2-py3-none-any.whl size=491058 sha256=bf241790f31df521b0568d9548e7d17ecee1c36e98cc662ebc2bb1ed0c712b0d\n", 178 | " Stored in directory: /root/.cache/pip/wheels/6e/9c/ed/4499c9865ac1002697793e0ae05ba6be33553d098f3347fb94\n", 179 | " Building wheel for sagemaker (setup.py): started\n", 180 | " Building wheel for sagemaker (setup.py): finished with status 'done'\n", 181 | " Created wheel for sagemaker: filename=sagemaker-1.69.0-py2.py3-none-any.whl size=384828 sha256=6932dd7ec9bf42671f101eebb8a13f776c90d690f211f4b44c532e9918feb08e\n", 182 | " Stored in directory: /root/.cache/pip/wheels/43/11/be/445ade3de346b409945d1275aa2e46dd79fb6dcedff0417d1d\n", 183 | " Building wheel for default-user-module-name (setup.py): started\u001b[0m\n", 184 | "\u001b[34m Building wheel for default-user-module-name (setup.py): finished with status 'done'\n", 185 | " Created wheel for default-user-module-name: filename=default_user_module_name-1.0.0-py2.py3-none-any.whl size=8193 sha256=fbd6f52066b476ceb4fe96dcc35abd4b0382592c00e33169301b54de51ce9fab\n", 186 | " Stored in directory: /tmp/pip-ephem-wheel-cache-qdc6v6xn/wheels/8d/f2/d4/d9825ea9d81c8dd3b114e917ff0864d00eb14a9e06d85f95bc\u001b[0m\n", 187 | "\u001b[34mSuccessfully built absl-py future sagemaker default-user-module-name\u001b[0m\n", 188 | "\u001b[34mERROR: torchvision 0.5.0 has requirement torch==1.4.0, but you'll have torch 1.5.1 which is incompatible.\u001b[0m\n", 189 | "\u001b[34mERROR: awscli 1.18.73 has requirement botocore==1.16.23, but you'll have botocore 1.17.20 which is incompatible.\u001b[0m\n", 190 | "\u001b[34mERROR: awscli 1.18.73 has requirement rsa<=3.5.0,>=3.1.2, but you'll have rsa 4.6 which is incompatible.\u001b[0m\n", 191 | "\u001b[34mInstalling collected packages: six, absl-py, urllib3, botocore, boto3, cachetools, certifi, future, oauthlib, idna, requests, requests-oauthlib, rsa, pyasn1-modules, google-auth, google-auth-oauthlib, grpcio, importlib-metadata, markdown, numpy, tensorboard-plugin-wit, tensorboard, tqdm, torch, pytorch-lightning, scipy, smdebug-rulesconfig, sagemaker, default-user-module-name\n", 192 | " Attempting uninstall: six\n", 193 | " Found existing installation: six 1.14.0\n", 194 | " Uninstalling six-1.14.0:\u001b[0m\n", 195 | "\u001b[34m Successfully uninstalled six-1.14.0\n", 196 | " Attempting uninstall: urllib3\n", 197 | " Found existing installation: urllib3 1.25.8\n", 198 | " Uninstalling urllib3-1.25.8:\n", 199 | " Successfully uninstalled urllib3-1.25.8\n", 200 | " Attempting uninstall: botocore\n", 201 | " Found existing installation: botocore 1.16.23\n", 202 | " Uninstalling botocore-1.16.23:\n", 203 | " Successfully uninstalled botocore-1.16.23\u001b[0m\n", 204 | "\u001b[34m Attempting uninstall: boto3\n", 205 | " Found existing installation: boto3 1.13.23\n", 206 | " Uninstalling boto3-1.13.23:\n", 207 | " Successfully uninstalled boto3-1.13.23\n", 208 | " Attempting uninstall: certifi\n", 209 | " Found existing installation: certifi 2020.4.5.1\n", 210 | " Uninstalling certifi-2020.4.5.1:\n", 211 | " Successfully uninstalled certifi-2020.4.5.1\n", 212 | " Attempting uninstall: future\n", 213 | " Found existing installation: future 0.17.1\n", 214 | " Uninstalling future-0.17.1:\n", 215 | " Successfully uninstalled future-0.17.1\u001b[0m\n", 216 | "\u001b[34m Attempting uninstall: idna\n", 217 | " Found existing installation: idna 2.8\n", 218 | " Uninstalling idna-2.8:\n", 219 | " Successfully uninstalled idna-2.8\n", 220 | " Attempting uninstall: requests\n", 221 | " Found existing installation: requests 2.22.0\n", 222 | " Uninstalling requests-2.22.0:\n", 223 | " Successfully uninstalled requests-2.22.0\n", 224 | " Attempting uninstall: rsa\n", 225 | " Found existing installation: rsa 3.4.2\n", 226 | " Uninstalling rsa-3.4.2:\n", 227 | " Successfully uninstalled rsa-3.4.2\u001b[0m\n", 228 | "\u001b[34m Attempting uninstall: importlib-metadata\n", 229 | " Found existing installation: importlib-metadata 1.6.0\n", 230 | " Uninstalling importlib-metadata-1.6.0:\n", 231 | " Successfully uninstalled importlib-metadata-1.6.0\n", 232 | " Attempting uninstall: numpy\n", 233 | " Found existing installation: numpy 1.16.4\n", 234 | " Uninstalling numpy-1.16.4:\n", 235 | " Successfully uninstalled numpy-1.16.4\u001b[0m\n", 236 | "\u001b[34m Attempting uninstall: tqdm\n", 237 | " Found existing installation: tqdm 4.42.1\n", 238 | " Uninstalling tqdm-4.42.1:\n", 239 | " Successfully uninstalled tqdm-4.42.1\n", 240 | " Attempting uninstall: torch\n", 241 | " Found existing installation: torch 1.4.0\n", 242 | " Uninstalling torch-1.4.0:\n", 243 | " Successfully uninstalled torch-1.4.0\u001b[0m\n", 244 | "\u001b[34m Attempting uninstall: scipy\n", 245 | " Found existing installation: scipy 1.2.2\n", 246 | " Uninstalling scipy-1.2.2:\n", 247 | " Successfully uninstalled scipy-1.2.2\u001b[0m\n", 248 | "\u001b[34m Attempting uninstall: smdebug-rulesconfig\n", 249 | " Found existing installation: smdebug-rulesconfig 0.1.2\n", 250 | " Uninstalling smdebug-rulesconfig-0.1.2:\n", 251 | " Successfully uninstalled smdebug-rulesconfig-0.1.2\n", 252 | " Attempting uninstall: sagemaker\n", 253 | " Found existing installation: sagemaker 1.50.17\n", 254 | " Uninstalling sagemaker-1.50.17:\n", 255 | " Successfully uninstalled sagemaker-1.50.17\u001b[0m\n", 256 | "\u001b[34mSuccessfully installed absl-py-0.9.0 boto3-1.14.20 botocore-1.17.20 cachetools-4.1.1 certifi-2020.6.20 default-user-module-name-1.0.0 future-0.18.2 google-auth-1.18.0 google-auth-oauthlib-0.4.1 grpcio-1.30.0 idna-2.10 importlib-metadata-1.7.0 markdown-3.2.2 numpy-1.19.0 oauthlib-3.1.0 pyasn1-modules-0.2.8 pytorch-lightning-0.8.5 requests-2.24.0 requests-oauthlib-1.3.0 rsa-4.6 sagemaker-1.69.0 scipy-1.5.1 six-1.15.0 smdebug-rulesconfig-0.1.4 tensorboard-2.2.2 tensorboard-plugin-wit-1.7.0 torch-1.5.1 tqdm-4.47.0 urllib3-1.25.9\u001b[0m\n", 257 | "\u001b[34m2020-07-13 13:14:48,434 sagemaker-containers INFO Invoking user script\n", 258 | "\u001b[0m\n", 259 | "\u001b[34mTraining Env:\n", 260 | "\u001b[0m\n", 261 | "\u001b[34m{\n", 262 | " \"additional_framework_parameters\": {},\n", 263 | " \"channel_input_dirs\": {\n", 264 | " \"test\": \"/opt/ml/input/data/test\",\n", 265 | " \"train\": \"/opt/ml/input/data/train\"\n", 266 | " },\n", 267 | " \"current_host\": \"algo-1\",\n", 268 | " \"framework_module\": \"sagemaker_pytorch_container.training:main\",\n", 269 | " \"hosts\": [\n", 270 | " \"algo-1\"\n", 271 | " ],\n", 272 | " \"hyperparameters\": {\n", 273 | " \"batch-size\": 128,\n", 274 | " \"epochs\": 6\n", 275 | " },\n", 276 | " \"input_config_dir\": \"/opt/ml/input/config\",\n", 277 | " \"input_data_config\": {\n", 278 | " \"test\": {\n", 279 | " \"TrainingInputMode\": \"File\",\n", 280 | " \"S3DistributionType\": \"FullyReplicated\",\n", 281 | " \"RecordWrapperType\": \"None\"\n", 282 | " },\n", 283 | " \"train\": {\n", 284 | " \"TrainingInputMode\": \"File\",\n", 285 | " \"S3DistributionType\": \"FullyReplicated\",\n", 286 | " \"RecordWrapperType\": \"None\"\n", 287 | " }\n", 288 | " },\n", 289 | " \"input_dir\": \"/opt/ml/input\",\n", 290 | " \"is_master\": true,\n", 291 | " \"job_name\": \"pytorch-training-2020-07-13-13-04-36-994\",\n", 292 | " \"log_level\": 20,\n", 293 | " \"master_hostname\": \"algo-1\",\n", 294 | " \"model_dir\": \"/opt/ml/model\",\n", 295 | " \"module_dir\": \"s3://sagemaker-eu-west-1-682411330166/pytorch-training-2020-07-13-13-04-36-994/source/sourcedir.tar.gz\",\n", 296 | " \"module_name\": \"train\",\n", 297 | " \"network_interface_name\": \"eth0\",\n", 298 | " \"num_cpus\": 4,\n", 299 | " \"num_gpus\": 1,\n", 300 | " \"output_data_dir\": \"/opt/ml/output/data\",\n", 301 | " \"output_dir\": \"/opt/ml/output\",\n", 302 | " \"output_intermediate_dir\": \"/opt/ml/output/intermediate\",\n", 303 | " \"resource_config\": {\n", 304 | " \"current_host\": \"algo-1\",\n", 305 | " \"hosts\": [\n", 306 | " \"algo-1\"\n", 307 | " ],\n", 308 | " \"network_interface_name\": \"eth0\"\n", 309 | " },\n", 310 | " \"user_entry_point\": \"train.py\"\u001b[0m\n", 311 | "\u001b[34m}\n", 312 | "\u001b[0m\n", 313 | "\u001b[34mEnvironment variables:\n", 314 | "\u001b[0m\n", 315 | "\u001b[34mSM_HOSTS=[\"algo-1\"]\u001b[0m\n", 316 | "\u001b[34mSM_NETWORK_INTERFACE_NAME=eth0\u001b[0m\n", 317 | "\u001b[34mSM_HPS={\"batch-size\":128,\"epochs\":6}\u001b[0m\n", 318 | "\u001b[34mSM_USER_ENTRY_POINT=train.py\u001b[0m\n", 319 | "\u001b[34mSM_FRAMEWORK_PARAMS={}\u001b[0m\n", 320 | "\u001b[34mSM_RESOURCE_CONFIG={\"current_host\":\"algo-1\",\"hosts\":[\"algo-1\"],\"network_interface_name\":\"eth0\"}\u001b[0m\n", 321 | "\u001b[34mSM_INPUT_DATA_CONFIG={\"test\":{\"RecordWrapperType\":\"None\",\"S3DistributionType\":\"FullyReplicated\",\"TrainingInputMode\":\"File\"},\"train\":{\"RecordWrapperType\":\"None\",\"S3DistributionType\":\"FullyReplicated\",\"TrainingInputMode\":\"File\"}}\u001b[0m\n", 322 | "\u001b[34mSM_OUTPUT_DATA_DIR=/opt/ml/output/data\u001b[0m\n", 323 | "\u001b[34mSM_CHANNELS=[\"test\",\"train\"]\u001b[0m\n", 324 | "\u001b[34mSM_CURRENT_HOST=algo-1\u001b[0m\n", 325 | "\u001b[34mSM_MODULE_NAME=train\u001b[0m\n", 326 | "\u001b[34mSM_LOG_LEVEL=20\u001b[0m\n", 327 | "\u001b[34mSM_FRAMEWORK_MODULE=sagemaker_pytorch_container.training:main\u001b[0m\n", 328 | "\u001b[34mSM_INPUT_DIR=/opt/ml/input\u001b[0m\n", 329 | "\u001b[34mSM_INPUT_CONFIG_DIR=/opt/ml/input/config\u001b[0m\n", 330 | "\u001b[34mSM_OUTPUT_DIR=/opt/ml/output\u001b[0m\n", 331 | "\u001b[34mSM_NUM_CPUS=4\u001b[0m\n", 332 | "\u001b[34mSM_NUM_GPUS=1\u001b[0m\n", 333 | "\u001b[34mSM_MODEL_DIR=/opt/ml/model\u001b[0m\n", 334 | "\u001b[34mSM_MODULE_DIR=s3://sagemaker-eu-west-1-682411330166/pytorch-training-2020-07-13-13-04-36-994/source/sourcedir.tar.gz\u001b[0m\n", 335 | "\u001b[34mSM_TRAINING_ENV={\"additional_framework_parameters\":{},\"channel_input_dirs\":{\"test\":\"/opt/ml/input/data/test\",\"train\":\"/opt/ml/input/data/train\"},\"current_host\":\"algo-1\",\"framework_module\":\"sagemaker_pytorch_container.training:main\",\"hosts\":[\"algo-1\"],\"hyperparameters\":{\"batch-size\":128,\"epochs\":6},\"input_config_dir\":\"/opt/ml/input/config\",\"input_data_config\":{\"test\":{\"RecordWrapperType\":\"None\",\"S3DistributionType\":\"FullyReplicated\",\"TrainingInputMode\":\"File\"},\"train\":{\"RecordWrapperType\":\"None\",\"S3DistributionType\":\"FullyReplicated\",\"TrainingInputMode\":\"File\"}},\"input_dir\":\"/opt/ml/input\",\"is_master\":true,\"job_name\":\"pytorch-training-2020-07-13-13-04-36-994\",\"log_level\":20,\"master_hostname\":\"algo-1\",\"model_dir\":\"/opt/ml/model\",\"module_dir\":\"s3://sagemaker-eu-west-1-682411330166/pytorch-training-2020-07-13-13-04-36-994/source/sourcedir.tar.gz\",\"module_name\":\"train\",\"network_interface_name\":\"eth0\",\"num_cpus\":4,\"num_gpus\":1,\"output_data_dir\":\"/opt/ml/output/data\",\"output_dir\":\"/opt/ml/output\",\"output_intermediate_dir\":\"/opt/ml/output/intermediate\",\"resource_config\":{\"current_host\":\"algo-1\",\"hosts\":[\"algo-1\"],\"network_interface_name\":\"eth0\"},\"user_entry_point\":\"train.py\"}\u001b[0m\n", 336 | "\u001b[34mSM_USER_ARGS=[\"--batch-size\",\"128\",\"--epochs\",\"6\"]\u001b[0m\n", 337 | "\u001b[34mSM_OUTPUT_INTERMEDIATE_DIR=/opt/ml/output/intermediate\u001b[0m\n", 338 | "\u001b[34mSM_CHANNEL_TEST=/opt/ml/input/data/test\u001b[0m\n", 339 | "\u001b[34mSM_CHANNEL_TRAIN=/opt/ml/input/data/train\u001b[0m\n", 340 | "\u001b[34mSM_HP_BATCH-SIZE=128\u001b[0m\n", 341 | "\u001b[34mSM_HP_EPOCHS=6\u001b[0m\n", 342 | "\u001b[34mPYTHONPATH=/opt/ml/code:/opt/conda/bin:/opt/conda/lib/python36.zip:/opt/conda/lib/python3.6:/opt/conda/lib/python3.6/lib-dynload:/opt/conda/lib/python3.6/site-packages\n", 343 | "\u001b[0m\n", 344 | "\u001b[34mInvoking script with the following command:\n", 345 | "\u001b[0m\n", 346 | "\u001b[34m/opt/conda/bin/python train.py --batch-size 128 --epochs 6\n", 347 | "\n", 348 | "\u001b[0m\n", 349 | "\n", 350 | "2020-07-13 13:15:10 Uploading - Uploading generated training model\n", 351 | "2020-07-13 13:15:10 Failed - Training job failed\n", 352 | "\u001b[34mNamespace(batch_size=128, epochs=6, gpus=1, model_dir=None, output_data_dir=None, test=None, train=None)\u001b[0m\n", 353 | "\u001b[34m2020-07-13 13:14:59,636 sagemaker-containers ERROR ExecuteUserScriptError:\u001b[0m\n", 354 | "\u001b[34mCommand \"/opt/conda/bin/python train.py --batch-size 128 --epochs 6\"\u001b[0m\n", 355 | "\u001b[34mGPU available: True, used: True\u001b[0m\n", 356 | "\u001b[34mTPU available: False, using: 0 TPU cores\u001b[0m\n", 357 | "\u001b[34mCUDA_VISIBLE_DEVICES: [0]\u001b[0m\n", 358 | "\u001b[34mTraceback (most recent call last):\n", 359 | " File \"train.py\", line 36, in \n", 360 | " mnistTrainer.fit(model)\n", 361 | " File \"/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py\", line 944, in fit\n", 362 | " model.prepare_data()\n", 363 | " File \"/opt/ml/code/MNISTClassifier.py\", line 64, in prepare_data\n", 364 | " self.train_loader, self.val_loader, self.test_loader = self.load_split_train_test()\n", 365 | " File \"/opt/ml/code/MNISTClassifier.py\", line 43, in load_split_train_test\n", 366 | " train_data = datasets.ImageFolder(self.train_data_dir, transform=train_transforms)\n", 367 | " File \"/opt/conda/lib/python3.6/site-packages/torchvision/datasets/folder.py\", line 209, in __init__\n", 368 | " is_valid_file=is_valid_file)\n", 369 | " File \"/opt/conda/lib/python3.6/site-packages/torchvision/datasets/folder.py\", line 94, in __init__\n", 370 | " samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)\n", 371 | " File \"/opt/conda/lib/python3.6/site-packages/torchvision/datasets/folder.py\", line 37, in make_dataset\n", 372 | " dir = os.path.expanduser(dir)\n", 373 | " File \"/opt/conda/lib/python3.6/posixpath.py\", line 235, in expanduser\n", 374 | " path = os.fspath(path)\u001b[0m\n", 375 | "\u001b[34mTypeError: expected str, bytes or os.PathLike object, not NoneType\u001b[0m\n" 376 | ] 377 | }, 378 | { 379 | "ename": "UnexpectedStatusException", 380 | "evalue": "Error for Training job pytorch-training-2020-07-13-13-04-36-994: Failed. Reason: AlgorithmError: ExecuteUserScriptError:\nCommand \"/opt/conda/bin/python train.py --batch-size 128 --epochs 6\"\nGPU available: True, used: True\nTPU available: False, using: 0 TPU cores\nCUDA_VISIBLE_DEVICES: [0]\nTraceback (most recent call last):\n File \"train.py\", line 36, in \n mnistTrainer.fit(model)\n File \"/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py\", line 944, in fit\n model.prepare_data()\n File \"/opt/ml/code/MNISTClassifier.py\", line 64, in prepare_data\n self.train_loader, self.val_loader, self.test_loader = self.load_split_train_test()\n File \"/opt/ml/code/MNISTClassifier.py\", line 43, in load_split_train_test\n train_data = datasets.ImageFolder(self.train_data_dir, transform=train_transforms)\n File \"/opt/conda/lib/python3.6/site-packages/torchvision/datasets/folder.py\", line 209, in __init__\n is_valid_file=is_valid_file)\n File \"/opt/conda/lib/python3.6/site-packages/torchvision/datasets/folder.py\", line 94, in __init__\n samples = make_dataset(sel", 381 | "output_type": "error", 382 | "traceback": [ 383 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 384 | "\u001b[0;31mUnexpectedStatusException\u001b[0m Traceback (most recent call last)", 385 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m estimator.fit({\n\u001b[1;32m 2\u001b[0m \u001b[0;34m'train'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbucket\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;34m'/training'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;34m'test'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbucket\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;34m'/testing'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m })\n", 386 | "\u001b[0;32m~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/sagemaker/estimator.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, inputs, wait, logs, job_name, experiment_config)\u001b[0m\n\u001b[1;32m 495\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjobs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlatest_training_job\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 496\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mwait\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 497\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlatest_training_job\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwait\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlogs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlogs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 498\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 499\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_compilation_job_name\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 387 | "\u001b[0;32m~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/sagemaker/estimator.py\u001b[0m in \u001b[0;36mwait\u001b[0;34m(self, logs)\u001b[0m\n\u001b[1;32m 1112\u001b[0m \u001b[0;31m# If logs are requested, call logs_for_jobs.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1113\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlogs\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m\"None\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1114\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msagemaker_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogs_for_job\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjob_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwait\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_type\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlogs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1115\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1116\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msagemaker_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwait_for_job\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjob_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 388 | "\u001b[0;32m~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/sagemaker/session.py\u001b[0m in \u001b[0;36mlogs_for_job\u001b[0;34m(self, job_name, wait, poll, log_type)\u001b[0m\n\u001b[1;32m 3068\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3069\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mwait\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3070\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_check_job_status\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjob_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdescription\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"TrainingJobStatus\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3071\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdot\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3072\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 389 | "\u001b[0;32m~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/sagemaker/session.py\u001b[0m in \u001b[0;36m_check_job_status\u001b[0;34m(self, job, desc, status_key_name)\u001b[0m\n\u001b[1;32m 2662\u001b[0m ),\n\u001b[1;32m 2663\u001b[0m \u001b[0mallowed_statuses\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"Completed\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"Stopped\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2664\u001b[0;31m \u001b[0mactual_status\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mstatus\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2665\u001b[0m )\n\u001b[1;32m 2666\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 390 | "\u001b[0;31mUnexpectedStatusException\u001b[0m: Error for Training job pytorch-training-2020-07-13-13-04-36-994: Failed. Reason: AlgorithmError: ExecuteUserScriptError:\nCommand \"/opt/conda/bin/python train.py --batch-size 128 --epochs 6\"\nGPU available: True, used: True\nTPU available: False, using: 0 TPU cores\nCUDA_VISIBLE_DEVICES: [0]\nTraceback (most recent call last):\n File \"train.py\", line 36, in \n mnistTrainer.fit(model)\n File \"/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py\", line 944, in fit\n model.prepare_data()\n File \"/opt/ml/code/MNISTClassifier.py\", line 64, in prepare_data\n self.train_loader, self.val_loader, self.test_loader = self.load_split_train_test()\n File \"/opt/ml/code/MNISTClassifier.py\", line 43, in load_split_train_test\n train_data = datasets.ImageFolder(self.train_data_dir, transform=train_transforms)\n File \"/opt/conda/lib/python3.6/site-packages/torchvision/datasets/folder.py\", line 209, in __init__\n is_valid_file=is_valid_file)\n File \"/opt/conda/lib/python3.6/site-packages/torchvision/datasets/folder.py\", line 94, in __init__\n samples = make_dataset(sel" 391 | ] 392 | } 393 | ], 394 | "source": [ 395 | "estimator.fit({\n", 396 | " 'train': bucket+'/training',\n", 397 | " 'test': bucket+'/testing'\n", 398 | "})" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": null, 404 | "metadata": {}, 405 | "outputs": [], 406 | "source": [] 407 | } 408 | ], 409 | "metadata": { 410 | "kernelspec": { 411 | "display_name": "conda_pytorch_p36", 412 | "language": "python", 413 | "name": "conda_pytorch_p36" 414 | }, 415 | "language_info": { 416 | "codemirror_mode": { 417 | "name": "ipython", 418 | "version": 3 419 | }, 420 | "file_extension": ".py", 421 | "mimetype": "text/x-python", 422 | "name": "python", 423 | "nbconvert_exporter": "python", 424 | "pygments_lexer": "ipython3", 425 | "version": "3.6.10" 426 | } 427 | }, 428 | "nbformat": 4, 429 | "nbformat_minor": 4 430 | } 431 | -------------------------------------------------------------------------------- /sagemaker-run.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aletheia/mnist_pl_sagemaker/9083684ee22f3dbf369c4a1509d2044760628a07/sagemaker-run.png -------------------------------------------------------------------------------- /sagemaker-sdk.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aletheia/mnist_pl_sagemaker/9083684ee22f3dbf369c4a1509d2044760628a07/sagemaker-sdk.png --------------------------------------------------------------------------------