├── .github ├── CONTRIBUTING.md ├── ISSUE_TEMPLATE │ └── bug_report.md └── PULL_REQUEST_TEMPLATE.md ├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── audio-processing └── pytorch │ └── audio_preprocessing_tutorial.ipynb ├── generative-adversarial-network └── tensorflow │ └── deep_convolutional_gan.ipynb ├── getting-started ├── pytorch │ ├── data_loading_processing.ipynb │ └── tensor_intro.ipynb ├── scikit-learn │ └── handwritten-digit-classifier.ipynb └── tensorflow │ ├── handwritten_digit_classifier.ipynb │ └── image_classification.ipynb ├── image-processing ├── pytorch │ └── object_detection_finetuning.ipynb └── tensorflow │ └── image_classification.ipynb ├── logo.png ├── natural-language-processing ├── pytorch │ └── translation_with_sequence2sequence.ipynb └── tensorflow │ ├── rnn_text_gen.ipynb │ ├── transformer_language_understanding.ipynb │ └── translation_with_sequence2sequence.ipynb ├── reinforcement-learning └── pytorch │ └── deep_q_learning.ipynb └── transfer-learning └── pytorch └── cnn_transfer_learning.ipynb /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | ## Add a New Notebook 4 | 5 | --- 6 | 7 | Before proposing a new notebook to be added, ensure the following 8 | 9 | - It is not a duplicate or very similar to another notebook. If it is, consider collaborating with the other author to merge the code into a single notebook. 10 | - Make sure it has been tested and can run on [Colaboratory](https://colab.research.google.com/). 11 | - If the example requires a dataset, make sure it is hosted and can be downloaded with `wget` or `curl` from the notebook. 12 | 13 | When you are ready to submit a new Notebook 14 | 15 | 1. Give the file a descriptive and concise name that uses [snake case](https://en.wikipedia.org/wiki/Snake_case) 16 | 1. Add the Colab and GitHub buttons with the following snippet with the updated paths: 17 | 18 | ```html 19 | 20 | 28 | 37 |
21 | Run 25 | in Google Colab 27 | 29 | View source on GitHub 36 |
38 | ``` 39 | 40 | 1. Add license and credit as the last section of the notebook 41 | 1. Submit a PR using the [Add Notebook template]() 42 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[NOTEBOOK NAME] - [ERROR DESCRIPTION]" 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **Notebook information** 14 | - URI: 15 | - Section: 16 | 17 | **To Reproduce** 18 | Steps to reproduce the behavior: 19 | 1. Go to '...' 20 | 2. Click on '....' 21 | 3. Scroll down to '....' 22 | 4. See error 23 | 24 | **Expected behavior** 25 | A clear and concise description of what you expected to happen. 26 | 27 | **Screenshots** 28 | If applicable, add screenshots to help explain your problem. 29 | 30 | **Suggested fix** 31 | Provide a suggestion on how to fix the problem. 32 | 33 | **Additional context** 34 | Add any other context about the problem here. 35 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | Thank you for your contribution to the NotebookExplore repo. 2 | Before submitting this PR, please make sure: 3 | 4 | - [ ] There are no closely similar notebook to the one you are submitting 5 | - [ ] Your notebook runs on Colab without any errors or warnings: [COLAB LINK] 6 | - [ ] Datasets used in your notebook can be downloaded from through the notebook 7 | - [ ] The notebook is being added to the correct location in the repo 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # OS generated files 132 | .DS_Store -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at notebookexplore@gmail.com. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | NotebookExplore logo 3 |

NotebookExplore

4 |

5 | 6 | A collection of #MachineLearning #Python Notebooks 🤖 🐍📚 that can be launched to the ☁️ for use and experimentation. No setup needed, just launch it 🚀 7 | 8 |

9 | NotebookExplore Demo 10 |

11 | 12 | ## Table of Contents 13 | 14 | | [🔍 Search Notebooks](https://github.com/notebookexplore/notebookexplore/find/master) | 15 | | ------------------------------------------------------------------------------------- | 16 | 17 | 18 | - [Getting Started](https://github.com/notebookexplore/notebookexplore/blob/master/getting-started) 19 | - [Audio Processing](https://github.com/notebookexplore/notebookexplore/blob/master/audio-processing) 20 | - [Generative Adversarial Network (GAN)](https://github.com/notebookexplore/notebookexplore/blob/master/generative-adversarial-network) 21 | - [Natural Language Processing (NLP)](https://github.com/notebookexplore/notebookexplore/blob/master/natural-language-processing) 22 | - [Image Processing](https://github.com/notebookexplore/notebookexplore/blob/master/image-processing) 23 | - [Reinforcement Learning](https://github.com/notebookexplore/notebookexplore/blob/master/reinforcement-learning) 24 | - [Transfer Learning](https://github.com/notebookexplore/NotebookExplore/tree/master/transfer-learning) 25 | 26 | ## Contribute 27 | 28 | There are many ways to contribute NotebookExplore. 29 | 30 | - Add a new notebook 31 | - Fix bugs 32 | - Review updates 33 | - Help each other in the community 34 | 35 | Read more about the [contributing process and guidelines](https://github.com/notebookexplore/NotebookExplore/tree/master/.github/CONTRIBUTING.md). 36 | 37 | We also welcome collaborators, so feel free to open a [GitHub issue](https://github.com/notebookexplore/NotebookExplore/issues/new/choose) to let us know how'd you like to contribute. 38 | 39 | ## Connect 40 | 41 | - Follow the project on Twitter [@NotebookExplore](https://twitter.com/NotebookExplore) 42 | - Open an [issue](https://github.com/notebookexplore/notebookexplore/issues/new) 43 | - Slack/Discord (coming soon) 44 | -------------------------------------------------------------------------------- /audio-processing/pytorch/audio_preprocessing_tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "display_name": "Python 3", 7 | "language": "python", 8 | "name": "python3" 9 | }, 10 | "language_info": { 11 | "codemirror_mode": { 12 | "name": "ipython", 13 | "version": 3 14 | }, 15 | "file_extension": ".py", 16 | "mimetype": "text/x-python", 17 | "name": "python", 18 | "nbconvert_exporter": "python", 19 | "pygments_lexer": "ipython3", 20 | "version": "3.6.8" 21 | }, 22 | "colab": { 23 | "name": "audio_preprocessing_tutorial.ipynb", 24 | "provenance": [] 25 | } 26 | }, 27 | "cells": [ 28 | { 29 | "cell_type": "markdown", 30 | "metadata": { 31 | "id": "Ejys4qb9K4yp", 32 | "colab_type": "text" 33 | }, 34 | "source": [ 35 | "\n", 36 | " \n", 44 | " \n", 53 | "
\n", 37 | " Run\n", 41 | " in Google Colab\n", 43 | " \n", 45 | " View source on GitHub\n", 52 | "
" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "metadata": { 59 | "id": "U8adtjNkog1-", 60 | "colab_type": "code", 61 | "colab": {} 62 | }, 63 | "source": [ 64 | "!pip install torch>=1.2.0\n", 65 | "!pip install torchaudio\n", 66 | "%matplotlib inline" 67 | ], 68 | "execution_count": 0, 69 | "outputs": [] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": { 74 | "id": "oHpAbYz2og2G", 75 | "colab_type": "text" 76 | }, 77 | "source": [ 78 | "\n", 79 | "torchaudio Tutorial\n", 80 | "===================\n", 81 | "\n", 82 | "PyTorch is an open source deep learning platform that provides a\n", 83 | "seamless path from research prototyping to production deployment with\n", 84 | "GPU support.\n", 85 | "\n", 86 | "Significant effort in solving machine learning problems goes into data\n", 87 | "preparation. torchaudio leverages PyTorch’s GPU support, and provides\n", 88 | "many tools to make data loading easy and more readable. In this\n", 89 | "tutorial, we will see how to load and preprocess data from a simple\n", 90 | "dataset.\n", 91 | "\n", 92 | "For this tutorial, please make sure the ``matplotlib`` package is\n", 93 | "installed for easier visualization.\n", 94 | "\n", 95 | "\n" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "metadata": { 101 | "id": "AFPA4DPUog2I", 102 | "colab_type": "code", 103 | "colab": {} 104 | }, 105 | "source": [ 106 | "import torch\n", 107 | "import torchaudio\n", 108 | "import matplotlib.pyplot as plt" 109 | ], 110 | "execution_count": 0, 111 | "outputs": [] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": { 116 | "id": "4HhXSHmGog2P", 117 | "colab_type": "text" 118 | }, 119 | "source": [ 120 | "Opening a dataset\n", 121 | "-----------------\n", 122 | "\n", 123 | "\n" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": { 129 | "id": "QEbIBsDFog2R", 130 | "colab_type": "text" 131 | }, 132 | "source": [ 133 | "torchaudio supports loading sound files in the wav and mp3 format. We\n", 134 | "call waveform the resulting raw audio signal.\n", 135 | "\n", 136 | "\n" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "metadata": { 142 | "id": "Uh7cF3wnog2S", 143 | "colab_type": "code", 144 | "colab": {} 145 | }, 146 | "source": [ 147 | "import requests\n", 148 | "\n", 149 | "url = \"https://pytorch.org/tutorials//_static/img/steam-train-whistle-daniel_simon-converted-from-mp3.wav\"\n", 150 | "r = requests.get(url)\n", 151 | "\n", 152 | "with open('steam-train-whistle-daniel_simon-converted-from-mp3.wav', 'wb') as f:\n", 153 | " f.write(r.content)\n", 154 | "\n", 155 | "filename = \"steam-train-whistle-daniel_simon-converted-from-mp3.wav\"\n", 156 | "waveform, sample_rate = torchaudio.load(filename)\n", 157 | "\n", 158 | "print(\"Shape of waveform: {}\".format(waveform.size()))\n", 159 | "print(\"Sample rate of waveform: {}\".format(sample_rate))\n", 160 | "\n", 161 | "plt.figure()\n", 162 | "plt.plot(waveform.t().numpy())" 163 | ], 164 | "execution_count": 0, 165 | "outputs": [] 166 | }, 167 | { 168 | "cell_type": "markdown", 169 | "metadata": { 170 | "id": "sLxzNge9og2X", 171 | "colab_type": "text" 172 | }, 173 | "source": [ 174 | "Transformations\n", 175 | "---------------\n", 176 | "\n", 177 | "torchaudio supports a growing list of\n", 178 | "`transformations `_.\n", 179 | "\n", 180 | "- **Resample**: Resample waveform to a different sample rate.\n", 181 | "- **Spectrogram**: Create a spectrogram from a waveform.\n", 182 | "- **MelScale**: This turns a normal STFT into a Mel-frequency STFT,\n", 183 | " using a conversion matrix.\n", 184 | "- **AmplitudeToDB**: This turns a spectrogram from the\n", 185 | " power/amplitude scale to the decibel scale.\n", 186 | "- **MFCC**: Create the Mel-frequency cepstrum coefficients from a\n", 187 | " waveform.\n", 188 | "- **MelSpectrogram**: Create MEL Spectrograms from a waveform using the\n", 189 | " STFT function in PyTorch.\n", 190 | "- **MuLawEncoding**: Encode waveform based on mu-law companding.\n", 191 | "- **MuLawDecoding**: Decode mu-law encoded waveform.\n", 192 | "\n", 193 | "Since all transforms are nn.Modules or jit.ScriptModules, they can be\n", 194 | "used as part of a neural network at any point.\n", 195 | "\n", 196 | "\n" 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "metadata": { 202 | "id": "JLoCOqHvog2Z", 203 | "colab_type": "text" 204 | }, 205 | "source": [ 206 | "To start, we can look at the log of the spectrogram on a log scale.\n", 207 | "\n", 208 | "\n" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "metadata": { 214 | "id": "T-JJqufHog2a", 215 | "colab_type": "code", 216 | "colab": {} 217 | }, 218 | "source": [ 219 | "specgram = torchaudio.transforms.Spectrogram()(waveform)\n", 220 | "\n", 221 | "print(\"Shape of spectrogram: {}\".format(specgram.size()))\n", 222 | "\n", 223 | "plt.figure()\n", 224 | "plt.imshow(specgram.log2()[0,:,:].numpy(), cmap='gray')" 225 | ], 226 | "execution_count": 0, 227 | "outputs": [] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "metadata": { 232 | "id": "LTshMbr9og2i", 233 | "colab_type": "text" 234 | }, 235 | "source": [ 236 | "Or we can look at the Mel Spectrogram on a log scale.\n", 237 | "\n", 238 | "\n" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "metadata": { 244 | "id": "tL_b8Oa5og2j", 245 | "colab_type": "code", 246 | "colab": {} 247 | }, 248 | "source": [ 249 | "specgram = torchaudio.transforms.MelSpectrogram()(waveform)\n", 250 | "\n", 251 | "print(\"Shape of spectrogram: {}\".format(specgram.size()))\n", 252 | "\n", 253 | "plt.figure()\n", 254 | "p = plt.imshow(specgram.log2()[0,:,:].detach().numpy(), cmap='gray')" 255 | ], 256 | "execution_count": 0, 257 | "outputs": [] 258 | }, 259 | { 260 | "cell_type": "markdown", 261 | "metadata": { 262 | "id": "0GNui3v-og2n", 263 | "colab_type": "text" 264 | }, 265 | "source": [ 266 | "We can resample the waveform, one channel at a time.\n", 267 | "\n", 268 | "\n" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "metadata": { 274 | "id": "1eHZFUR8og2p", 275 | "colab_type": "code", 276 | "colab": {} 277 | }, 278 | "source": [ 279 | "new_sample_rate = sample_rate/10\n", 280 | "\n", 281 | "# Since Resample applies to a single channel, we resample first channel here\n", 282 | "channel = 0\n", 283 | "transformed = torchaudio.transforms.Resample(sample_rate, new_sample_rate)(waveform[channel,:].view(1,-1))\n", 284 | "\n", 285 | "print(\"Shape of transformed waveform: {}\".format(transformed.size()))\n", 286 | "\n", 287 | "plt.figure()\n", 288 | "plt.plot(transformed[0,:].numpy())" 289 | ], 290 | "execution_count": 0, 291 | "outputs": [] 292 | }, 293 | { 294 | "cell_type": "markdown", 295 | "metadata": { 296 | "id": "1mT305QXog2s", 297 | "colab_type": "text" 298 | }, 299 | "source": [ 300 | "As another example of transformations, we can encode the signal based on\n", 301 | "Mu-Law enconding. But to do so, we need the signal to be between -1 and\n", 302 | "1. Since the tensor is just a regular PyTorch tensor, we can apply\n", 303 | "standard operators on it.\n", 304 | "\n", 305 | "\n" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "metadata": { 311 | "id": "o5hJAVEnog2u", 312 | "colab_type": "code", 313 | "colab": {} 314 | }, 315 | "source": [ 316 | "# Let's check if the tensor is in the interval [-1,1]\n", 317 | "print(\"Min of waveform: {}\\nMax of waveform: {}\\nMean of waveform: {}\".format(waveform.min(), waveform.max(), waveform.mean()))" 318 | ], 319 | "execution_count": 0, 320 | "outputs": [] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "metadata": { 325 | "id": "vlk9qo4Rog2y", 326 | "colab_type": "text" 327 | }, 328 | "source": [ 329 | "Since the waveform is already between -1 and 1, we do not need to\n", 330 | "normalize it.\n", 331 | "\n", 332 | "\n" 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "metadata": { 338 | "id": "TI3yBBSGog2z", 339 | "colab_type": "code", 340 | "colab": {} 341 | }, 342 | "source": [ 343 | "def normalize(tensor):\n", 344 | " # Subtract the mean, and scale to the interval [-1,1]\n", 345 | " tensor_minusmean = tensor - tensor.mean()\n", 346 | " return tensor_minusmean/tensor_minusmean.abs().max()\n", 347 | "\n", 348 | "# Let's normalize to the full interval [-1,1]\n", 349 | "# waveform = normalize(waveform)" 350 | ], 351 | "execution_count": 0, 352 | "outputs": [] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "metadata": { 357 | "id": "3Yu7ubXZog22", 358 | "colab_type": "text" 359 | }, 360 | "source": [ 361 | "Let’s apply encode the waveform.\n", 362 | "\n", 363 | "\n" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "metadata": { 369 | "id": "qjytiCNYog23", 370 | "colab_type": "code", 371 | "colab": {} 372 | }, 373 | "source": [ 374 | "transformed = torchaudio.transforms.MuLawEncoding()(waveform)\n", 375 | "\n", 376 | "print(\"Shape of transformed waveform: {}\".format(transformed.size()))\n", 377 | "\n", 378 | "plt.figure()\n", 379 | "plt.plot(transformed[0,:].numpy())" 380 | ], 381 | "execution_count": 0, 382 | "outputs": [] 383 | }, 384 | { 385 | "cell_type": "markdown", 386 | "metadata": { 387 | "id": "pgApI-D6og27", 388 | "colab_type": "text" 389 | }, 390 | "source": [ 391 | "And now decode.\n", 392 | "\n", 393 | "\n" 394 | ] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "metadata": { 399 | "id": "Ciijz1Kgog28", 400 | "colab_type": "code", 401 | "colab": {} 402 | }, 403 | "source": [ 404 | "reconstructed = torchaudio.transforms.MuLawDecoding()(transformed)\n", 405 | "\n", 406 | "print(\"Shape of recovered waveform: {}\".format(reconstructed.size()))\n", 407 | "\n", 408 | "plt.figure()\n", 409 | "plt.plot(reconstructed[0,:].numpy())" 410 | ], 411 | "execution_count": 0, 412 | "outputs": [] 413 | }, 414 | { 415 | "cell_type": "markdown", 416 | "metadata": { 417 | "id": "fyJ02hGMog3C", 418 | "colab_type": "text" 419 | }, 420 | "source": [ 421 | "We can finally compare the original waveform with its reconstructed\n", 422 | "version.\n", 423 | "\n", 424 | "\n" 425 | ] 426 | }, 427 | { 428 | "cell_type": "code", 429 | "metadata": { 430 | "id": "4B6pyGd5og3D", 431 | "colab_type": "code", 432 | "colab": {} 433 | }, 434 | "source": [ 435 | "# Compute median relative difference\n", 436 | "err = ((waveform-reconstructed).abs() / waveform.abs()).median()\n", 437 | "\n", 438 | "print(\"Median relative difference between original and MuLaw reconstucted signals: {:.2%}\".format(err))" 439 | ], 440 | "execution_count": 0, 441 | "outputs": [] 442 | }, 443 | { 444 | "cell_type": "markdown", 445 | "metadata": { 446 | "id": "ayNUB_keog3H", 447 | "colab_type": "text" 448 | }, 449 | "source": [ 450 | "Migrating to torchaudio from Kaldi\n", 451 | "----------------------------------\n", 452 | "\n", 453 | "Users may be familiar with\n", 454 | "`Kaldi `_, a toolkit for speech\n", 455 | "recognition. torchaudio offers compatibility with it in\n", 456 | "``torchaudio.kaldi_io``. It can indeed read from kaldi scp, or ark file\n", 457 | "or streams with:\n", 458 | "\n", 459 | "- read_vec_int_ark\n", 460 | "- read_vec_flt_scp\n", 461 | "- read_vec_flt_arkfile/stream\n", 462 | "- read_mat_scp\n", 463 | "- read_mat_ark\n", 464 | "\n", 465 | "torchaudio provides Kaldi-compatible transforms for ``spectrogram`` and\n", 466 | "``fbank`` with the benefit of GPU support, see\n", 467 | "`here `__ for more information.\n", 468 | "\n", 469 | "\n" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "metadata": { 475 | "id": "NiSS8s10og3I", 476 | "colab_type": "code", 477 | "colab": {} 478 | }, 479 | "source": [ 480 | "n_fft = 400.0\n", 481 | "frame_length = n_fft / sample_rate * 1000.0\n", 482 | "frame_shift = frame_length / 2.0\n", 483 | "\n", 484 | "params = {\n", 485 | " \"channel\": 0,\n", 486 | " \"dither\": 0.0,\n", 487 | " \"window_type\": \"hanning\",\n", 488 | " \"frame_length\": frame_length,\n", 489 | " \"frame_shift\": frame_shift,\n", 490 | " \"remove_dc_offset\": False,\n", 491 | " \"round_to_power_of_two\": False,\n", 492 | " \"sample_frequency\": sample_rate,\n", 493 | "}\n", 494 | "\n", 495 | "specgram = torchaudio.compliance.kaldi.spectrogram(waveform, **params)\n", 496 | "\n", 497 | "print(\"Shape of spectrogram: {}\".format(specgram.size()))\n", 498 | "\n", 499 | "plt.figure()\n", 500 | "plt.imshow(specgram.t().numpy(), cmap='gray')" 501 | ], 502 | "execution_count": 0, 503 | "outputs": [] 504 | }, 505 | { 506 | "cell_type": "markdown", 507 | "metadata": { 508 | "id": "v5DauoCqog3M", 509 | "colab_type": "text" 510 | }, 511 | "source": [ 512 | "We also support computing the filterbank features from waveforms,\n", 513 | "matching Kaldi’s implementation.\n", 514 | "\n", 515 | "\n" 516 | ] 517 | }, 518 | { 519 | "cell_type": "code", 520 | "metadata": { 521 | "id": "ONLSeJfIog3N", 522 | "colab_type": "code", 523 | "colab": {} 524 | }, 525 | "source": [ 526 | "fbank = torchaudio.compliance.kaldi.fbank(waveform, **params)\n", 527 | "\n", 528 | "print(\"Shape of fbank: {}\".format(fbank.size()))\n", 529 | "\n", 530 | "plt.figure()\n", 531 | "plt.imshow(fbank.t().numpy(), cmap='gray')" 532 | ], 533 | "execution_count": 0, 534 | "outputs": [] 535 | }, 536 | { 537 | "cell_type": "markdown", 538 | "metadata": { 539 | "id": "BX519cRTog3S", 540 | "colab_type": "text" 541 | }, 542 | "source": [ 543 | "Conclusion\n", 544 | "----------\n", 545 | "\n", 546 | "We used an example raw audio signal, or waveform, to illustrate how to\n", 547 | "open an audio file using torchaudio, and how to pre-process and\n", 548 | "transform such waveform. Given that torchaudio is built on PyTorch,\n", 549 | "these techniques can be used as building blocks for more advanced audio\n", 550 | "applications, such as speech recognition, while leveraging GPUs.\n", 551 | "\n", 552 | "\n" 553 | ] 554 | } 555 | ] 556 | } -------------------------------------------------------------------------------- /generative-adversarial-network/tensorflow/deep_convolutional_gan.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "accelerator": "GPU", 6 | "colab": { 7 | "name": "dcgan.ipynb", 8 | "provenance": [], 9 | "private_outputs": true, 10 | "collapsed_sections": [], 11 | "toc_visible": true 12 | }, 13 | "kernelspec": { 14 | "display_name": "Python 3", 15 | "name": "python3" 16 | } 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "rHOjD7FELArb", 23 | "colab_type": "text" 24 | }, 25 | "source": [ 26 | "\n", 27 | " \n", 35 | " \n", 44 | "
\n", 28 | " Run\n", 32 | " in Google Colab\n", 34 | " \n", 36 | " View source on GitHub\n", 43 | "
\n" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": { 50 | "colab_type": "text", 51 | "id": "rF2x3qooyBTI" 52 | }, 53 | "source": [ 54 | "# Deep Convolutional Generative Adversarial Network" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": { 60 | "colab_type": "text", 61 | "id": "ITZuApL56Mny" 62 | }, 63 | "source": [ 64 | "This tutorial demonstrates how to generate images of handwritten digits using a [Deep Convolutional Generative Adversarial Network](https://arxiv.org/pdf/1511.06434.pdf) (DCGAN). The code is written using the [Keras Sequential API](https://www.tensorflow.org/guide/keras) with a `tf.GradientTape` training loop." 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": { 70 | "colab_type": "text", 71 | "id": "2MbKJY38Puy9" 72 | }, 73 | "source": [ 74 | "## What are GANs?\n", 75 | "[Generative Adversarial Networks](https://arxiv.org/abs/1406.2661) (GANs) are one of the most interesting ideas in computer science today. Two models are trained simultaneously by an adversarial process. A *generator* (\"the artist\") learns to create images that look real, while a *discriminator* (\"the art critic\") learns to tell real images apart from fakes.\n", 76 | "\n", 77 | "![A diagram of a generator and discriminator](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/gan1.png?raw=1)\n", 78 | "\n", 79 | "During training, the *generator* progressively becomes better at creating images that look real, while the *discriminator* becomes better at telling them apart. The process reaches equilibrium when the *discriminator* can no longer distinguish real images from fakes.\n", 80 | "\n", 81 | "![A second diagram of a generator and discriminator](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/gan2.png?raw=1)\n", 82 | "\n", 83 | "This notebook demonstrates this process on the MNIST dataset. The following animation shows a series of images produced by the *generator* as it was trained for 50 epochs. The images begin as random noise, and increasingly resemble hand written digits over time.\n", 84 | "\n", 85 | "![sample output](https://tensorflow.org/images/gan/dcgan.gif)\n", 86 | "\n", 87 | "To learn more about GANs, we recommend MIT's [Intro to Deep Learning](http://introtodeeplearning.com/) course." 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": { 93 | "colab_type": "text", 94 | "id": "e1_Y75QXJS6h" 95 | }, 96 | "source": [ 97 | "### Import TensorFlow and other libraries" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "metadata": { 103 | "colab_type": "code", 104 | "id": "J5oue0oqCkZZ", 105 | "colab": {} 106 | }, 107 | "source": [ 108 | "from __future__ import absolute_import, division, print_function, unicode_literals" 109 | ], 110 | "execution_count": 0, 111 | "outputs": [] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "metadata": { 116 | "colab_type": "code", 117 | "id": "g5RstiiB8V-z", 118 | "colab": {} 119 | }, 120 | "source": [ 121 | "try:\n", 122 | " # %tensorflow_version only exists in Colab.\n", 123 | " %tensorflow_version 2.x\n", 124 | "except Exception:\n", 125 | " pass\n" 126 | ], 127 | "execution_count": 0, 128 | "outputs": [] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "metadata": { 133 | "colab_type": "code", 134 | "id": "WZKbyU2-AiY-", 135 | "colab": {} 136 | }, 137 | "source": [ 138 | "import tensorflow as tf" 139 | ], 140 | "execution_count": 0, 141 | "outputs": [] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "metadata": { 146 | "colab_type": "code", 147 | "id": "wx-zNbLqB4K8", 148 | "colab": {} 149 | }, 150 | "source": [ 151 | "tf.__version__" 152 | ], 153 | "execution_count": 0, 154 | "outputs": [] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "metadata": { 159 | "colab_type": "code", 160 | "id": "YzTlj4YdCip_", 161 | "colab": {} 162 | }, 163 | "source": [ 164 | "# To generate GIFs\n", 165 | "!pip install imageio" 166 | ], 167 | "execution_count": 0, 168 | "outputs": [] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "metadata": { 173 | "colab_type": "code", 174 | "id": "YfIk2es3hJEd", 175 | "colab": {} 176 | }, 177 | "source": [ 178 | "import glob\n", 179 | "import imageio\n", 180 | "import matplotlib.pyplot as plt\n", 181 | "import numpy as np\n", 182 | "import os\n", 183 | "import PIL\n", 184 | "from tensorflow.keras import layers\n", 185 | "import time\n", 186 | "\n", 187 | "from IPython import display" 188 | ], 189 | "execution_count": 0, 190 | "outputs": [] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "metadata": { 195 | "colab_type": "text", 196 | "id": "iYn4MdZnKCey" 197 | }, 198 | "source": [ 199 | "### Load and prepare the dataset\n", 200 | "\n", 201 | "You will use the MNIST dataset to train the generator and the discriminator. The generator will generate handwritten digits resembling the MNIST data." 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "metadata": { 207 | "colab_type": "code", 208 | "id": "a4fYMGxGhrna", 209 | "colab": {} 210 | }, 211 | "source": [ 212 | "(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()" 213 | ], 214 | "execution_count": 0, 215 | "outputs": [] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "metadata": { 220 | "colab_type": "code", 221 | "id": "NFC2ghIdiZYE", 222 | "colab": {} 223 | }, 224 | "source": [ 225 | "train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')\n", 226 | "train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]" 227 | ], 228 | "execution_count": 0, 229 | "outputs": [] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "metadata": { 234 | "colab_type": "code", 235 | "id": "S4PIDhoDLbsZ", 236 | "colab": {} 237 | }, 238 | "source": [ 239 | "BUFFER_SIZE = 60000\n", 240 | "BATCH_SIZE = 256" 241 | ], 242 | "execution_count": 0, 243 | "outputs": [] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "metadata": { 248 | "colab_type": "code", 249 | "id": "-yKCCQOoJ7cn", 250 | "colab": {} 251 | }, 252 | "source": [ 253 | "# Batch and shuffle the data\n", 254 | "train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)" 255 | ], 256 | "execution_count": 0, 257 | "outputs": [] 258 | }, 259 | { 260 | "cell_type": "markdown", 261 | "metadata": { 262 | "colab_type": "text", 263 | "id": "THY-sZMiQ4UV" 264 | }, 265 | "source": [ 266 | "## Create the models\n", 267 | "\n", 268 | "Both the generator and discriminator are defined using the [Keras Sequential API](https://www.tensorflow.org/guide/keras#sequential_model)." 269 | ] 270 | }, 271 | { 272 | "cell_type": "markdown", 273 | "metadata": { 274 | "colab_type": "text", 275 | "id": "-tEyxE-GMC48" 276 | }, 277 | "source": [ 278 | "### The Generator\n", 279 | "\n", 280 | "The generator uses `tf.keras.layers.Conv2DTranspose` (upsampling) layers to produce an image from a seed (random noise). Start with a `Dense` layer that takes this seed as input, then upsample several times until you reach the desired image size of 28x28x1. Notice the `tf.keras.layers.LeakyReLU` activation for each layer, except the output layer which uses tanh." 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "metadata": { 286 | "colab_type": "code", 287 | "id": "6bpTcDqoLWjY", 288 | "colab": {} 289 | }, 290 | "source": [ 291 | "def make_generator_model():\n", 292 | " model = tf.keras.Sequential()\n", 293 | " model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))\n", 294 | " model.add(layers.BatchNormalization())\n", 295 | " model.add(layers.LeakyReLU())\n", 296 | "\n", 297 | " model.add(layers.Reshape((7, 7, 256)))\n", 298 | " assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size\n", 299 | "\n", 300 | " model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))\n", 301 | " assert model.output_shape == (None, 7, 7, 128)\n", 302 | " model.add(layers.BatchNormalization())\n", 303 | " model.add(layers.LeakyReLU())\n", 304 | "\n", 305 | " model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))\n", 306 | " assert model.output_shape == (None, 14, 14, 64)\n", 307 | " model.add(layers.BatchNormalization())\n", 308 | " model.add(layers.LeakyReLU())\n", 309 | "\n", 310 | " model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))\n", 311 | " assert model.output_shape == (None, 28, 28, 1)\n", 312 | "\n", 313 | " return model" 314 | ], 315 | "execution_count": 0, 316 | "outputs": [] 317 | }, 318 | { 319 | "cell_type": "markdown", 320 | "metadata": { 321 | "colab_type": "text", 322 | "id": "GyWgG09LCSJl" 323 | }, 324 | "source": [ 325 | "Use the (as yet untrained) generator to create an image." 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "metadata": { 331 | "colab_type": "code", 332 | "id": "gl7jcC7TdPTG", 333 | "colab": {} 334 | }, 335 | "source": [ 336 | "generator = make_generator_model()\n", 337 | "\n", 338 | "noise = tf.random.normal([1, 100])\n", 339 | "generated_image = generator(noise, training=False)\n", 340 | "\n", 341 | "plt.imshow(generated_image[0, :, :, 0], cmap='gray')" 342 | ], 343 | "execution_count": 0, 344 | "outputs": [] 345 | }, 346 | { 347 | "cell_type": "markdown", 348 | "metadata": { 349 | "colab_type": "text", 350 | "id": "D0IKnaCtg6WE" 351 | }, 352 | "source": [ 353 | "### The Discriminator\n", 354 | "\n", 355 | "The discriminator is a CNN-based image classifier." 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "metadata": { 361 | "colab_type": "code", 362 | "id": "dw2tPLmk2pEP", 363 | "colab": {} 364 | }, 365 | "source": [ 366 | "def make_discriminator_model():\n", 367 | " model = tf.keras.Sequential()\n", 368 | " model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',\n", 369 | " input_shape=[28, 28, 1]))\n", 370 | " model.add(layers.LeakyReLU())\n", 371 | " model.add(layers.Dropout(0.3))\n", 372 | "\n", 373 | " model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))\n", 374 | " model.add(layers.LeakyReLU())\n", 375 | " model.add(layers.Dropout(0.3))\n", 376 | "\n", 377 | " model.add(layers.Flatten())\n", 378 | " model.add(layers.Dense(1))\n", 379 | "\n", 380 | " return model" 381 | ], 382 | "execution_count": 0, 383 | "outputs": [] 384 | }, 385 | { 386 | "cell_type": "markdown", 387 | "metadata": { 388 | "colab_type": "text", 389 | "id": "QhPneagzCaQv" 390 | }, 391 | "source": [ 392 | "Use the (as yet untrained) discriminator to classify the generated images as real or fake. The model will be trained to output positive values for real images, and negative values for fake images." 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "metadata": { 398 | "colab_type": "code", 399 | "id": "gDkA05NE6QMs", 400 | "colab": {} 401 | }, 402 | "source": [ 403 | "discriminator = make_discriminator_model()\n", 404 | "decision = discriminator(generated_image)\n", 405 | "print (decision)" 406 | ], 407 | "execution_count": 0, 408 | "outputs": [] 409 | }, 410 | { 411 | "cell_type": "markdown", 412 | "metadata": { 413 | "colab_type": "text", 414 | "id": "0FMYgY_mPfTi" 415 | }, 416 | "source": [ 417 | "## Define the loss and optimizers\n", 418 | "\n", 419 | "Define loss functions and optimizers for both models.\n" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "metadata": { 425 | "colab_type": "code", 426 | "id": "psQfmXxYKU3X", 427 | "colab": {} 428 | }, 429 | "source": [ 430 | "# This method returns a helper function to compute cross entropy loss\n", 431 | "cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)" 432 | ], 433 | "execution_count": 0, 434 | "outputs": [] 435 | }, 436 | { 437 | "cell_type": "markdown", 438 | "metadata": { 439 | "colab_type": "text", 440 | "id": "PKY_iPSPNWoj" 441 | }, 442 | "source": [ 443 | "### Discriminator loss\n", 444 | "\n", 445 | "This method quantifies how well the discriminator is able to distinguish real images from fakes. It compares the discriminator's predictions on real images to an array of 1s, and the discriminator's predictions on fake (generated) images to an array of 0s." 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "metadata": { 451 | "colab_type": "code", 452 | "id": "wkMNfBWlT-PV", 453 | "colab": {} 454 | }, 455 | "source": [ 456 | "def discriminator_loss(real_output, fake_output):\n", 457 | " real_loss = cross_entropy(tf.ones_like(real_output), real_output)\n", 458 | " fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)\n", 459 | " total_loss = real_loss + fake_loss\n", 460 | " return total_loss" 461 | ], 462 | "execution_count": 0, 463 | "outputs": [] 464 | }, 465 | { 466 | "cell_type": "markdown", 467 | "metadata": { 468 | "colab_type": "text", 469 | "id": "Jd-3GCUEiKtv" 470 | }, 471 | "source": [ 472 | "### Generator loss\n", 473 | "The generator's loss quantifies how well it was able to trick the discriminator. Intuitively, if the generator is performing well, the discriminator will classify the fake images as real (or 1). Here, we will compare the discriminators decisions on the generated images to an array of 1s." 474 | ] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "metadata": { 479 | "colab_type": "code", 480 | "id": "90BIcCKcDMxz", 481 | "colab": {} 482 | }, 483 | "source": [ 484 | "def generator_loss(fake_output):\n", 485 | " return cross_entropy(tf.ones_like(fake_output), fake_output)" 486 | ], 487 | "execution_count": 0, 488 | "outputs": [] 489 | }, 490 | { 491 | "cell_type": "markdown", 492 | "metadata": { 493 | "colab_type": "text", 494 | "id": "MgIc7i0th_Iu" 495 | }, 496 | "source": [ 497 | "The discriminator and the generator optimizers are different since we will train two networks separately." 498 | ] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "metadata": { 503 | "colab_type": "code", 504 | "id": "iWCn_PVdEJZ7", 505 | "colab": {} 506 | }, 507 | "source": [ 508 | "generator_optimizer = tf.keras.optimizers.Adam(1e-4)\n", 509 | "discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)" 510 | ], 511 | "execution_count": 0, 512 | "outputs": [] 513 | }, 514 | { 515 | "cell_type": "markdown", 516 | "metadata": { 517 | "colab_type": "text", 518 | "id": "mWtinsGDPJlV" 519 | }, 520 | "source": [ 521 | "### Save checkpoints\n", 522 | "This notebook also demonstrates how to save and restore models, which can be helpful in case a long running training task is interrupted." 523 | ] 524 | }, 525 | { 526 | "cell_type": "code", 527 | "metadata": { 528 | "colab_type": "code", 529 | "id": "CA1w-7s2POEy", 530 | "colab": {} 531 | }, 532 | "source": [ 533 | "checkpoint_dir = './training_checkpoints'\n", 534 | "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", 535 | "checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,\n", 536 | " discriminator_optimizer=discriminator_optimizer,\n", 537 | " generator=generator,\n", 538 | " discriminator=discriminator)" 539 | ], 540 | "execution_count": 0, 541 | "outputs": [] 542 | }, 543 | { 544 | "cell_type": "markdown", 545 | "metadata": { 546 | "colab_type": "text", 547 | "id": "Rw1fkAczTQYh" 548 | }, 549 | "source": [ 550 | "## Define the training loop\n", 551 | "\n" 552 | ] 553 | }, 554 | { 555 | "cell_type": "code", 556 | "metadata": { 557 | "colab_type": "code", 558 | "id": "NS2GWywBbAWo", 559 | "colab": {} 560 | }, 561 | "source": [ 562 | "EPOCHS = 50\n", 563 | "noise_dim = 100\n", 564 | "num_examples_to_generate = 16\n", 565 | "\n", 566 | "# We will reuse this seed overtime (so it's easier)\n", 567 | "# to visualize progress in the animated GIF)\n", 568 | "seed = tf.random.normal([num_examples_to_generate, noise_dim])" 569 | ], 570 | "execution_count": 0, 571 | "outputs": [] 572 | }, 573 | { 574 | "cell_type": "markdown", 575 | "metadata": { 576 | "colab_type": "text", 577 | "id": "jylSonrqSWfi" 578 | }, 579 | "source": [ 580 | "The training loop begins with generator receiving a random seed as input. That seed is used to produce an image. The discriminator is then used to classify real images (drawn from the training set) and fakes images (produced by the generator). The loss is calculated for each of these models, and the gradients are used to update the generator and discriminator." 581 | ] 582 | }, 583 | { 584 | "cell_type": "code", 585 | "metadata": { 586 | "colab_type": "code", 587 | "id": "3t5ibNo05jCB", 588 | "colab": {} 589 | }, 590 | "source": [ 591 | "# Notice the use of `tf.function`\n", 592 | "# This annotation causes the function to be \"compiled\".\n", 593 | "@tf.function\n", 594 | "def train_step(images):\n", 595 | " noise = tf.random.normal([BATCH_SIZE, noise_dim])\n", 596 | "\n", 597 | " with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n", 598 | " generated_images = generator(noise, training=True)\n", 599 | "\n", 600 | " real_output = discriminator(images, training=True)\n", 601 | " fake_output = discriminator(generated_images, training=True)\n", 602 | "\n", 603 | " gen_loss = generator_loss(fake_output)\n", 604 | " disc_loss = discriminator_loss(real_output, fake_output)\n", 605 | "\n", 606 | " gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)\n", 607 | " gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)\n", 608 | "\n", 609 | " generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))\n", 610 | " discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))" 611 | ], 612 | "execution_count": 0, 613 | "outputs": [] 614 | }, 615 | { 616 | "cell_type": "code", 617 | "metadata": { 618 | "colab_type": "code", 619 | "id": "2M7LmLtGEMQJ", 620 | "colab": {} 621 | }, 622 | "source": [ 623 | "def train(dataset, epochs):\n", 624 | " for epoch in range(epochs):\n", 625 | " start = time.time()\n", 626 | "\n", 627 | " for image_batch in dataset:\n", 628 | " train_step(image_batch)\n", 629 | "\n", 630 | " # Produce images for the GIF as we go\n", 631 | " display.clear_output(wait=True)\n", 632 | " generate_and_save_images(generator,\n", 633 | " epoch + 1,\n", 634 | " seed)\n", 635 | "\n", 636 | " # Save the model every 15 epochs\n", 637 | " if (epoch + 1) % 15 == 0:\n", 638 | " checkpoint.save(file_prefix = checkpoint_prefix)\n", 639 | "\n", 640 | " print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))\n", 641 | "\n", 642 | " # Generate after the final epoch\n", 643 | " display.clear_output(wait=True)\n", 644 | " generate_and_save_images(generator,\n", 645 | " epochs,\n", 646 | " seed)" 647 | ], 648 | "execution_count": 0, 649 | "outputs": [] 650 | }, 651 | { 652 | "cell_type": "markdown", 653 | "metadata": { 654 | "colab_type": "text", 655 | "id": "2aFF7Hk3XdeW" 656 | }, 657 | "source": [ 658 | "**Generate and save images**\n", 659 | "\n" 660 | ] 661 | }, 662 | { 663 | "cell_type": "code", 664 | "metadata": { 665 | "colab_type": "code", 666 | "id": "RmdVsmvhPxyy", 667 | "colab": {} 668 | }, 669 | "source": [ 670 | "def generate_and_save_images(model, epoch, test_input):\n", 671 | " # Notice `training` is set to False.\n", 672 | " # This is so all layers run in inference mode (batchnorm).\n", 673 | " predictions = model(test_input, training=False)\n", 674 | "\n", 675 | " fig = plt.figure(figsize=(4,4))\n", 676 | "\n", 677 | " for i in range(predictions.shape[0]):\n", 678 | " plt.subplot(4, 4, i+1)\n", 679 | " plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')\n", 680 | " plt.axis('off')\n", 681 | "\n", 682 | " plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n", 683 | " plt.show()" 684 | ], 685 | "execution_count": 0, 686 | "outputs": [] 687 | }, 688 | { 689 | "cell_type": "markdown", 690 | "metadata": { 691 | "colab_type": "text", 692 | "id": "dZrd4CdjR-Fp" 693 | }, 694 | "source": [ 695 | "## Train the model\n", 696 | "Call the `train()` method defined above to train the generator and discriminator simultaneously. Note, training GANs can be tricky. It's important that the generator and discriminator do not overpower each other (e.g., that they train at a similar rate).\n", 697 | "\n", 698 | "At the beginning of the training, the generated images look like random noise. As training progresses, the generated digits will look increasingly real. After about 50 epochs, they resemble MNIST digits. This may take about one minute / epoch with the default settings on Colab." 699 | ] 700 | }, 701 | { 702 | "cell_type": "code", 703 | "metadata": { 704 | "colab_type": "code", 705 | "id": "Ly3UN0SLLY2l", 706 | "colab": {} 707 | }, 708 | "source": [ 709 | "train(train_dataset, EPOCHS)" 710 | ], 711 | "execution_count": 0, 712 | "outputs": [] 713 | }, 714 | { 715 | "cell_type": "markdown", 716 | "metadata": { 717 | "colab_type": "text", 718 | "id": "rfM4YcPVPkNO" 719 | }, 720 | "source": [ 721 | "Restore the latest checkpoint." 722 | ] 723 | }, 724 | { 725 | "cell_type": "code", 726 | "metadata": { 727 | "colab_type": "code", 728 | "id": "XhXsd0srPo8c", 729 | "colab": {} 730 | }, 731 | "source": [ 732 | "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" 733 | ], 734 | "execution_count": 0, 735 | "outputs": [] 736 | }, 737 | { 738 | "cell_type": "markdown", 739 | "metadata": { 740 | "colab_type": "text", 741 | "id": "P4M_vIbUi7c0" 742 | }, 743 | "source": [ 744 | "## Create a GIF\n" 745 | ] 746 | }, 747 | { 748 | "cell_type": "code", 749 | "metadata": { 750 | "colab_type": "code", 751 | "id": "WfO5wCdclHGL", 752 | "colab": {} 753 | }, 754 | "source": [ 755 | "# Display a single image using the epoch number\n", 756 | "def display_image(epoch_no):\n", 757 | " return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))" 758 | ], 759 | "execution_count": 0, 760 | "outputs": [] 761 | }, 762 | { 763 | "cell_type": "code", 764 | "metadata": { 765 | "colab_type": "code", 766 | "id": "5x3q9_Oe5q0A", 767 | "colab": {} 768 | }, 769 | "source": [ 770 | "display_image(EPOCHS)" 771 | ], 772 | "execution_count": 0, 773 | "outputs": [] 774 | }, 775 | { 776 | "cell_type": "markdown", 777 | "metadata": { 778 | "colab_type": "text", 779 | "id": "NywiH3nL8guF" 780 | }, 781 | "source": [ 782 | "Use `imageio` to create an animated gif using the images saved during training." 783 | ] 784 | }, 785 | { 786 | "cell_type": "code", 787 | "metadata": { 788 | "colab_type": "code", 789 | "id": "IGKQgENQ8lEI", 790 | "colab": {} 791 | }, 792 | "source": [ 793 | "anim_file = 'dcgan.gif'\n", 794 | "\n", 795 | "with imageio.get_writer(anim_file, mode='I') as writer:\n", 796 | " filenames = glob.glob('image*.png')\n", 797 | " filenames = sorted(filenames)\n", 798 | " last = -1\n", 799 | " for i,filename in enumerate(filenames):\n", 800 | " frame = 2*(i**0.5)\n", 801 | " if round(frame) > round(last):\n", 802 | " last = frame\n", 803 | " else:\n", 804 | " continue\n", 805 | " image = imageio.imread(filename)\n", 806 | " writer.append_data(image)\n", 807 | " image = imageio.imread(filename)\n", 808 | " writer.append_data(image)\n", 809 | "\n", 810 | "import IPython\n", 811 | "if IPython.version_info > (6,2,0,''):\n", 812 | " display.Image(filename=anim_file)" 813 | ], 814 | "execution_count": 0, 815 | "outputs": [] 816 | }, 817 | { 818 | "cell_type": "markdown", 819 | "metadata": { 820 | "colab_type": "text", 821 | "id": "cGhC3-fMWSwl" 822 | }, 823 | "source": [ 824 | "If you're working in Colab you can download the animation with the code below:" 825 | ] 826 | }, 827 | { 828 | "cell_type": "code", 829 | "metadata": { 830 | "colab_type": "code", 831 | "id": "uV0yiKpzNP1b", 832 | "colab": {} 833 | }, 834 | "source": [ 835 | "try:\n", 836 | " from google.colab import files\n", 837 | "except ImportError:\n", 838 | " pass\n", 839 | "else:\n", 840 | " files.download(anim_file)" 841 | ], 842 | "execution_count": 0, 843 | "outputs": [] 844 | }, 845 | { 846 | "cell_type": "markdown", 847 | "metadata": { 848 | "colab_type": "text", 849 | "id": "k6qC-SbjK0yW" 850 | }, 851 | "source": [ 852 | "## Next steps\n" 853 | ] 854 | }, 855 | { 856 | "cell_type": "markdown", 857 | "metadata": { 858 | "colab_type": "text", 859 | "id": "xjjkT9KAK6H7" 860 | }, 861 | "source": [ 862 | "This tutorial has shown the complete code necessary to write and train a GAN. As a next step, you might like to experiment with a different dataset, for example the Large-scale Celeb Faces Attributes (CelebA) dataset [available on Kaggle](https://www.kaggle.com/jessicali9530/celeba-dataset). To learn more about GANs we recommend the [NIPS 2016 Tutorial: Generative Adversarial Networks](https://arxiv.org/abs/1701.00160).\n" 863 | ] 864 | }, 865 | { 866 | "cell_type": "markdown", 867 | "metadata": { 868 | "id": "qK0yItGJK4pf", 869 | "colab_type": "text" 870 | }, 871 | "source": [ 872 | "# License and Credits" 873 | ] 874 | }, 875 | { 876 | "cell_type": "markdown", 877 | "metadata": { 878 | "colab_type": "text", 879 | "id": "_jQ1tEQCxwRx" 880 | }, 881 | "source": [ 882 | "##### Copyright 2019 The TensorFlow Authors." 883 | ] 884 | }, 885 | { 886 | "cell_type": "code", 887 | "metadata": { 888 | "cellView": "form", 889 | "colab_type": "code", 890 | "id": "V_sgB_5dx1f1", 891 | "colab": {} 892 | }, 893 | "source": [ 894 | "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", 895 | "# you may not use this file except in compliance with the License.\n", 896 | "# You may obtain a copy of the License at\n", 897 | "#\n", 898 | "# https://www.apache.org/licenses/LICENSE-2.0\n", 899 | "#\n", 900 | "# Unless required by applicable law or agreed to in writing, software\n", 901 | "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", 902 | "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", 903 | "# See the License for the specific language governing permissions and\n", 904 | "# limitations under the License." 905 | ], 906 | "execution_count": 0, 907 | "outputs": [] 908 | } 909 | ] 910 | } -------------------------------------------------------------------------------- /getting-started/pytorch/data_loading_processing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "display_name": "Python 3", 7 | "language": "python", 8 | "name": "python3" 9 | }, 10 | "language_info": { 11 | "codemirror_mode": { 12 | "name": "ipython", 13 | "version": 3 14 | }, 15 | "file_extension": ".py", 16 | "mimetype": "text/x-python", 17 | "name": "python", 18 | "nbconvert_exporter": "python", 19 | "pygments_lexer": "ipython3", 20 | "version": "3.6.6" 21 | }, 22 | "colab": { 23 | "name": "data_loading_tutorial.ipynb", 24 | "provenance": [] 25 | } 26 | }, 27 | "cells": [ 28 | { 29 | "cell_type": "code", 30 | "metadata": { 31 | "id": "4O6Fl0wNeQYq", 32 | "colab_type": "code", 33 | "colab": {} 34 | }, 35 | "source": [ 36 | "%matplotlib inline" 37 | ], 38 | "execution_count": 0, 39 | "outputs": [] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": { 44 | "id": "6oBIoiDWe63w", 45 | "colab_type": "text" 46 | }, 47 | "source": [ 48 | "\n", 49 | " \n", 52 | " \n", 55 | "
\n", 50 | " Run in Google Colab\n", 51 | " \n", 53 | " View source on GitHub\n", 54 | "
" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": { 61 | "id": "bYRAYLbbeQYy", 62 | "colab_type": "text" 63 | }, 64 | "source": [ 65 | "Data Loading and Processing Tutorial\n", 66 | "====================================\n", 67 | "\n", 68 | "\n", 69 | "**Author**: `Sasank Chilamkurthy `_\n", 70 | "\n", 71 | "A lot of effort in solving any machine learning problem goes in to\n", 72 | "preparing the data. PyTorch provides many tools to make data loading\n", 73 | "easy and hopefully, to make your code more readable. In this tutorial,\n", 74 | "we will see how to load and preprocess/augment data from a non trivial\n", 75 | "dataset.\n", 76 | "\n", 77 | "To run this tutorial, please make sure the following packages are\n", 78 | "installed:\n", 79 | "\n", 80 | "- ``scikit-image``: For image io and transforms\n", 81 | "- ``pandas``: For easier csv parsing\n", 82 | "\n", 83 | "\n" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "metadata": { 89 | "id": "YIrwL18qeQYz", 90 | "colab_type": "code", 91 | "colab": {} 92 | }, 93 | "source": [ 94 | "from __future__ import print_function, division\n", 95 | "import os\n", 96 | "import torch\n", 97 | "import pandas as pd\n", 98 | "from skimage import io, transform\n", 99 | "import numpy as np\n", 100 | "import matplotlib.pyplot as plt\n", 101 | "from torch.utils.data import Dataset, DataLoader\n", 102 | "from torchvision import transforms, utils\n", 103 | "\n", 104 | "# Ignore warnings\n", 105 | "import warnings\n", 106 | "warnings.filterwarnings(\"ignore\")\n", 107 | "\n", 108 | "plt.ion() # interactive mode" 109 | ], 110 | "execution_count": 0, 111 | "outputs": [] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": { 116 | "id": "E6hrf5mpeQY4", 117 | "colab_type": "text" 118 | }, 119 | "source": [ 120 | "The dataset we are going to deal with is that of facial pose.\n", 121 | "This means that a face is annotated like this:\n", 122 | "\n", 123 | ".. figure:: /_static/img/landmarked_face2.png\n", 124 | " :width: 400\n", 125 | "\n", 126 | "Over all, 68 different landmark points are annotated for each face.\n", 127 | "\n", 128 | "

Note

Download the dataset from `here `_\n", 129 | " so that the images are in a directory named 'faces/'.\n", 130 | " This dataset was actually\n", 131 | " generated by applying excellent `dlib's pose\n", 132 | " estimation `__\n", 133 | " on a few images from imagenet tagged as 'face'.

\n", 134 | "\n", 135 | "Dataset comes with a csv file with annotations which looks like this:\n", 136 | "\n", 137 | "::\n", 138 | "\n", 139 | " image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y\n", 140 | " 0805personali01.jpg,27,83,27,98, ... 84,134\n", 141 | " 1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312\n", 142 | "\n", 143 | "Let's quickly read the CSV and get the annotations in an (N, 2) array where N\n", 144 | "is the number of landmarks.\n", 145 | "\n", 146 | "\n" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "metadata": { 152 | "id": "3L2D5AtVeQY5", 153 | "colab_type": "code", 154 | "colab": {} 155 | }, 156 | "source": [ 157 | "landmarks_frame = pd.read_csv('faces/face_landmarks.csv')\n", 158 | "\n", 159 | "n = 65\n", 160 | "img_name = landmarks_frame.iloc[n, 0]\n", 161 | "landmarks = landmarks_frame.iloc[n, 1:].as_matrix()\n", 162 | "landmarks = landmarks.astype('float').reshape(-1, 2)\n", 163 | "\n", 164 | "print('Image name: {}'.format(img_name))\n", 165 | "print('Landmarks shape: {}'.format(landmarks.shape))\n", 166 | "print('First 4 Landmarks: {}'.format(landmarks[:4]))" 167 | ], 168 | "execution_count": 0, 169 | "outputs": [] 170 | }, 171 | { 172 | "cell_type": "markdown", 173 | "metadata": { 174 | "id": "EPohHit9eQY9", 175 | "colab_type": "text" 176 | }, 177 | "source": [ 178 | "Let's write a simple helper function to show an image and its landmarks\n", 179 | "and use it to show a sample.\n", 180 | "\n", 181 | "\n" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "metadata": { 187 | "id": "-aKwzlHkeQY_", 188 | "colab_type": "code", 189 | "colab": {} 190 | }, 191 | "source": [ 192 | "def show_landmarks(image, landmarks):\n", 193 | " \"\"\"Show image with landmarks\"\"\"\n", 194 | " plt.imshow(image)\n", 195 | " plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')\n", 196 | " plt.pause(0.001) # pause a bit so that plots are updated\n", 197 | "\n", 198 | "plt.figure()\n", 199 | "show_landmarks(io.imread(os.path.join('faces/', img_name)),\n", 200 | " landmarks)\n", 201 | "plt.show()" 202 | ], 203 | "execution_count": 0, 204 | "outputs": [] 205 | }, 206 | { 207 | "cell_type": "markdown", 208 | "metadata": { 209 | "id": "ljNZ5Vt-eQZD", 210 | "colab_type": "text" 211 | }, 212 | "source": [ 213 | "Dataset class\n", 214 | "-------------\n", 215 | "\n", 216 | "``torch.utils.data.Dataset`` is an abstract class representing a\n", 217 | "dataset.\n", 218 | "Your custom dataset should inherit ``Dataset`` and override the following\n", 219 | "methods:\n", 220 | "\n", 221 | "- ``__len__`` so that ``len(dataset)`` returns the size of the dataset.\n", 222 | "- ``__getitem__`` to support the indexing such that ``dataset[i]`` can\n", 223 | " be used to get $i$\\ th sample\n", 224 | "\n", 225 | "Let's create a dataset class for our face landmarks dataset. We will\n", 226 | "read the csv in ``__init__`` but leave the reading of images to\n", 227 | "``__getitem__``. This is memory efficient because all the images are not\n", 228 | "stored in the memory at once but read as required.\n", 229 | "\n", 230 | "Sample of our dataset will be a dict\n", 231 | "``{'image': image, 'landmarks': landmarks}``. Our datset will take an\n", 232 | "optional argument ``transform`` so that any required processing can be\n", 233 | "applied on the sample. We will see the usefulness of ``transform`` in the\n", 234 | "next section.\n", 235 | "\n", 236 | "\n" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "metadata": { 242 | "id": "ND5EDjYYeQZE", 243 | "colab_type": "code", 244 | "colab": {} 245 | }, 246 | "source": [ 247 | "class FaceLandmarksDataset(Dataset):\n", 248 | " \"\"\"Face Landmarks dataset.\"\"\"\n", 249 | "\n", 250 | " def __init__(self, csv_file, root_dir, transform=None):\n", 251 | " \"\"\"\n", 252 | " Args:\n", 253 | " csv_file (string): Path to the csv file with annotations.\n", 254 | " root_dir (string): Directory with all the images.\n", 255 | " transform (callable, optional): Optional transform to be applied\n", 256 | " on a sample.\n", 257 | " \"\"\"\n", 258 | " self.landmarks_frame = pd.read_csv(csv_file)\n", 259 | " self.root_dir = root_dir\n", 260 | " self.transform = transform\n", 261 | "\n", 262 | " def __len__(self):\n", 263 | " return len(self.landmarks_frame)\n", 264 | "\n", 265 | " def __getitem__(self, idx):\n", 266 | " img_name = os.path.join(self.root_dir,\n", 267 | " self.landmarks_frame.iloc[idx, 0])\n", 268 | " image = io.imread(img_name)\n", 269 | " landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()\n", 270 | " landmarks = landmarks.astype('float').reshape(-1, 2)\n", 271 | " sample = {'image': image, 'landmarks': landmarks}\n", 272 | "\n", 273 | " if self.transform:\n", 274 | " sample = self.transform(sample)\n", 275 | "\n", 276 | " return sample" 277 | ], 278 | "execution_count": 0, 279 | "outputs": [] 280 | }, 281 | { 282 | "cell_type": "markdown", 283 | "metadata": { 284 | "id": "q2dpDjF0eQZK", 285 | "colab_type": "text" 286 | }, 287 | "source": [ 288 | "Let's instantiate this class and iterate through the data samples. We\n", 289 | "will print the sizes of first 4 samples and show their landmarks.\n", 290 | "\n", 291 | "\n" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "metadata": { 297 | "id": "KBzeTG3MeQZM", 298 | "colab_type": "code", 299 | "colab": {} 300 | }, 301 | "source": [ 302 | "face_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv',\n", 303 | " root_dir='faces/')\n", 304 | "\n", 305 | "fig = plt.figure()\n", 306 | "\n", 307 | "for i in range(len(face_dataset)):\n", 308 | " sample = face_dataset[i]\n", 309 | "\n", 310 | " print(i, sample['image'].shape, sample['landmarks'].shape)\n", 311 | "\n", 312 | " ax = plt.subplot(1, 4, i + 1)\n", 313 | " plt.tight_layout()\n", 314 | " ax.set_title('Sample #{}'.format(i))\n", 315 | " ax.axis('off')\n", 316 | " show_landmarks(**sample)\n", 317 | "\n", 318 | " if i == 3:\n", 319 | " plt.show()\n", 320 | " break" 321 | ], 322 | "execution_count": 0, 323 | "outputs": [] 324 | }, 325 | { 326 | "cell_type": "markdown", 327 | "metadata": { 328 | "id": "hOOWwulEeQZP", 329 | "colab_type": "text" 330 | }, 331 | "source": [ 332 | "Transforms\n", 333 | "----------\n", 334 | "\n", 335 | "One issue we can see from the above is that the samples are not of the\n", 336 | "same size. Most neural networks expect the images of a fixed size.\n", 337 | "Therefore, we will need to write some prepocessing code.\n", 338 | "Let's create three transforms:\n", 339 | "\n", 340 | "- ``Rescale``: to scale the image\n", 341 | "- ``RandomCrop``: to crop from image randomly. This is data\n", 342 | " augmentation.\n", 343 | "- ``ToTensor``: to convert the numpy images to torch images (we need to\n", 344 | " swap axes).\n", 345 | "\n", 346 | "We will write them as callable classes instead of simple functions so\n", 347 | "that parameters of the transform need not be passed everytime it's\n", 348 | "called. For this, we just need to implement ``__call__`` method and\n", 349 | "if required, ``__init__`` method. We can then use a transform like this:\n", 350 | "\n", 351 | "::\n", 352 | "\n", 353 | " tsfm = Transform(params)\n", 354 | " transformed_sample = tsfm(sample)\n", 355 | "\n", 356 | "Observe below how these transforms had to be applied both on the image and\n", 357 | "landmarks.\n", 358 | "\n", 359 | "\n" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "metadata": { 365 | "id": "qqYK6uNzeQZQ", 366 | "colab_type": "code", 367 | "colab": {} 368 | }, 369 | "source": [ 370 | "class Rescale(object):\n", 371 | " \"\"\"Rescale the image in a sample to a given size.\n", 372 | "\n", 373 | " Args:\n", 374 | " output_size (tuple or int): Desired output size. If tuple, output is\n", 375 | " matched to output_size. If int, smaller of image edges is matched\n", 376 | " to output_size keeping aspect ratio the same.\n", 377 | " \"\"\"\n", 378 | "\n", 379 | " def __init__(self, output_size):\n", 380 | " assert isinstance(output_size, (int, tuple))\n", 381 | " self.output_size = output_size\n", 382 | "\n", 383 | " def __call__(self, sample):\n", 384 | " image, landmarks = sample['image'], sample['landmarks']\n", 385 | "\n", 386 | " h, w = image.shape[:2]\n", 387 | " if isinstance(self.output_size, int):\n", 388 | " if h > w:\n", 389 | " new_h, new_w = self.output_size * h / w, self.output_size\n", 390 | " else:\n", 391 | " new_h, new_w = self.output_size, self.output_size * w / h\n", 392 | " else:\n", 393 | " new_h, new_w = self.output_size\n", 394 | "\n", 395 | " new_h, new_w = int(new_h), int(new_w)\n", 396 | "\n", 397 | " img = transform.resize(image, (new_h, new_w))\n", 398 | "\n", 399 | " # h and w are swapped for landmarks because for images,\n", 400 | " # x and y axes are axis 1 and 0 respectively\n", 401 | " landmarks = landmarks * [new_w / w, new_h / h]\n", 402 | "\n", 403 | " return {'image': img, 'landmarks': landmarks}\n", 404 | "\n", 405 | "\n", 406 | "class RandomCrop(object):\n", 407 | " \"\"\"Crop randomly the image in a sample.\n", 408 | "\n", 409 | " Args:\n", 410 | " output_size (tuple or int): Desired output size. If int, square crop\n", 411 | " is made.\n", 412 | " \"\"\"\n", 413 | "\n", 414 | " def __init__(self, output_size):\n", 415 | " assert isinstance(output_size, (int, tuple))\n", 416 | " if isinstance(output_size, int):\n", 417 | " self.output_size = (output_size, output_size)\n", 418 | " else:\n", 419 | " assert len(output_size) == 2\n", 420 | " self.output_size = output_size\n", 421 | "\n", 422 | " def __call__(self, sample):\n", 423 | " image, landmarks = sample['image'], sample['landmarks']\n", 424 | "\n", 425 | " h, w = image.shape[:2]\n", 426 | " new_h, new_w = self.output_size\n", 427 | "\n", 428 | " top = np.random.randint(0, h - new_h)\n", 429 | " left = np.random.randint(0, w - new_w)\n", 430 | "\n", 431 | " image = image[top: top + new_h,\n", 432 | " left: left + new_w]\n", 433 | "\n", 434 | " landmarks = landmarks - [left, top]\n", 435 | "\n", 436 | " return {'image': image, 'landmarks': landmarks}\n", 437 | "\n", 438 | "\n", 439 | "class ToTensor(object):\n", 440 | " \"\"\"Convert ndarrays in sample to Tensors.\"\"\"\n", 441 | "\n", 442 | " def __call__(self, sample):\n", 443 | " image, landmarks = sample['image'], sample['landmarks']\n", 444 | "\n", 445 | " # swap color axis because\n", 446 | " # numpy image: H x W x C\n", 447 | " # torch image: C X H X W\n", 448 | " image = image.transpose((2, 0, 1))\n", 449 | " return {'image': torch.from_numpy(image),\n", 450 | " 'landmarks': torch.from_numpy(landmarks)}" 451 | ], 452 | "execution_count": 0, 453 | "outputs": [] 454 | }, 455 | { 456 | "cell_type": "markdown", 457 | "metadata": { 458 | "id": "wkdgdpBReQZX", 459 | "colab_type": "text" 460 | }, 461 | "source": [ 462 | "Compose transforms\n", 463 | "~~~~~~~~~~~~~~~~~~\n", 464 | "\n", 465 | "Now, we apply the transforms on an sample.\n", 466 | "\n", 467 | "Let's say we want to rescale the shorter side of the image to 256 and\n", 468 | "then randomly crop a square of size 224 from it. i.e, we want to compose\n", 469 | "``Rescale`` and ``RandomCrop`` transforms.\n", 470 | "``torchvision.transforms.Compose`` is a simple callable class which allows us\n", 471 | "to do this.\n", 472 | "\n", 473 | "\n" 474 | ] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "metadata": { 479 | "id": "rpgZb9c4eQZY", 480 | "colab_type": "code", 481 | "colab": {} 482 | }, 483 | "source": [ 484 | "scale = Rescale(256)\n", 485 | "crop = RandomCrop(128)\n", 486 | "composed = transforms.Compose([Rescale(256),\n", 487 | " RandomCrop(224)])\n", 488 | "\n", 489 | "# Apply each of the above transforms on sample.\n", 490 | "fig = plt.figure()\n", 491 | "sample = face_dataset[65]\n", 492 | "for i, tsfrm in enumerate([scale, crop, composed]):\n", 493 | " transformed_sample = tsfrm(sample)\n", 494 | "\n", 495 | " ax = plt.subplot(1, 3, i + 1)\n", 496 | " plt.tight_layout()\n", 497 | " ax.set_title(type(tsfrm).__name__)\n", 498 | " show_landmarks(**transformed_sample)\n", 499 | "\n", 500 | "plt.show()" 501 | ], 502 | "execution_count": 0, 503 | "outputs": [] 504 | }, 505 | { 506 | "cell_type": "markdown", 507 | "metadata": { 508 | "id": "AXSVP-FqeQZb", 509 | "colab_type": "text" 510 | }, 511 | "source": [ 512 | "Iterating through the dataset\n", 513 | "-----------------------------\n", 514 | "\n", 515 | "Let's put this all together to create a dataset with composed\n", 516 | "transforms.\n", 517 | "To summarize, every time this dataset is sampled:\n", 518 | "\n", 519 | "- An image is read from the file on the fly\n", 520 | "- Transforms are applied on the read image\n", 521 | "- Since one of the transforms is random, data is augmentated on\n", 522 | " sampling\n", 523 | "\n", 524 | "We can iterate over the created dataset with a ``for i in range``\n", 525 | "loop as before.\n", 526 | "\n", 527 | "\n" 528 | ] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "metadata": { 533 | "id": "GnrImAiyeQZc", 534 | "colab_type": "code", 535 | "colab": {} 536 | }, 537 | "source": [ 538 | "transformed_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv',\n", 539 | " root_dir='faces/',\n", 540 | " transform=transforms.Compose([\n", 541 | " Rescale(256),\n", 542 | " RandomCrop(224),\n", 543 | " ToTensor()\n", 544 | " ]))\n", 545 | "\n", 546 | "for i in range(len(transformed_dataset)):\n", 547 | " sample = transformed_dataset[i]\n", 548 | "\n", 549 | " print(i, sample['image'].size(), sample['landmarks'].size())\n", 550 | "\n", 551 | " if i == 3:\n", 552 | " break" 553 | ], 554 | "execution_count": 0, 555 | "outputs": [] 556 | }, 557 | { 558 | "cell_type": "markdown", 559 | "metadata": { 560 | "id": "zQAvdMiHeQZf", 561 | "colab_type": "text" 562 | }, 563 | "source": [ 564 | "However, we are losing a lot of features by using a simple ``for`` loop to\n", 565 | "iterate over the data. In particular, we are missing out on:\n", 566 | "\n", 567 | "- Batching the data\n", 568 | "- Shuffling the data\n", 569 | "- Load the data in parallel using ``multiprocessing`` workers.\n", 570 | "\n", 571 | "``torch.utils.data.DataLoader`` is an iterator which provides all these\n", 572 | "features. Parameters used below should be clear. One parameter of\n", 573 | "interest is ``collate_fn``. You can specify how exactly the samples need\n", 574 | "to be batched using ``collate_fn``. However, default collate should work\n", 575 | "fine for most use cases.\n", 576 | "\n", 577 | "\n" 578 | ] 579 | }, 580 | { 581 | "cell_type": "code", 582 | "metadata": { 583 | "id": "RY_ASVFceQZg", 584 | "colab_type": "code", 585 | "colab": {} 586 | }, 587 | "source": [ 588 | "dataloader = DataLoader(transformed_dataset, batch_size=4,\n", 589 | " shuffle=True, num_workers=4)\n", 590 | "\n", 591 | "\n", 592 | "# Helper function to show a batch\n", 593 | "def show_landmarks_batch(sample_batched):\n", 594 | " \"\"\"Show image with landmarks for a batch of samples.\"\"\"\n", 595 | " images_batch, landmarks_batch = \\\n", 596 | " sample_batched['image'], sample_batched['landmarks']\n", 597 | " batch_size = len(images_batch)\n", 598 | " im_size = images_batch.size(2)\n", 599 | "\n", 600 | " grid = utils.make_grid(images_batch)\n", 601 | " plt.imshow(grid.numpy().transpose((1, 2, 0)))\n", 602 | "\n", 603 | " for i in range(batch_size):\n", 604 | " plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size,\n", 605 | " landmarks_batch[i, :, 1].numpy(),\n", 606 | " s=10, marker='.', c='r')\n", 607 | "\n", 608 | " plt.title('Batch from dataloader')\n", 609 | "\n", 610 | "for i_batch, sample_batched in enumerate(dataloader):\n", 611 | " print(i_batch, sample_batched['image'].size(),\n", 612 | " sample_batched['landmarks'].size())\n", 613 | "\n", 614 | " # observe 4th batch and stop.\n", 615 | " if i_batch == 3:\n", 616 | " plt.figure()\n", 617 | " show_landmarks_batch(sample_batched)\n", 618 | " plt.axis('off')\n", 619 | " plt.ioff()\n", 620 | " plt.show()\n", 621 | " break" 622 | ], 623 | "execution_count": 0, 624 | "outputs": [] 625 | }, 626 | { 627 | "cell_type": "markdown", 628 | "metadata": { 629 | "id": "-IYb-jUUeQZn", 630 | "colab_type": "text" 631 | }, 632 | "source": [ 633 | "Afterword: torchvision\n", 634 | "----------------------\n", 635 | "\n", 636 | "In this tutorial, we have seen how to write and use datasets, transforms\n", 637 | "and dataloader. ``torchvision`` package provides some common datasets and\n", 638 | "transforms. You might not even have to write custom classes. One of the\n", 639 | "more generic datasets available in torchvision is ``ImageFolder``.\n", 640 | "It assumes that images are organized in the following way: ::\n", 641 | "\n", 642 | " root/ants/xxx.png\n", 643 | " root/ants/xxy.jpeg\n", 644 | " root/ants/xxz.png\n", 645 | " .\n", 646 | " .\n", 647 | " .\n", 648 | " root/bees/123.jpg\n", 649 | " root/bees/nsdf3.png\n", 650 | " root/bees/asd932_.png\n", 651 | "\n", 652 | "where 'ants', 'bees' etc. are class labels. Similarly generic transforms\n", 653 | "which operate on ``PIL.Image`` like ``RandomHorizontalFlip``, ``Scale``,\n", 654 | "are also available. You can use these to write a dataloader like this: ::\n", 655 | "\n", 656 | " import torch\n", 657 | " from torchvision import transforms, datasets\n", 658 | "\n", 659 | " data_transform = transforms.Compose([\n", 660 | " transforms.RandomSizedCrop(224),\n", 661 | " transforms.RandomHorizontalFlip(),\n", 662 | " transforms.ToTensor(),\n", 663 | " transforms.Normalize(mean=[0.485, 0.456, 0.406],\n", 664 | " std=[0.229, 0.224, 0.225])\n", 665 | " ])\n", 666 | " hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',\n", 667 | " transform=data_transform)\n", 668 | " dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,\n", 669 | " batch_size=4, shuffle=True,\n", 670 | " num_workers=4)\n", 671 | "\n", 672 | "For an example with training code, please see\n", 673 | ":doc:`transfer_learning_tutorial`.\n", 674 | "\n" 675 | ] 676 | } 677 | ] 678 | } -------------------------------------------------------------------------------- /getting-started/pytorch/tensor_intro.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "display_name": "Python 3", 7 | "language": "python", 8 | "name": "python3" 9 | }, 10 | "language_info": { 11 | "codemirror_mode": { 12 | "name": "ipython", 13 | "version": 3 14 | }, 15 | "file_extension": ".py", 16 | "mimetype": "text/x-python", 17 | "name": "python", 18 | "nbconvert_exporter": "python", 19 | "pygments_lexer": "ipython3", 20 | "version": "3.6.6" 21 | }, 22 | "colab": { 23 | "name": "pytorch_tensor_intro.ipynb", 24 | "provenance": [], 25 | "toc_visible": true 26 | } 27 | }, 28 | "cells": [ 29 | { 30 | "cell_type": "markdown", 31 | "metadata": { 32 | "id": "N2CemtUnq4du", 33 | "colab_type": "text" 34 | }, 35 | "source": [ 36 | "\n", 37 | " \n", 40 | " \n", 43 | "
\n", 38 | " Run in Google Colab\n", 39 | " \n", 41 | " View source on GitHub\n", 42 | "
" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": { 49 | "id": "nnIYNAHfwHBk", 50 | "colab_type": "text" 51 | }, 52 | "source": [ 53 | "\n", 54 | "*This notebook was originally published on [pytorch.org](https://pytorch.org)*\n", 55 | "\n", 56 | "What is PyTorch?\n", 57 | "================\n", 58 | "\n", 59 | "It’s a Python-based scientific computing package targeted at two sets of\n", 60 | "audiences:\n", 61 | "\n", 62 | "- A replacement for NumPy to use the power of GPUs\n", 63 | "- a deep learning research platform that provides maximum flexibility\n", 64 | " and speed\n", 65 | "\n", 66 | "Getting Started\n", 67 | "---------------\n", 68 | "\n", 69 | "### Tensors\n", 70 | "\n", 71 | "Tensors are similar to NumPy’s ndarrays, with the addition being that\n", 72 | "Tensors can also be used on a GPU to accelerate computing." 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "metadata": { 78 | "id": "ME_yKqc0wHBd", 79 | "colab_type": "code", 80 | "colab": {} 81 | }, 82 | "source": [ 83 | "%matplotlib inline" 84 | ], 85 | "execution_count": 0, 86 | "outputs": [] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "metadata": { 91 | "id": "ta5vl4I3wHBm", 92 | "colab_type": "code", 93 | "colab": {} 94 | }, 95 | "source": [ 96 | "from __future__ import print_function\n", 97 | "import torch" 98 | ], 99 | "execution_count": 0, 100 | "outputs": [] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": { 105 | "id": "kh0rvm25wHBq", 106 | "colab_type": "text" 107 | }, 108 | "source": [ 109 | "Construct a 5x3 matrix, uninitialized:\n", 110 | "\n" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "metadata": { 116 | "id": "5aGh48trwHBr", 117 | "colab_type": "code", 118 | "colab": {} 119 | }, 120 | "source": [ 121 | "x = torch.empty(5, 3)\n", 122 | "print(x)" 123 | ], 124 | "execution_count": 0, 125 | "outputs": [] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": { 130 | "id": "98pDhyICwHBv", 131 | "colab_type": "text" 132 | }, 133 | "source": [ 134 | "Construct a randomly initialized matrix:\n", 135 | "\n" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "metadata": { 141 | "id": "sbLmphVmwHBw", 142 | "colab_type": "code", 143 | "colab": {} 144 | }, 145 | "source": [ 146 | "x = torch.rand(5, 3)\n", 147 | "print(x)" 148 | ], 149 | "execution_count": 0, 150 | "outputs": [] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": { 155 | "id": "5EVgnuNQwHBz", 156 | "colab_type": "text" 157 | }, 158 | "source": [ 159 | "Construct a matrix filled zeros and of dtype long:\n", 160 | "\n" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "metadata": { 166 | "id": "qPMBgfUvwHB1", 167 | "colab_type": "code", 168 | "colab": {} 169 | }, 170 | "source": [ 171 | "x = torch.zeros(5, 3, dtype=torch.long)\n", 172 | "print(x)" 173 | ], 174 | "execution_count": 0, 175 | "outputs": [] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "metadata": { 180 | "id": "HurkpidTwHB4", 181 | "colab_type": "text" 182 | }, 183 | "source": [ 184 | "Construct a tensor directly from data:\n", 185 | "\n" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "metadata": { 191 | "id": "5U799A9ywHB7", 192 | "colab_type": "code", 193 | "colab": {} 194 | }, 195 | "source": [ 196 | "x = torch.tensor([5.5, 3])\n", 197 | "print(x)" 198 | ], 199 | "execution_count": 0, 200 | "outputs": [] 201 | }, 202 | { 203 | "cell_type": "markdown", 204 | "metadata": { 205 | "id": "LrtHJF7gwHB-", 206 | "colab_type": "text" 207 | }, 208 | "source": [ 209 | "or create a tensor based on an existing tensor. These methods\n", 210 | "will reuse properties of the input tensor, e.g. dtype, unless\n", 211 | "new values are provided by user\n", 212 | "\n" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "metadata": { 218 | "id": "y4QpmHyRwHCA", 219 | "colab_type": "code", 220 | "colab": {} 221 | }, 222 | "source": [ 223 | "x = x.new_ones(5, 3, dtype=torch.double) # new_* methods take in sizes\n", 224 | "print(x)\n", 225 | "\n", 226 | "x = torch.randn_like(x, dtype=torch.float) # override dtype!\n", 227 | "print(x) # result has the same size" 228 | ], 229 | "execution_count": 0, 230 | "outputs": [] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "metadata": { 235 | "id": "zb_JCA3zwHCD", 236 | "colab_type": "text" 237 | }, 238 | "source": [ 239 | "Get its size:\n", 240 | "\n" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "metadata": { 246 | "id": "zRZjx7LywHCE", 247 | "colab_type": "code", 248 | "colab": {} 249 | }, 250 | "source": [ 251 | "print(x.size())" 252 | ], 253 | "execution_count": 0, 254 | "outputs": [] 255 | }, 256 | { 257 | "cell_type": "markdown", 258 | "metadata": { 259 | "id": "SI9Zcg8ywHCH", 260 | "colab_type": "text" 261 | }, 262 | "source": [ 263 | "> #### Note\n", 264 | "`torch.Size` is in fact a tuple, so it supports all tuple operations.\n", 265 | "\n", 266 | "### Operations\n", 267 | "There are multiple syntaxes for operations. In the following\n", 268 | "example, we will take a look at the addition operation.\n", 269 | "\n", 270 | "Addition: syntax 1\n", 271 | "\n" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "metadata": { 277 | "id": "SGRBCWv0wHCH", 278 | "colab_type": "code", 279 | "colab": {} 280 | }, 281 | "source": [ 282 | "y = torch.rand(5, 3)\n", 283 | "print(x + y)" 284 | ], 285 | "execution_count": 0, 286 | "outputs": [] 287 | }, 288 | { 289 | "cell_type": "markdown", 290 | "metadata": { 291 | "id": "BSrArdUfwHCK", 292 | "colab_type": "text" 293 | }, 294 | "source": [ 295 | "Addition: syntax 2\n", 296 | "\n" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "metadata": { 302 | "id": "CUIOpY6SwHCL", 303 | "colab_type": "code", 304 | "colab": {} 305 | }, 306 | "source": [ 307 | "print(torch.add(x, y))" 308 | ], 309 | "execution_count": 0, 310 | "outputs": [] 311 | }, 312 | { 313 | "cell_type": "markdown", 314 | "metadata": { 315 | "id": "jP_YRy9rwHCO", 316 | "colab_type": "text" 317 | }, 318 | "source": [ 319 | "Addition: providing an output tensor as argument\n", 320 | "\n" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "metadata": { 326 | "id": "go2wACGhwHCP", 327 | "colab_type": "code", 328 | "colab": {} 329 | }, 330 | "source": [ 331 | "result = torch.empty(5, 3)\n", 332 | "torch.add(x, y, out=result)\n", 333 | "print(result)" 334 | ], 335 | "execution_count": 0, 336 | "outputs": [] 337 | }, 338 | { 339 | "cell_type": "markdown", 340 | "metadata": { 341 | "id": "2YekwAaZwHCS", 342 | "colab_type": "text" 343 | }, 344 | "source": [ 345 | "Addition: in-place\n", 346 | "\n" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "metadata": { 352 | "id": "nfLDWMTqwHCT", 353 | "colab_type": "code", 354 | "colab": {} 355 | }, 356 | "source": [ 357 | "# adds x to y\n", 358 | "y.add_(x)\n", 359 | "print(y)" 360 | ], 361 | "execution_count": 0, 362 | "outputs": [] 363 | }, 364 | { 365 | "cell_type": "markdown", 366 | "metadata": { 367 | "id": "P6U3H2tUwHCX", 368 | "colab_type": "text" 369 | }, 370 | "source": [ 371 | "> #### Note\n", 372 | "Any operation that mutates a tensor in-place is post-fixed with an `_`. For example: `x.copy_(y)`, `x.t_()`, will change `x`. \n", 373 | "\n", 374 | "You can use standard NumPy-like indexing with all bells and whistles!\n" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "metadata": { 380 | "id": "D5knUwBOwHCY", 381 | "colab_type": "code", 382 | "colab": {} 383 | }, 384 | "source": [ 385 | "print(x[:, 1])" 386 | ], 387 | "execution_count": 0, 388 | "outputs": [] 389 | }, 390 | { 391 | "cell_type": "markdown", 392 | "metadata": { 393 | "id": "AgAoBI6awHCb", 394 | "colab_type": "text" 395 | }, 396 | "source": [ 397 | "Resizing: If you want to resize/reshape tensor, you can use ``torch.view``:\n", 398 | "\n" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "metadata": { 404 | "id": "L_kCg12-wHCd", 405 | "colab_type": "code", 406 | "colab": {} 407 | }, 408 | "source": [ 409 | "x = torch.randn(4, 4)\n", 410 | "y = x.view(16)\n", 411 | "z = x.view(-1, 8) # the size -1 is inferred from other dimensions\n", 412 | "print(x.size(), y.size(), z.size())" 413 | ], 414 | "execution_count": 0, 415 | "outputs": [] 416 | }, 417 | { 418 | "cell_type": "markdown", 419 | "metadata": { 420 | "id": "5M9OwcfbwHCh", 421 | "colab_type": "text" 422 | }, 423 | "source": [ 424 | "If you have a one element tensor, use ``.item()`` to get the value as a\n", 425 | "Python number\n", 426 | "\n" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "metadata": { 432 | "id": "e4d4nyAOwHCi", 433 | "colab_type": "code", 434 | "colab": {} 435 | }, 436 | "source": [ 437 | "x = torch.randn(1)\n", 438 | "print(x)\n", 439 | "print(x.item())" 440 | ], 441 | "execution_count": 0, 442 | "outputs": [] 443 | }, 444 | { 445 | "cell_type": "markdown", 446 | "metadata": { 447 | "id": "5Ke5Wr0bwHCk", 448 | "colab_type": "text" 449 | }, 450 | "source": [ 451 | "**Read later:**\n", 452 | "\n", 453 | "\n", 454 | " 100+ Tensor operations, including transposing, indexing, slicing,\n", 455 | " mathematical operations, linear algebra, random numbers, etc.,\n", 456 | " are described\n", 457 | " in the [docs](http://pytorch.org/docs/torch).\n", 458 | "\n", 459 | "NumPy Bridge\n", 460 | "------------\n", 461 | "\n", 462 | "Converting a Torch Tensor to a NumPy array and vice versa is a breeze.\n", 463 | "\n", 464 | "The Torch Tensor and NumPy array will share their underlying memory\n", 465 | "locations, and changing one will change the other.\n", 466 | "\n", 467 | "### Converting a Torch Tensor to a NumPy Array\n", 468 | "\n" 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "metadata": { 474 | "id": "xk-M7Ep1wHCl", 475 | "colab_type": "code", 476 | "colab": {} 477 | }, 478 | "source": [ 479 | "a = torch.ones(5)\n", 480 | "print(a)" 481 | ], 482 | "execution_count": 0, 483 | "outputs": [] 484 | }, 485 | { 486 | "cell_type": "code", 487 | "metadata": { 488 | "id": "N60PXRxKwHCo", 489 | "colab_type": "code", 490 | "colab": {} 491 | }, 492 | "source": [ 493 | "b = a.numpy()\n", 494 | "print(b)" 495 | ], 496 | "execution_count": 0, 497 | "outputs": [] 498 | }, 499 | { 500 | "cell_type": "markdown", 501 | "metadata": { 502 | "id": "5xXNJtv0wHCq", 503 | "colab_type": "text" 504 | }, 505 | "source": [ 506 | "See how the numpy array changed in value.\n", 507 | "\n" 508 | ] 509 | }, 510 | { 511 | "cell_type": "code", 512 | "metadata": { 513 | "id": "xt7bKmeWwHCr", 514 | "colab_type": "code", 515 | "colab": {} 516 | }, 517 | "source": [ 518 | "a.add_(1)\n", 519 | "print(a)\n", 520 | "print(b)" 521 | ], 522 | "execution_count": 0, 523 | "outputs": [] 524 | }, 525 | { 526 | "cell_type": "markdown", 527 | "metadata": { 528 | "id": "KhNyQeo9wHCu", 529 | "colab_type": "text" 530 | }, 531 | "source": [ 532 | "### Converting NumPy Array to Torch Tensor\n", 533 | "See how changing the np array changed the Torch Tensor automatically\n", 534 | "\n" 535 | ] 536 | }, 537 | { 538 | "cell_type": "code", 539 | "metadata": { 540 | "id": "ysGTrbTTwHCv", 541 | "colab_type": "code", 542 | "colab": {} 543 | }, 544 | "source": [ 545 | "import numpy as np\n", 546 | "a = np.ones(5)\n", 547 | "b = torch.from_numpy(a)\n", 548 | "np.add(a, 1, out=a)\n", 549 | "print(a)\n", 550 | "print(b)" 551 | ], 552 | "execution_count": 0, 553 | "outputs": [] 554 | }, 555 | { 556 | "cell_type": "markdown", 557 | "metadata": { 558 | "id": "05KUOUWAwHCx", 559 | "colab_type": "text" 560 | }, 561 | "source": [ 562 | "All the Tensors on the CPU except a CharTensor support converting to\n", 563 | "NumPy and back.\n", 564 | "\n", 565 | "CUDA Tensors\n", 566 | "------------\n", 567 | "\n", 568 | "Tensors can be moved onto any device using the ``.to`` method.\n", 569 | "\n" 570 | ] 571 | }, 572 | { 573 | "cell_type": "code", 574 | "metadata": { 575 | "id": "I8Un-mKTwHCy", 576 | "colab_type": "code", 577 | "colab": {} 578 | }, 579 | "source": [ 580 | "# let us run this cell only if CUDA is available\n", 581 | "# We will use ``torch.device`` objects to move tensors in and out of GPU\n", 582 | "if torch.cuda.is_available():\n", 583 | " device = torch.device(\"cuda\") # a CUDA device object\n", 584 | " y = torch.ones_like(x, device=device) # directly create a tensor on GPU\n", 585 | " x = x.to(device) # or just use strings ``.to(\"cuda\")``\n", 586 | " z = x + y\n", 587 | " print(z)\n", 588 | " print(z.to(\"cpu\", torch.double)) # ``.to`` can also change dtype together!" 589 | ], 590 | "execution_count": 0, 591 | "outputs": [] 592 | } 593 | ] 594 | } -------------------------------------------------------------------------------- /getting-started/scikit-learn/handwritten-digit-classifier.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "display_name": "Python 3", 7 | "language": "python", 8 | "name": "python3" 9 | }, 10 | "language_info": { 11 | "codemirror_mode": { 12 | "name": "ipython", 13 | "version": 3 14 | }, 15 | "file_extension": ".py", 16 | "mimetype": "text/x-python", 17 | "name": "python", 18 | "nbconvert_exporter": "python", 19 | "pygments_lexer": "ipython3", 20 | "version": "3.7.6" 21 | }, 22 | "colab": { 23 | "name": "digit_classification.ipynb", 24 | "provenance": [] 25 | } 26 | }, 27 | "cells": [ 28 | { 29 | "cell_type": "markdown", 30 | "metadata": { 31 | "id": "7P-chWAmvl2p", 32 | "colab_type": "text" 33 | }, 34 | "source": [ 35 | "\n", 36 | " \n", 44 | " \n", 53 | "
\n", 37 | " Run\n", 41 | " in Google Colab\n", 43 | " \n", 45 | " View source on GitHub\n", 52 | "
" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "metadata": { 59 | "id": "TTXtHwqjvgaJ", 60 | "colab_type": "code", 61 | "colab": {} 62 | }, 63 | "source": [ 64 | "%matplotlib inline" 65 | ], 66 | "execution_count": 0, 67 | "outputs": [] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": { 72 | "id": "lSC3gf8CvgaO", 73 | "colab_type": "text" 74 | }, 75 | "source": [ 76 | "\n", 77 | "# Recognizing hand-written digits\n", 78 | "\n", 79 | "\n", 80 | "An example showing how the scikit-learn can be used to recognize images of\n", 81 | "hand-written digits.\n", 82 | "\n", 83 | "This example is commented in the\n", 84 | "`tutorial section of the user manual `.\n" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "metadata": { 90 | "id": "TUy-iA_OvgaP", 91 | "colab_type": "code", 92 | "colab": {} 93 | }, 94 | "source": [ 95 | "print(__doc__)\n", 96 | "\n", 97 | "# Author: Gael Varoquaux \n", 98 | "# License: BSD 3 clause\n", 99 | "\n", 100 | "# Standard scientific Python imports\n", 101 | "import matplotlib.pyplot as plt\n", 102 | "\n", 103 | "# Import datasets, classifiers and performance metrics\n", 104 | "from sklearn import datasets, svm, metrics\n", 105 | "from sklearn.model_selection import train_test_split\n", 106 | "\n", 107 | "# The digits dataset\n", 108 | "digits = datasets.load_digits()\n", 109 | "\n", 110 | "# The data that we are interested in is made of 8x8 images of digits, let's\n", 111 | "# have a look at the first 4 images, stored in the `images` attribute of the\n", 112 | "# dataset. If we were working from image files, we could load them using\n", 113 | "# matplotlib.pyplot.imread. Note that each image must have the same size. For these\n", 114 | "# images, we know which digit they represent: it is given in the 'target' of\n", 115 | "# the dataset.\n", 116 | "_, axes = plt.subplots(2, 4)\n", 117 | "images_and_labels = list(zip(digits.images, digits.target))\n", 118 | "for ax, (image, label) in zip(axes[0, :], images_and_labels[:4]):\n", 119 | " ax.set_axis_off()\n", 120 | " ax.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')\n", 121 | " ax.set_title('Training: %i' % label)\n", 122 | "\n", 123 | "# To apply a classifier on this data, we need to flatten the image, to\n", 124 | "# turn the data in a (samples, feature) matrix:\n", 125 | "n_samples = len(digits.images)\n", 126 | "data = digits.images.reshape((n_samples, -1))\n", 127 | "\n", 128 | "# Create a classifier: a support vector classifier\n", 129 | "classifier = svm.SVC(gamma=0.001)\n", 130 | "\n", 131 | "# Split data into train and test subsets\n", 132 | "X_train, X_test, y_train, y_test = train_test_split(\n", 133 | " data, digits.target, test_size=0.5, shuffle=False)\n", 134 | "\n", 135 | "# We learn the digits on the first half of the digits\n", 136 | "classifier.fit(X_train, y_train)\n", 137 | "\n", 138 | "# Now predict the value of the digit on the second half:\n", 139 | "predicted = classifier.predict(X_test)\n", 140 | "\n", 141 | "images_and_predictions = list(zip(digits.images[n_samples // 2:], predicted))\n", 142 | "for ax, (image, prediction) in zip(axes[1, :], images_and_predictions[:4]):\n", 143 | " ax.set_axis_off()\n", 144 | " ax.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')\n", 145 | " ax.set_title('Prediction: %i' % prediction)\n", 146 | "\n", 147 | "print(\"Classification report for classifier %s:\\n%s\\n\"\n", 148 | " % (classifier, metrics.classification_report(y_test, predicted)))\n", 149 | "disp = metrics.plot_confusion_matrix(classifier, X_test, y_test)\n", 150 | "disp.figure_.suptitle(\"Confusion Matrix\")\n", 151 | "print(\"Confusion matrix:\\n%s\" % disp.confusion_matrix)\n", 152 | "\n", 153 | "plt.show()" 154 | ], 155 | "execution_count": 0, 156 | "outputs": [] 157 | } 158 | ] 159 | } -------------------------------------------------------------------------------- /getting-started/tensorflow/handwritten_digit_classifier.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "beginner.ipynb", 7 | "provenance": [], 8 | "private_outputs": true, 9 | "collapsed_sections": [ 10 | "rX8mhOLljYeM" 11 | ], 12 | "toc_visible": true 13 | }, 14 | "kernelspec": { 15 | "display_name": "Python 3", 16 | "name": "python3" 17 | } 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "markdown", 22 | "metadata": { 23 | "colab_type": "text", 24 | "id": "Uc6IaxDSr3C9" 25 | }, 26 | "source": [ 27 | "\n", 28 | " \n", 31 | " \n", 34 | "
\n", 29 | " Run in Google Colab\n", 30 | " \n", 32 | " View source on GitHub\n", 33 | "
" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": { 40 | "colab_type": "text", 41 | "id": "3wF5wszaj97Y" 42 | }, 43 | "source": [ 44 | "# TensorFlow 2 quickstart for beginners" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": { 50 | "colab_type": "text", 51 | "id": "04QgGZc9bF5D" 52 | }, 53 | "source": [ 54 | "This short introduction uses [Keras](https://www.tensorflow.org/guide/keras/overview) to:\n", 55 | "\n", 56 | "1. Build a neural network that classifies images.\n", 57 | "2. Train this neural network.\n", 58 | "3. And, finally, evaluate the accuracy of the model." 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": { 64 | "colab_type": "text", 65 | "id": "hiH7AC-NTniF" 66 | }, 67 | "source": [ 68 | "This is a [Google Colaboratory](https://colab.research.google.com/notebooks/welcome.ipynb) notebook file. Python programs are run directly in the browser—a great way to learn and use TensorFlow. To follow this tutorial, run the notebook in Google Colab by clicking the button at the top of this page.\n", 69 | "\n", 70 | "1. In Colab, connect to a Python runtime: At the top-right of the menu bar, select *CONNECT*.\n", 71 | "2. Run all the notebook code cells: Select *Runtime* > *Run all*." 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": { 77 | "colab_type": "text", 78 | "id": "nnrWf3PCEzXL" 79 | }, 80 | "source": [ 81 | "Download and install the TensorFlow 2 package. Import TensorFlow into your program:" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 0, 87 | "metadata": { 88 | "colab": {}, 89 | "colab_type": "code", 90 | "id": "0trJmd6DjqBZ" 91 | }, 92 | "outputs": [], 93 | "source": [ 94 | "from __future__ import absolute_import, division, print_function, unicode_literals\n", 95 | "\n", 96 | "# Install TensorFlow\n", 97 | "try:\n", 98 | " # %tensorflow_version only exists in Colab.\n", 99 | " %tensorflow_version 2.x\n", 100 | "except Exception:\n", 101 | " pass\n", 102 | "\n", 103 | "import tensorflow as tf" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": { 109 | "colab_type": "text", 110 | "id": "7NAbSZiaoJ4z" 111 | }, 112 | "source": [ 113 | "Load and prepare the [MNIST dataset](http://yann.lecun.com/exdb/mnist/). Convert the samples from integers to floating-point numbers:" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 0, 119 | "metadata": { 120 | "colab": {}, 121 | "colab_type": "code", 122 | "id": "7FP5258xjs-v" 123 | }, 124 | "outputs": [], 125 | "source": [ 126 | "mnist = tf.keras.datasets.mnist\n", 127 | "\n", 128 | "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n", 129 | "x_train, x_test = x_train / 255.0, x_test / 255.0" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": { 135 | "colab_type": "text", 136 | "id": "BPZ68wASog_I" 137 | }, 138 | "source": [ 139 | "Build the `tf.keras.Sequential` model by stacking layers. Choose an optimizer and loss function for training:" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 0, 145 | "metadata": { 146 | "colab": {}, 147 | "colab_type": "code", 148 | "id": "h3IKyzTCDNGo" 149 | }, 150 | "outputs": [], 151 | "source": [ 152 | "model = tf.keras.models.Sequential([\n", 153 | " tf.keras.layers.Flatten(input_shape=(28, 28)),\n", 154 | " tf.keras.layers.Dense(128, activation='relu'),\n", 155 | " tf.keras.layers.Dropout(0.2),\n", 156 | " tf.keras.layers.Dense(10)\n", 157 | "])" 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "metadata": { 163 | "colab_type": "text", 164 | "id": "l2hiez2eIUz8" 165 | }, 166 | "source": [ 167 | "For each example the model returns a vector of \"[logits](https://developers.google.com/machine-learning/glossary#logits)\" or \"[log-odds](https://developers.google.com/machine-learning/glossary#log-odds)\" scores, one for each class." 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 0, 173 | "metadata": { 174 | "colab": {}, 175 | "colab_type": "code", 176 | "id": "OeOrNdnkEEcR" 177 | }, 178 | "outputs": [], 179 | "source": [ 180 | "predictions = model(x_train[:1]).numpy()\n", 181 | "predictions" 182 | ] 183 | }, 184 | { 185 | "cell_type": "markdown", 186 | "metadata": { 187 | "colab_type": "text", 188 | "id": "tgjhDQGcIniO" 189 | }, 190 | "source": [ 191 | "The `tf.nn.softmax` function converts these logits to \"probabilities\" for each class: " 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 0, 197 | "metadata": { 198 | "colab": {}, 199 | "colab_type": "code", 200 | "id": "zWSRnQ0WI5eq" 201 | }, 202 | "outputs": [], 203 | "source": [ 204 | "tf.nn.softmax(predictions).numpy()" 205 | ] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "metadata": { 210 | "colab_type": "text", 211 | "id": "he5u_okAYS4a" 212 | }, 213 | "source": [ 214 | "Note: It is possible to bake this `tf.nn.softmax` in as the activation function for the last layer of the network. While this can make the model output more directly interpretable, this approach is discouraged as it's impossible to\n", 215 | "provide an exact and numerically stable loss calculation for all models when using a softmax output. " 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": { 221 | "colab_type": "text", 222 | "id": "hQyugpgRIyrA" 223 | }, 224 | "source": [ 225 | "The `losses.SparseCategoricalCrossentropy` loss takes a vector of logits and a `True` index and returns a scalar loss for each example." 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 0, 231 | "metadata": { 232 | "colab": {}, 233 | "colab_type": "code", 234 | "id": "RSkzdv8MD0tT" 235 | }, 236 | "outputs": [], 237 | "source": [ 238 | "loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)" 239 | ] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "metadata": { 244 | "colab_type": "text", 245 | "id": "SfR4MsSDU880" 246 | }, 247 | "source": [ 248 | "This loss is equal to the negative log probability of the the true class:\n", 249 | "It is zero if the model is sure of the correct class.\n", 250 | "\n", 251 | "This untrained model gives probabilities close to random (1/10 for each class), so the initial loss should be close to `-tf.log(1/10) ~= 2.3`." 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": 0, 257 | "metadata": { 258 | "colab": {}, 259 | "colab_type": "code", 260 | "id": "NJWqEVrrJ7ZB" 261 | }, 262 | "outputs": [], 263 | "source": [ 264 | "loss_fn(y_train[:1], predictions).numpy()" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 0, 270 | "metadata": { 271 | "colab": {}, 272 | "colab_type": "code", 273 | "id": "9foNKHzTD2Vo" 274 | }, 275 | "outputs": [], 276 | "source": [ 277 | "model.compile(optimizer='adam',\n", 278 | " loss=loss_fn,\n", 279 | " metrics=['accuracy'])" 280 | ] 281 | }, 282 | { 283 | "cell_type": "markdown", 284 | "metadata": { 285 | "colab_type": "text", 286 | "id": "ix4mEL65on-w" 287 | }, 288 | "source": [ 289 | "The `Model.fit` method adjusts the model parameters to minimize the loss: " 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 0, 295 | "metadata": { 296 | "colab": {}, 297 | "colab_type": "code", 298 | "id": "y7suUbJXVLqP" 299 | }, 300 | "outputs": [], 301 | "source": [ 302 | "model.fit(x_train, y_train, epochs=5)" 303 | ] 304 | }, 305 | { 306 | "cell_type": "markdown", 307 | "metadata": { 308 | "colab_type": "text", 309 | "id": "4mDAAPFqVVgn" 310 | }, 311 | "source": [ 312 | "The `Model.evaluate` method checks the models performance, usually on a \"[Validation-set](https://developers.google.com/machine-learning/glossary#validation-set)\"." 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": 0, 318 | "metadata": { 319 | "colab": {}, 320 | "colab_type": "code", 321 | "id": "F7dTAzgHDUh7" 322 | }, 323 | "outputs": [], 324 | "source": [ 325 | "model.evaluate(x_test, y_test, verbose=2)" 326 | ] 327 | }, 328 | { 329 | "cell_type": "markdown", 330 | "metadata": { 331 | "colab_type": "text", 332 | "id": "T4JfEh7kvx6m" 333 | }, 334 | "source": [ 335 | "The image classifier is now trained to ~98% accuracy on this dataset. To learn more, read the [TensorFlow tutorials](https://www.tensorflow.org/tutorials/)." 336 | ] 337 | }, 338 | { 339 | "cell_type": "markdown", 340 | "metadata": { 341 | "colab_type": "text", 342 | "id": "Aj8NrlzlJqDG" 343 | }, 344 | "source": [ 345 | "If you want your model to return a probability, you can wrap the trained model, and attach the softmax to it:" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": 0, 351 | "metadata": { 352 | "colab": {}, 353 | "colab_type": "code", 354 | "id": "rYb6DrEH0GMv" 355 | }, 356 | "outputs": [], 357 | "source": [ 358 | "probability_model = tf.keras.Sequential([\n", 359 | " model,\n", 360 | " tf.keras.layers.Softmax()\n", 361 | "])" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": 0, 367 | "metadata": { 368 | "colab": {}, 369 | "colab_type": "code", 370 | "id": "cnqOZtUp1YR_" 371 | }, 372 | "outputs": [], 373 | "source": [ 374 | "probability_model(x_test[:5])" 375 | ] 376 | }, 377 | { 378 | "cell_type": "markdown", 379 | "metadata": { 380 | "colab_type": "text", 381 | "id": "xfvvdBoqtOxs" 382 | }, 383 | "source": [ 384 | "# License\n", 385 | "\n", 386 | "##### Copyright 2019 The TensorFlow Authors." 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 0, 392 | "metadata": { 393 | "cellView": "both", 394 | "colab": {}, 395 | "colab_type": "code", 396 | "id": "BZSlp3DAjdYf" 397 | }, 398 | "outputs": [], 399 | "source": [ 400 | "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", 401 | "# you may not use this file except in compliance with the License.\n", 402 | "# You may obtain a copy of the License at\n", 403 | "#\n", 404 | "# https://www.apache.org/licenses/LICENSE-2.0\n", 405 | "#\n", 406 | "# Unless required by applicable law or agreed to in writing, software\n", 407 | "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", 408 | "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", 409 | "# See the License for the specific language governing permissions and\n", 410 | "# limitations under the License." 411 | ] 412 | } 413 | ] 414 | } -------------------------------------------------------------------------------- /getting-started/tensorflow/image_classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "classification.ipynb", 7 | "provenance": [], 8 | "private_outputs": true, 9 | "collapsed_sections": [], 10 | "toc_visible": true 11 | }, 12 | "kernelspec": { 13 | "display_name": "Python 3", 14 | "name": "python3" 15 | } 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "colab_type": "text", 22 | "id": "S5Uhzt6vVIB2" 23 | }, 24 | "source": [ 25 | "\n", 26 | " \n", 29 | " \n", 32 | "
\n", 27 | " Run in Google Colab\n", 28 | " \n", 30 | " View source on GitHub\n", 31 | "
" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": { 38 | "colab_type": "text", 39 | "id": "jYysdyb-CaWM" 40 | }, 41 | "source": [ 42 | "# Basic classification: Classify images of clothing" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "metadata": { 48 | "colab_type": "code", 49 | "id": "jL3OqFKZ9dFg", 50 | "colab": {} 51 | }, 52 | "source": [ 53 | "try:\n", 54 | " # %tensorflow_version only exists in Colab.\n", 55 | " %tensorflow_version 2.x\n", 56 | "except Exception:\n", 57 | " pass\n" 58 | ], 59 | "execution_count": 0, 60 | "outputs": [] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "metadata": { 65 | "colab_type": "code", 66 | "id": "dzLKpmZICaWN", 67 | "colab": {} 68 | }, 69 | "source": [ 70 | "from __future__ import absolute_import, division, print_function, unicode_literals\n", 71 | "\n", 72 | "# TensorFlow and tf.keras\n", 73 | "import tensorflow as tf\n", 74 | "from tensorflow import keras\n", 75 | "\n", 76 | "# Helper libraries\n", 77 | "import numpy as np\n", 78 | "import matplotlib.pyplot as plt\n", 79 | "\n", 80 | "print(tf.__version__)" 81 | ], 82 | "execution_count": 0, 83 | "outputs": [] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": { 88 | "colab_type": "text", 89 | "id": "yR0EdgrLCaWR" 90 | }, 91 | "source": [ 92 | "## Import the Fashion MNIST dataset" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": { 98 | "colab_type": "text", 99 | "id": "DLdCchMdCaWQ" 100 | }, 101 | "source": [ 102 | "This guide uses the [Fashion MNIST](https://github.com/zalandoresearch/fashion-mnist) dataset which contains 70,000 grayscale images in 10 categories. The images show individual articles of clothing at low resolution (28 by 28 pixels), as seen here:\n", 103 | "\n", 104 | "\n", 105 | " \n", 109 | " \n", 112 | "
\n", 106 | " \"Fashion\n", 108 | "
\n", 110 | " Figure 1. Fashion-MNIST samples (by Zalando, MIT License).
 \n", 111 | "
\n", 113 | "\n", 114 | "Fashion MNIST is intended as a drop-in replacement for the classic [MNIST](http://yann.lecun.com/exdb/mnist/) dataset—often used as the \"Hello, World\" of machine learning programs for computer vision. The MNIST dataset contains images of handwritten digits (0, 1, 2, etc.) in a format identical to that of the articles of clothing you'll use here.\n", 115 | "\n", 116 | "This guide uses Fashion MNIST for variety, and because it's a slightly more challenging problem than regular MNIST. Both datasets are relatively small and are used to verify that an algorithm works as expected. They're good starting points to test and debug code.\n", 117 | "\n", 118 | "Here, 60,000 images are used to train the network and 10,000 images to evaluate how accurately the network learned to classify images. You can access the Fashion MNIST directly from TensorFlow. Import and load the Fashion MNIST data directly from TensorFlow:" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "metadata": { 124 | "colab_type": "code", 125 | "id": "7MqDQO0KCaWS", 126 | "colab": {} 127 | }, 128 | "source": [ 129 | "fashion_mnist = keras.datasets.fashion_mnist\n", 130 | "\n", 131 | "(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()" 132 | ], 133 | "execution_count": 0, 134 | "outputs": [] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "metadata": { 139 | "colab_type": "text", 140 | "id": "t9FDsUlxCaWW" 141 | }, 142 | "source": [ 143 | "Loading the dataset returns four NumPy arrays:\n", 144 | "\n", 145 | "* The `train_images` and `train_labels` arrays are the *training set*—the data the model uses to learn.\n", 146 | "* The model is tested against the *test set*, the `test_images`, and `test_labels` arrays.\n", 147 | "\n", 148 | "The images are 28x28 NumPy arrays, with pixel values ranging from 0 to 255. The *labels* are an array of integers, ranging from 0 to 9. These correspond to the *class* of clothing the image represents:\n", 149 | "\n", 150 | "\n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | "
LabelClass
0T-shirt/top
1Trouser
2Pullover
3Dress
4Coat
5Sandal
6Shirt
7Sneaker
8Bag
9Ankle boot
\n", 196 | "\n", 197 | "Each image is mapped to a single label. Since the *class names* are not included with the dataset, store them here to use later when plotting the images:" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "metadata": { 203 | "colab_type": "code", 204 | "id": "IjnLH5S2CaWx", 205 | "colab": {} 206 | }, 207 | "source": [ 208 | "class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',\n", 209 | " 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']" 210 | ], 211 | "execution_count": 0, 212 | "outputs": [] 213 | }, 214 | { 215 | "cell_type": "markdown", 216 | "metadata": { 217 | "colab_type": "text", 218 | "id": "Brm0b_KACaWX" 219 | }, 220 | "source": [ 221 | "## Explore the data\n", 222 | "\n", 223 | "Let's explore the format of the dataset before training the model. The following shows there are 60,000 images in the training set, with each image represented as 28 x 28 pixels:" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "metadata": { 229 | "colab_type": "code", 230 | "id": "zW5k_xz1CaWX", 231 | "colab": {} 232 | }, 233 | "source": [ 234 | "train_images.shape" 235 | ], 236 | "execution_count": 0, 237 | "outputs": [] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "metadata": { 242 | "colab_type": "text", 243 | "id": "cIAcvQqMCaWf" 244 | }, 245 | "source": [ 246 | "Likewise, there are 60,000 labels in the training set:" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "metadata": { 252 | "colab_type": "code", 253 | "id": "TRFYHB2mCaWb", 254 | "colab": {} 255 | }, 256 | "source": [ 257 | "len(train_labels)" 258 | ], 259 | "execution_count": 0, 260 | "outputs": [] 261 | }, 262 | { 263 | "cell_type": "markdown", 264 | "metadata": { 265 | "colab_type": "text", 266 | "id": "YSlYxFuRCaWk" 267 | }, 268 | "source": [ 269 | "Each label is an integer between 0 and 9:" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "metadata": { 275 | "colab_type": "code", 276 | "id": "XKnCTHz4CaWg", 277 | "colab": {} 278 | }, 279 | "source": [ 280 | "train_labels" 281 | ], 282 | "execution_count": 0, 283 | "outputs": [] 284 | }, 285 | { 286 | "cell_type": "markdown", 287 | "metadata": { 288 | "colab_type": "text", 289 | "id": "TMPI88iZpO2T" 290 | }, 291 | "source": [ 292 | "There are 10,000 images in the test set. Again, each image is represented as 28 x 28 pixels:" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "metadata": { 298 | "colab_type": "code", 299 | "id": "2KFnYlcwCaWl", 300 | "colab": {} 301 | }, 302 | "source": [ 303 | "test_images.shape" 304 | ], 305 | "execution_count": 0, 306 | "outputs": [] 307 | }, 308 | { 309 | "cell_type": "markdown", 310 | "metadata": { 311 | "colab_type": "text", 312 | "id": "rd0A0Iu0CaWq" 313 | }, 314 | "source": [ 315 | "And the test set contains 10,000 images labels:" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "metadata": { 321 | "colab_type": "code", 322 | "id": "iJmPr5-ACaWn", 323 | "colab": {} 324 | }, 325 | "source": [ 326 | "len(test_labels)" 327 | ], 328 | "execution_count": 0, 329 | "outputs": [] 330 | }, 331 | { 332 | "cell_type": "markdown", 333 | "metadata": { 334 | "colab_type": "text", 335 | "id": "ES6uQoLKCaWr" 336 | }, 337 | "source": [ 338 | "## Preprocess the data\n", 339 | "\n", 340 | "The data must be preprocessed before training the network. If you inspect the first image in the training set, you will see that the pixel values fall in the range of 0 to 255:" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "metadata": { 346 | "colab_type": "code", 347 | "id": "m4VEw8Ud9Quh", 348 | "colab": {} 349 | }, 350 | "source": [ 351 | "plt.figure()\n", 352 | "plt.imshow(train_images[0])\n", 353 | "plt.colorbar()\n", 354 | "plt.grid(False)\n", 355 | "plt.show()" 356 | ], 357 | "execution_count": 0, 358 | "outputs": [] 359 | }, 360 | { 361 | "cell_type": "markdown", 362 | "metadata": { 363 | "colab_type": "text", 364 | "id": "Wz7l27Lz9S1P" 365 | }, 366 | "source": [ 367 | "Scale these values to a range of 0 to 1 before feeding them to the neural network model. To do so, divide the values by 255. It's important that the *training set* and the *testing set* be preprocessed in the same way:" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "metadata": { 373 | "colab_type": "code", 374 | "id": "bW5WzIPlCaWv", 375 | "colab": {} 376 | }, 377 | "source": [ 378 | "train_images = train_images / 255.0\n", 379 | "\n", 380 | "test_images = test_images / 255.0" 381 | ], 382 | "execution_count": 0, 383 | "outputs": [] 384 | }, 385 | { 386 | "cell_type": "markdown", 387 | "metadata": { 388 | "colab_type": "text", 389 | "id": "Ee638AlnCaWz" 390 | }, 391 | "source": [ 392 | "To verify that the data is in the correct format and that you're ready to build and train the network, let's display the first 25 images from the *training set* and display the class name below each image." 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "metadata": { 398 | "colab_type": "code", 399 | "id": "oZTImqg_CaW1", 400 | "colab": {} 401 | }, 402 | "source": [ 403 | "plt.figure(figsize=(10,10))\n", 404 | "for i in range(25):\n", 405 | " plt.subplot(5,5,i+1)\n", 406 | " plt.xticks([])\n", 407 | " plt.yticks([])\n", 408 | " plt.grid(False)\n", 409 | " plt.imshow(train_images[i], cmap=plt.cm.binary)\n", 410 | " plt.xlabel(class_names[train_labels[i]])\n", 411 | "plt.show()" 412 | ], 413 | "execution_count": 0, 414 | "outputs": [] 415 | }, 416 | { 417 | "cell_type": "markdown", 418 | "metadata": { 419 | "colab_type": "text", 420 | "id": "59veuiEZCaW4" 421 | }, 422 | "source": [ 423 | "## Build the model\n", 424 | "\n", 425 | "Building the neural network requires configuring the layers of the model, then compiling the model." 426 | ] 427 | }, 428 | { 429 | "cell_type": "markdown", 430 | "metadata": { 431 | "colab_type": "text", 432 | "id": "Gxg1XGm0eOBy" 433 | }, 434 | "source": [ 435 | "### Set up the layers\n", 436 | "\n", 437 | "The basic building block of a neural network is the *layer*. Layers extract representations from the data fed into them. Hopefully, these representations are meaningful for the problem at hand.\n", 438 | "\n", 439 | "Most of deep learning consists of chaining together simple layers. Most layers, such as `tf.keras.layers.Dense`, have parameters that are learned during training." 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "metadata": { 445 | "colab_type": "code", 446 | "id": "9ODch-OFCaW4", 447 | "colab": {} 448 | }, 449 | "source": [ 450 | "model = keras.Sequential([\n", 451 | " keras.layers.Flatten(input_shape=(28, 28)),\n", 452 | " keras.layers.Dense(128, activation='relu'),\n", 453 | " keras.layers.Dense(10)\n", 454 | "])" 455 | ], 456 | "execution_count": 0, 457 | "outputs": [] 458 | }, 459 | { 460 | "cell_type": "markdown", 461 | "metadata": { 462 | "colab_type": "text", 463 | "id": "gut8A_7rCaW6" 464 | }, 465 | "source": [ 466 | "The first layer in this network, `tf.keras.layers.Flatten`, transforms the format of the images from a two-dimensional array (of 28 by 28 pixels) to a one-dimensional array (of 28 * 28 = 784 pixels). Think of this layer as unstacking rows of pixels in the image and lining them up. This layer has no parameters to learn; it only reformats the data.\n", 467 | "\n", 468 | "After the pixels are flattened, the network consists of a sequence of two `tf.keras.layers.Dense` layers. These are densely connected, or fully connected, neural layers. The first `Dense` layer has 128 nodes (or neurons). The second (and last) layer is a 10-node *softmax* layer that returns an array of 10 probability scores that sum to 1. Each node contains a score that indicates the probability that the current image belongs to one of the 10 classes.\n", 469 | "\n", 470 | "### Compile the model\n", 471 | "\n", 472 | "Before the model is ready for training, it needs a few more settings. These are added during the model's *compile* step:\n", 473 | "\n", 474 | "* *Loss function* —This measures how accurate the model is during training. You want to minimize this function to \"steer\" the model in the right direction.\n", 475 | "* *Optimizer* —This is how the model is updated based on the data it sees and its loss function.\n", 476 | "* *Metrics* —Used to monitor the training and testing steps. The following example uses *accuracy*, the fraction of the images that are correctly classified." 477 | ] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "metadata": { 482 | "colab_type": "code", 483 | "id": "Lhan11blCaW7", 484 | "colab": {} 485 | }, 486 | "source": [ 487 | "model.compile(optimizer='adam',\n", 488 | " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", 489 | " metrics=['accuracy'])" 490 | ], 491 | "execution_count": 0, 492 | "outputs": [] 493 | }, 494 | { 495 | "cell_type": "markdown", 496 | "metadata": { 497 | "colab_type": "text", 498 | "id": "qKF6uW-BCaW-" 499 | }, 500 | "source": [ 501 | "## Train the model\n", 502 | "\n", 503 | "Training the neural network model requires the following steps:\n", 504 | "\n", 505 | "1. Feed the training data to the model. In this example, the training data is in the `train_images` and `train_labels` arrays.\n", 506 | "2. The model learns to associate images and labels.\n", 507 | "3. You ask the model to make predictions about a test set—in this example, the `test_images` array.\n", 508 | "4. Verify that the predictions match the labels from the `test_labels` array.\n", 509 | "\n" 510 | ] 511 | }, 512 | { 513 | "cell_type": "markdown", 514 | "metadata": { 515 | "colab_type": "text", 516 | "id": "Z4P4zIV7E28Z" 517 | }, 518 | "source": [ 519 | "### Feed the model\n", 520 | "\n", 521 | "To start training, call the `model.fit` method—so called because it \"fits\" the model to the training data:" 522 | ] 523 | }, 524 | { 525 | "cell_type": "code", 526 | "metadata": { 527 | "colab_type": "code", 528 | "id": "xvwvpA64CaW_", 529 | "colab": {} 530 | }, 531 | "source": [ 532 | "model.fit(train_images, train_labels, epochs=10)" 533 | ], 534 | "execution_count": 0, 535 | "outputs": [] 536 | }, 537 | { 538 | "cell_type": "markdown", 539 | "metadata": { 540 | "colab_type": "text", 541 | "id": "W3ZVOhugCaXA" 542 | }, 543 | "source": [ 544 | "As the model trains, the loss and accuracy metrics are displayed. This model reaches an accuracy of about 0.91 (or 91%) on the training data." 545 | ] 546 | }, 547 | { 548 | "cell_type": "markdown", 549 | "metadata": { 550 | "colab_type": "text", 551 | "id": "wCpr6DGyE28h" 552 | }, 553 | "source": [ 554 | "### Evaluate accuracy\n", 555 | "\n", 556 | "Next, compare how the model performs on the test dataset:" 557 | ] 558 | }, 559 | { 560 | "cell_type": "code", 561 | "metadata": { 562 | "colab_type": "code", 563 | "id": "VflXLEeECaXC", 564 | "colab": {} 565 | }, 566 | "source": [ 567 | "test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)\n", 568 | "\n", 569 | "print('\\nTest accuracy:', test_acc)" 570 | ], 571 | "execution_count": 0, 572 | "outputs": [] 573 | }, 574 | { 575 | "cell_type": "markdown", 576 | "metadata": { 577 | "colab_type": "text", 578 | "id": "yWfgsmVXCaXG" 579 | }, 580 | "source": [ 581 | "It turns out that the accuracy on the test dataset is a little less than the accuracy on the training dataset. This gap between training accuracy and test accuracy represents *overfitting*. Overfitting is when a machine learning model performs worse on new, previously unseen inputs than on the training data. An overfitted model \"memorizes\" the training data—with less accuracy on testing data. For more information, see the following:\n", 582 | "* [Demonstrate overfitting](https://www.tensorflow.org/tutorials/keras/overfit_and_underfit#demonstrate_overfitting)\n", 583 | "* [Strategies to prevent overfitting](https://www.tensorflow.org/tutorials/keras/overfit_and_underfit#strategies_to_prevent_overfitting)" 584 | ] 585 | }, 586 | { 587 | "cell_type": "markdown", 588 | "metadata": { 589 | "colab_type": "text", 590 | "id": "v-PyD1SYE28q" 591 | }, 592 | "source": [ 593 | "### Make predictions\n", 594 | "\n", 595 | "With the model trained, you can use it to make predictions about some images." 596 | ] 597 | }, 598 | { 599 | "cell_type": "code", 600 | "metadata": { 601 | "colab_type": "code", 602 | "id": "Gl91RPhdCaXI", 603 | "colab": {} 604 | }, 605 | "source": [ 606 | "predictions = model.predict(test_images)" 607 | ], 608 | "execution_count": 0, 609 | "outputs": [] 610 | }, 611 | { 612 | "cell_type": "markdown", 613 | "metadata": { 614 | "colab_type": "text", 615 | "id": "x9Kk1voUCaXJ" 616 | }, 617 | "source": [ 618 | "Here, the model has predicted the label for each image in the testing set. Let's take a look at the first prediction:" 619 | ] 620 | }, 621 | { 622 | "cell_type": "code", 623 | "metadata": { 624 | "colab_type": "code", 625 | "id": "3DmJEUinCaXK", 626 | "colab": {} 627 | }, 628 | "source": [ 629 | "predictions[0]" 630 | ], 631 | "execution_count": 0, 632 | "outputs": [] 633 | }, 634 | { 635 | "cell_type": "markdown", 636 | "metadata": { 637 | "colab_type": "text", 638 | "id": "-hw1hgeSCaXN" 639 | }, 640 | "source": [ 641 | "A prediction is an array of 10 numbers. They represent the model's \"confidence\" that the image corresponds to each of the 10 different articles of clothing. You can see which label has the highest confidence value:" 642 | ] 643 | }, 644 | { 645 | "cell_type": "code", 646 | "metadata": { 647 | "colab_type": "code", 648 | "id": "qsqenuPnCaXO", 649 | "colab": {} 650 | }, 651 | "source": [ 652 | "np.argmax(predictions[0])" 653 | ], 654 | "execution_count": 0, 655 | "outputs": [] 656 | }, 657 | { 658 | "cell_type": "markdown", 659 | "metadata": { 660 | "colab_type": "text", 661 | "id": "E51yS7iCCaXO" 662 | }, 663 | "source": [ 664 | "So, the model is most confident that this image is an ankle boot, or `class_names[9]`. Examining the test label shows that this classification is correct:" 665 | ] 666 | }, 667 | { 668 | "cell_type": "code", 669 | "metadata": { 670 | "colab_type": "code", 671 | "id": "Sd7Pgsu6CaXP", 672 | "colab": {} 673 | }, 674 | "source": [ 675 | "test_labels[0]" 676 | ], 677 | "execution_count": 0, 678 | "outputs": [] 679 | }, 680 | { 681 | "cell_type": "markdown", 682 | "metadata": { 683 | "colab_type": "text", 684 | "id": "ygh2yYC972ne" 685 | }, 686 | "source": [ 687 | "Graph this to look at the full set of 10 class predictions." 688 | ] 689 | }, 690 | { 691 | "cell_type": "code", 692 | "metadata": { 693 | "colab_type": "code", 694 | "id": "DvYmmrpIy6Y1", 695 | "colab": {} 696 | }, 697 | "source": [ 698 | "def plot_image(i, predictions_array, true_label, img):\n", 699 | " predictions_array, true_label, img = predictions_array, true_label[i], img[i]\n", 700 | " plt.grid(False)\n", 701 | " plt.xticks([])\n", 702 | " plt.yticks([])\n", 703 | "\n", 704 | " plt.imshow(img, cmap=plt.cm.binary)\n", 705 | "\n", 706 | " predicted_label = np.argmax(predictions_array)\n", 707 | " if predicted_label == true_label:\n", 708 | " color = 'blue'\n", 709 | " else:\n", 710 | " color = 'red'\n", 711 | "\n", 712 | " plt.xlabel(\"{} {:2.0f}% ({})\".format(class_names[predicted_label],\n", 713 | " 100*np.max(predictions_array),\n", 714 | " class_names[true_label]),\n", 715 | " color=color)\n", 716 | "\n", 717 | "def plot_value_array(i, predictions_array, true_label):\n", 718 | " predictions_array, true_label = predictions_array, true_label[i]\n", 719 | " plt.grid(False)\n", 720 | " plt.xticks(range(10))\n", 721 | " plt.yticks([])\n", 722 | " thisplot = plt.bar(range(10), predictions_array, color=\"#777777\")\n", 723 | " plt.ylim([0, 1])\n", 724 | " predicted_label = np.argmax(predictions_array)\n", 725 | "\n", 726 | " thisplot[predicted_label].set_color('red')\n", 727 | " thisplot[true_label].set_color('blue')" 728 | ], 729 | "execution_count": 0, 730 | "outputs": [] 731 | }, 732 | { 733 | "cell_type": "markdown", 734 | "metadata": { 735 | "colab_type": "text", 736 | "id": "Zh9yABaME29S" 737 | }, 738 | "source": [ 739 | "### Verify predictions\n", 740 | "\n", 741 | "With the model trained, you can use it to make predictions about some images." 742 | ] 743 | }, 744 | { 745 | "cell_type": "markdown", 746 | "metadata": { 747 | "colab_type": "text", 748 | "id": "d4Ov9OFDMmOD" 749 | }, 750 | "source": [ 751 | "Let's look at the 0th image, predictions, and prediction array. Correct prediction labels are blue and incorrect prediction labels are red. The number gives the percentage (out of 100) for the predicted label." 752 | ] 753 | }, 754 | { 755 | "cell_type": "code", 756 | "metadata": { 757 | "colab_type": "code", 758 | "id": "HV5jw-5HwSmO", 759 | "colab": {} 760 | }, 761 | "source": [ 762 | "i = 0\n", 763 | "plt.figure(figsize=(6,3))\n", 764 | "plt.subplot(1,2,1)\n", 765 | "plot_image(i, predictions[i], test_labels, test_images)\n", 766 | "plt.subplot(1,2,2)\n", 767 | "plot_value_array(i, predictions[i], test_labels)\n", 768 | "plt.show()" 769 | ], 770 | "execution_count": 0, 771 | "outputs": [] 772 | }, 773 | { 774 | "cell_type": "code", 775 | "metadata": { 776 | "colab_type": "code", 777 | "id": "Ko-uzOufSCSe", 778 | "colab": {} 779 | }, 780 | "source": [ 781 | "i = 12\n", 782 | "plt.figure(figsize=(6,3))\n", 783 | "plt.subplot(1,2,1)\n", 784 | "plot_image(i, predictions[i], test_labels, test_images)\n", 785 | "plt.subplot(1,2,2)\n", 786 | "plot_value_array(i, predictions[i], test_labels)\n", 787 | "plt.show()" 788 | ], 789 | "execution_count": 0, 790 | "outputs": [] 791 | }, 792 | { 793 | "cell_type": "markdown", 794 | "metadata": { 795 | "colab_type": "text", 796 | "id": "kgdvGD52CaXR" 797 | }, 798 | "source": [ 799 | "Let's plot several images with their predictions. Note that the model can be wrong even when very confident." 800 | ] 801 | }, 802 | { 803 | "cell_type": "code", 804 | "metadata": { 805 | "colab_type": "code", 806 | "id": "hQlnbqaw2Qu_", 807 | "colab": {} 808 | }, 809 | "source": [ 810 | "# Plot the first X test images, their predicted labels, and the true labels.\n", 811 | "# Color correct predictions in blue and incorrect predictions in red.\n", 812 | "num_rows = 5\n", 813 | "num_cols = 3\n", 814 | "num_images = num_rows*num_cols\n", 815 | "plt.figure(figsize=(2*2*num_cols, 2*num_rows))\n", 816 | "for i in range(num_images):\n", 817 | " plt.subplot(num_rows, 2*num_cols, 2*i+1)\n", 818 | " plot_image(i, predictions[i], test_labels, test_images)\n", 819 | " plt.subplot(num_rows, 2*num_cols, 2*i+2)\n", 820 | " plot_value_array(i, predictions[i], test_labels)\n", 821 | "plt.tight_layout()\n", 822 | "plt.show()" 823 | ], 824 | "execution_count": 0, 825 | "outputs": [] 826 | }, 827 | { 828 | "cell_type": "markdown", 829 | "metadata": { 830 | "colab_type": "text", 831 | "id": "R32zteKHCaXT" 832 | }, 833 | "source": [ 834 | "## Use the trained model\n", 835 | "\n", 836 | "Finally, use the trained model to make a prediction about a single image." 837 | ] 838 | }, 839 | { 840 | "cell_type": "code", 841 | "metadata": { 842 | "colab_type": "code", 843 | "id": "yRJ7JU7JCaXT", 844 | "colab": {} 845 | }, 846 | "source": [ 847 | "# Grab an image from the test dataset.\n", 848 | "img = test_images[1]\n", 849 | "\n", 850 | "print(img.shape)" 851 | ], 852 | "execution_count": 0, 853 | "outputs": [] 854 | }, 855 | { 856 | "cell_type": "markdown", 857 | "metadata": { 858 | "colab_type": "text", 859 | "id": "vz3bVp21CaXV" 860 | }, 861 | "source": [ 862 | "`tf.keras` models are optimized to make predictions on a *batch*, or collection, of examples at once. Accordingly, even though you're using a single image, you need to add it to a list:" 863 | ] 864 | }, 865 | { 866 | "cell_type": "code", 867 | "metadata": { 868 | "colab_type": "code", 869 | "id": "lDFh5yF_CaXW", 870 | "colab": {} 871 | }, 872 | "source": [ 873 | "# Add the image to a batch where it's the only member.\n", 874 | "img = (np.expand_dims(img,0))\n", 875 | "\n", 876 | "print(img.shape)" 877 | ], 878 | "execution_count": 0, 879 | "outputs": [] 880 | }, 881 | { 882 | "cell_type": "markdown", 883 | "metadata": { 884 | "colab_type": "text", 885 | "id": "EQ5wLTkcCaXY" 886 | }, 887 | "source": [ 888 | "Now predict the correct label for this image:" 889 | ] 890 | }, 891 | { 892 | "cell_type": "code", 893 | "metadata": { 894 | "colab_type": "code", 895 | "id": "o_rzNSdrCaXY", 896 | "colab": {} 897 | }, 898 | "source": [ 899 | "predictions_single = model.predict(img)\n", 900 | "\n", 901 | "print(predictions_single)" 902 | ], 903 | "execution_count": 0, 904 | "outputs": [] 905 | }, 906 | { 907 | "cell_type": "code", 908 | "metadata": { 909 | "colab_type": "code", 910 | "id": "6Ai-cpLjO-3A", 911 | "colab": {} 912 | }, 913 | "source": [ 914 | "plot_value_array(1, predictions_single[0], test_labels)\n", 915 | "_ = plt.xticks(range(10), class_names, rotation=45)" 916 | ], 917 | "execution_count": 0, 918 | "outputs": [] 919 | }, 920 | { 921 | "cell_type": "markdown", 922 | "metadata": { 923 | "colab_type": "text", 924 | "id": "cU1Y2OAMCaXb" 925 | }, 926 | "source": [ 927 | "`model.predict` returns a list of lists—one list for each image in the batch of data. Grab the predictions for our (only) image in the batch:" 928 | ] 929 | }, 930 | { 931 | "cell_type": "code", 932 | "metadata": { 933 | "colab_type": "code", 934 | "id": "2tRmdq_8CaXb", 935 | "colab": {} 936 | }, 937 | "source": [ 938 | "np.argmax(predictions_single[0])" 939 | ], 940 | "execution_count": 0, 941 | "outputs": [] 942 | }, 943 | { 944 | "cell_type": "markdown", 945 | "metadata": { 946 | "colab_type": "text", 947 | "id": "YFc2HbEVCaXd" 948 | }, 949 | "source": [ 950 | "And the model predicts a label as expected." 951 | ] 952 | }, 953 | { 954 | "cell_type": "markdown", 955 | "metadata": { 956 | "colab_type": "text", 957 | "id": "MhoQ0WE77laV" 958 | }, 959 | "source": [ 960 | "# Copyright 2018 The TensorFlow Authors." 961 | ] 962 | }, 963 | { 964 | "cell_type": "code", 965 | "metadata": { 966 | "cellView": "form", 967 | "colab_type": "code", 968 | "id": "_ckMIh7O7s6D", 969 | "colab": {} 970 | }, 971 | "source": [ 972 | "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", 973 | "# you may not use this file except in compliance with the License.\n", 974 | "# You may obtain a copy of the License at\n", 975 | "#\n", 976 | "# https://www.apache.org/licenses/LICENSE-2.0\n", 977 | "#\n", 978 | "# Unless required by applicable law or agreed to in writing, software\n", 979 | "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", 980 | "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", 981 | "# See the License for the specific language governing permissions and\n", 982 | "# limitations under the License." 983 | ], 984 | "execution_count": 0, 985 | "outputs": [] 986 | }, 987 | { 988 | "cell_type": "code", 989 | "metadata": { 990 | "cellView": "form", 991 | "colab_type": "code", 992 | "id": "vasWnqRgy1H4", 993 | "colab": {} 994 | }, 995 | "source": [ 996 | "#@title MIT License\n", 997 | "#\n", 998 | "# Copyright (c) 2017 François Chollet\n", 999 | "#\n", 1000 | "# Permission is hereby granted, free of charge, to any person obtaining a\n", 1001 | "# copy of this software and associated documentation files (the \"Software\"),\n", 1002 | "# to deal in the Software without restriction, including without limitation\n", 1003 | "# the rights to use, copy, modify, merge, publish, distribute, sublicense,\n", 1004 | "# and/or sell copies of the Software, and to permit persons to whom the\n", 1005 | "# Software is furnished to do so, subject to the following conditions:\n", 1006 | "#\n", 1007 | "# The above copyright notice and this permission notice shall be included in\n", 1008 | "# all copies or substantial portions of the Software.\n", 1009 | "#\n", 1010 | "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", 1011 | "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", 1012 | "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL\n", 1013 | "# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", 1014 | "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n", 1015 | "# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\n", 1016 | "# DEALINGS IN THE SOFTWARE." 1017 | ], 1018 | "execution_count": 0, 1019 | "outputs": [] 1020 | } 1021 | ] 1022 | } -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notebookexplore/NotebookExplore/63d8db772e482cf003eb9696729984f916ec453f/logo.png -------------------------------------------------------------------------------- /reinforcement-learning/pytorch/deep_q_learning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "display_name": "Python 3", 7 | "language": "python", 8 | "name": "python3" 9 | }, 10 | "language_info": { 11 | "codemirror_mode": { 12 | "name": "ipython", 13 | "version": 3 14 | }, 15 | "file_extension": ".py", 16 | "mimetype": "text/x-python", 17 | "name": "python", 18 | "nbconvert_exporter": "python", 19 | "pygments_lexer": "ipython3", 20 | "version": "3.6.6" 21 | }, 22 | "colab": { 23 | "name": "reinforcement_q_learning.ipynb", 24 | "provenance": [] 25 | } 26 | }, 27 | "cells": [ 28 | { 29 | "cell_type": "markdown", 30 | "metadata": { 31 | "id": "DDYNPppiBHsX", 32 | "colab_type": "text" 33 | }, 34 | "source": [ 35 | "\n", 36 | " \n", 44 | " \n", 53 | "
\n", 37 | " Run\n", 41 | " in Google Colab\n", 43 | " \n", 45 | " View source on GitHub\n", 52 | "
" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "metadata": { 59 | "id": "g_36dg5DBDfU", 60 | "colab_type": "code", 61 | "colab": {} 62 | }, 63 | "source": [ 64 | "%matplotlib inline" 65 | ], 66 | "execution_count": 0, 67 | "outputs": [] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": { 72 | "id": "Dkuou2uxBDfv", 73 | "colab_type": "text" 74 | }, 75 | "source": [ 76 | "\n", 77 | "Reinforcement Learning (DQN) Tutorial\n", 78 | "=====================================\n", 79 | "**Author**: `Adam Paszke `_\n", 80 | "\n", 81 | "\n", 82 | "This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent\n", 83 | "on the CartPole-v0 task from the `OpenAI Gym `__.\n", 84 | "\n", 85 | "**Task**\n", 86 | "\n", 87 | "The agent has to decide between two actions - moving the cart left or\n", 88 | "right - so that the pole attached to it stays upright. You can find an\n", 89 | "official leaderboard with various algorithms and visualizations at the\n", 90 | "`Gym website `__.\n", 91 | "\n", 92 | ".. figure:: /_static/img/cartpole.gif\n", 93 | " :alt: cartpole\n", 94 | "\n", 95 | " cartpole\n", 96 | "\n", 97 | "As the agent observes the current state of the environment and chooses\n", 98 | "an action, the environment *transitions* to a new state, and also\n", 99 | "returns a reward that indicates the consequences of the action. In this\n", 100 | "task, the environment terminates if the pole falls over too far.\n", 101 | "\n", 102 | "The CartPole task is designed so that the inputs to the agent are 4 real\n", 103 | "values representing the environment state (position, velocity, etc.).\n", 104 | "However, neural networks can solve the task purely by looking at the\n", 105 | "scene, so we'll use a patch of the screen centered on the cart as an\n", 106 | "input. Because of this, our results aren't directly comparable to the\n", 107 | "ones from the official leaderboard - our task is much harder.\n", 108 | "Unfortunately this does slow down the training, because we have to\n", 109 | "render all the frames.\n", 110 | "\n", 111 | "Strictly speaking, we will present the state as the difference between\n", 112 | "the current screen patch and the previous one. This will allow the agent\n", 113 | "to take the velocity of the pole into account from one image.\n", 114 | "\n", 115 | "**Packages**\n", 116 | "\n", 117 | "\n", 118 | "First, let's import needed packages. Firstly, we need\n", 119 | "`gym `__ for the environment\n", 120 | "(Install using `pip install gym`).\n", 121 | "We'll also use the following from PyTorch:\n", 122 | "\n", 123 | "- neural networks (``torch.nn``)\n", 124 | "- optimization (``torch.optim``)\n", 125 | "- automatic differentiation (``torch.autograd``)\n", 126 | "- utilities for vision tasks (``torchvision`` - `a separate\n", 127 | " package `__).\n", 128 | "\n", 129 | "\n" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "metadata": { 135 | "id": "vwVMSbAZBDfx", 136 | "colab_type": "code", 137 | "colab": {} 138 | }, 139 | "source": [ 140 | "import gym\n", 141 | "import math\n", 142 | "import random\n", 143 | "import numpy as np\n", 144 | "import matplotlib\n", 145 | "import matplotlib.pyplot as plt\n", 146 | "from collections import namedtuple\n", 147 | "from itertools import count\n", 148 | "from PIL import Image\n", 149 | "\n", 150 | "import torch\n", 151 | "import torch.nn as nn\n", 152 | "import torch.optim as optim\n", 153 | "import torch.nn.functional as F\n", 154 | "import torchvision.transforms as T\n", 155 | "\n", 156 | "\n", 157 | "env = gym.make('CartPole-v0').unwrapped\n", 158 | "\n", 159 | "# set up matplotlib\n", 160 | "is_ipython = 'inline' in matplotlib.get_backend()\n", 161 | "if is_ipython:\n", 162 | " from IPython import display\n", 163 | "\n", 164 | "plt.ion()\n", 165 | "\n", 166 | "# if gpu is to be used\n", 167 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" 168 | ], 169 | "execution_count": 0, 170 | "outputs": [] 171 | }, 172 | { 173 | "cell_type": "markdown", 174 | "metadata": { 175 | "id": "6r39vwOWBDf3", 176 | "colab_type": "text" 177 | }, 178 | "source": [ 179 | "Replay Memory\n", 180 | "-------------\n", 181 | "\n", 182 | "We'll be using experience replay memory for training our DQN. It stores\n", 183 | "the transitions that the agent observes, allowing us to reuse this data\n", 184 | "later. By sampling from it randomly, the transitions that build up a\n", 185 | "batch are decorrelated. It has been shown that this greatly stabilizes\n", 186 | "and improves the DQN training procedure.\n", 187 | "\n", 188 | "For this, we're going to need two classses:\n", 189 | "\n", 190 | "- ``Transition`` - a named tuple representing a single transition in\n", 191 | " our environment\n", 192 | "- ``ReplayMemory`` - a cyclic buffer of bounded size that holds the\n", 193 | " transitions observed recently. It also implements a ``.sample()``\n", 194 | " method for selecting a random batch of transitions for training.\n", 195 | "\n", 196 | "\n" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "metadata": { 202 | "id": "w-I8VNEYBDf9", 203 | "colab_type": "code", 204 | "colab": {} 205 | }, 206 | "source": [ 207 | "Transition = namedtuple('Transition',\n", 208 | " ('state', 'action', 'next_state', 'reward'))\n", 209 | "\n", 210 | "\n", 211 | "class ReplayMemory(object):\n", 212 | "\n", 213 | " def __init__(self, capacity):\n", 214 | " self.capacity = capacity\n", 215 | " self.memory = []\n", 216 | " self.position = 0\n", 217 | "\n", 218 | " def push(self, *args):\n", 219 | " \"\"\"Saves a transition.\"\"\"\n", 220 | " if len(self.memory) < self.capacity:\n", 221 | " self.memory.append(None)\n", 222 | " self.memory[self.position] = Transition(*args)\n", 223 | " self.position = (self.position + 1) % self.capacity\n", 224 | "\n", 225 | " def sample(self, batch_size):\n", 226 | " return random.sample(self.memory, batch_size)\n", 227 | "\n", 228 | " def __len__(self):\n", 229 | " return len(self.memory)" 230 | ], 231 | "execution_count": 0, 232 | "outputs": [] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "metadata": { 237 | "id": "mkHPZbDaBDgB", 238 | "colab_type": "text" 239 | }, 240 | "source": [ 241 | "Now, let's define our model. But first, let quickly recap what a DQN is.\n", 242 | "\n", 243 | "DQN algorithm\n", 244 | "-------------\n", 245 | "\n", 246 | "Our environment is deterministic, so all equations presented here are\n", 247 | "also formulated deterministically for the sake of simplicity. In the\n", 248 | "reinforcement learning literature, they would also contain expectations\n", 249 | "over stochastic transitions in the environment.\n", 250 | "\n", 251 | "Our aim will be to train a policy that tries to maximize the discounted,\n", 252 | "cumulative reward\n", 253 | "$R_{t_0} = \\sum_{t=t_0}^{\\infty} \\gamma^{t - t_0} r_t$, where\n", 254 | "$R_{t_0}$ is also known as the *return*. The discount,\n", 255 | "$\\gamma$, should be a constant between $0$ and $1$\n", 256 | "that ensures the sum converges. It makes rewards from the uncertain far\n", 257 | "future less important for our agent than the ones in the near future\n", 258 | "that it can be fairly confident about.\n", 259 | "\n", 260 | "The main idea behind Q-learning is that if we had a function\n", 261 | "$Q^*: State \\times Action \\rightarrow \\mathbb{R}$, that could tell\n", 262 | "us what our return would be, if we were to take an action in a given\n", 263 | "state, then we could easily construct a policy that maximizes our\n", 264 | "rewards:\n", 265 | "\n", 266 | "\\begin{align}\\pi^*(s) = \\arg\\!\\max_a \\ Q^*(s, a)\\end{align}\n", 267 | "\n", 268 | "However, we don't know everything about the world, so we don't have\n", 269 | "access to $Q^*$. But, since neural networks are universal function\n", 270 | "approximators, we can simply create one and train it to resemble\n", 271 | "$Q^*$.\n", 272 | "\n", 273 | "For our training update rule, we'll use a fact that every $Q$\n", 274 | "function for some policy obeys the Bellman equation:\n", 275 | "\n", 276 | "\\begin{align}Q^{\\pi}(s, a) = r + \\gamma Q^{\\pi}(s', \\pi(s'))\\end{align}\n", 277 | "\n", 278 | "The difference between the two sides of the equality is known as the\n", 279 | "temporal difference error, $\\delta$:\n", 280 | "\n", 281 | "\\begin{align}\\delta = Q(s, a) - (r + \\gamma \\max_a Q(s', a))\\end{align}\n", 282 | "\n", 283 | "To minimise this error, we will use the `Huber\n", 284 | "loss `__. The Huber loss acts\n", 285 | "like the mean squared error when the error is small, but like the mean\n", 286 | "absolute error when the error is large - this makes it more robust to\n", 287 | "outliers when the estimates of $Q$ are very noisy. We calculate\n", 288 | "this over a batch of transitions, $B$, sampled from the replay\n", 289 | "memory:\n", 290 | "\n", 291 | "\\begin{align}\\mathcal{L} = \\frac{1}{|B|}\\sum_{(s, a, s', r) \\ \\in \\ B} \\mathcal{L}(\\delta)\\end{align}\n", 292 | "\n", 293 | "\\begin{align}\\text{where} \\quad \\mathcal{L}(\\delta) = \\begin{cases}\n", 294 | " \\frac{1}{2}{\\delta^2} & \\text{for } |\\delta| \\le 1, \\\\\n", 295 | " |\\delta| - \\frac{1}{2} & \\text{otherwise.}\n", 296 | " \\end{cases}\\end{align}\n", 297 | "\n", 298 | "Q-network\n", 299 | "^^^^^^^^^\n", 300 | "\n", 301 | "Our model will be a convolutional neural network that takes in the\n", 302 | "difference between the current and previous screen patches. It has two\n", 303 | "outputs, representing $Q(s, \\mathrm{left})$ and\n", 304 | "$Q(s, \\mathrm{right})$ (where $s$ is the input to the\n", 305 | "network). In effect, the network is trying to predict the *quality* of\n", 306 | "taking each action given the current input.\n", 307 | "\n", 308 | "\n" 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "metadata": { 314 | "id": "viCLZRk6BDgD", 315 | "colab_type": "code", 316 | "colab": {} 317 | }, 318 | "source": [ 319 | "class DQN(nn.Module):\n", 320 | "\n", 321 | " def __init__(self):\n", 322 | " super(DQN, self).__init__()\n", 323 | " self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)\n", 324 | " self.bn1 = nn.BatchNorm2d(16)\n", 325 | " self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)\n", 326 | " self.bn2 = nn.BatchNorm2d(32)\n", 327 | " self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)\n", 328 | " self.bn3 = nn.BatchNorm2d(32)\n", 329 | " self.head = nn.Linear(448, 2)\n", 330 | "\n", 331 | " def forward(self, x):\n", 332 | " x = F.relu(self.bn1(self.conv1(x)))\n", 333 | " x = F.relu(self.bn2(self.conv2(x)))\n", 334 | " x = F.relu(self.bn3(self.conv3(x)))\n", 335 | " return self.head(x.view(x.size(0), -1))" 336 | ], 337 | "execution_count": 0, 338 | "outputs": [] 339 | }, 340 | { 341 | "cell_type": "markdown", 342 | "metadata": { 343 | "id": "mOfzT798BDgR", 344 | "colab_type": "text" 345 | }, 346 | "source": [ 347 | "Input extraction\n", 348 | "^^^^^^^^^^^^^^^^\n", 349 | "\n", 350 | "The code below are utilities for extracting and processing rendered\n", 351 | "images from the environment. It uses the ``torchvision`` package, which\n", 352 | "makes it easy to compose image transforms. Once you run the cell it will\n", 353 | "display an example patch that it extracted.\n", 354 | "\n", 355 | "\n" 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "metadata": { 361 | "id": "3vu_v8JFBDgT", 362 | "colab_type": "code", 363 | "colab": {} 364 | }, 365 | "source": [ 366 | "resize = T.Compose([T.ToPILImage(),\n", 367 | " T.Resize(40, interpolation=Image.CUBIC),\n", 368 | " T.ToTensor()])\n", 369 | "\n", 370 | "# This is based on the code from gym.\n", 371 | "screen_width = 600\n", 372 | "\n", 373 | "\n", 374 | "def get_cart_location():\n", 375 | " world_width = env.x_threshold * 2\n", 376 | " scale = screen_width / world_width\n", 377 | " return int(env.state[0] * scale + screen_width / 2.0) # MIDDLE OF CART\n", 378 | "\n", 379 | "\n", 380 | "def get_screen():\n", 381 | " screen = env.render(mode='rgb_array').transpose(\n", 382 | " (2, 0, 1)) # transpose into torch order (CHW)\n", 383 | " # Strip off the top and bottom of the screen\n", 384 | " screen = screen[:, 160:320]\n", 385 | " view_width = 320\n", 386 | " cart_location = get_cart_location()\n", 387 | " if cart_location < view_width // 2:\n", 388 | " slice_range = slice(view_width)\n", 389 | " elif cart_location > (screen_width - view_width // 2):\n", 390 | " slice_range = slice(-view_width, None)\n", 391 | " else:\n", 392 | " slice_range = slice(cart_location - view_width // 2,\n", 393 | " cart_location + view_width // 2)\n", 394 | " # Strip off the edges, so that we have a square image centered on a cart\n", 395 | " screen = screen[:, :, slice_range]\n", 396 | " # Convert to float, rescare, convert to torch tensor\n", 397 | " # (this doesn't require a copy)\n", 398 | " screen = np.ascontiguousarray(screen, dtype=np.float32) / 255\n", 399 | " screen = torch.from_numpy(screen)\n", 400 | " # Resize, and add a batch dimension (BCHW)\n", 401 | " return resize(screen).unsqueeze(0).to(device)\n", 402 | "\n", 403 | "\n", 404 | "env.reset()\n", 405 | "plt.figure()\n", 406 | "plt.imshow(get_screen().cpu().squeeze(0).permute(1, 2, 0).numpy(),\n", 407 | " interpolation='none')\n", 408 | "plt.title('Example extracted screen')\n", 409 | "plt.show()" 410 | ], 411 | "execution_count": 0, 412 | "outputs": [] 413 | }, 414 | { 415 | "cell_type": "markdown", 416 | "metadata": { 417 | "id": "y6Ft2SUNBDga", 418 | "colab_type": "text" 419 | }, 420 | "source": [ 421 | "Training\n", 422 | "--------\n", 423 | "\n", 424 | "Hyperparameters and utilities\n", 425 | "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", 426 | "This cell instantiates our model and its optimizer, and defines some\n", 427 | "utilities:\n", 428 | "\n", 429 | "- ``select_action`` - will select an action accordingly to an epsilon\n", 430 | " greedy policy. Simply put, we'll sometimes use our model for choosing\n", 431 | " the action, and sometimes we'll just sample one uniformly. The\n", 432 | " probability of choosing a random action will start at ``EPS_START``\n", 433 | " and will decay exponentially towards ``EPS_END``. ``EPS_DECAY``\n", 434 | " controls the rate of the decay.\n", 435 | "- ``plot_durations`` - a helper for plotting the durations of episodes,\n", 436 | " along with an average over the last 100 episodes (the measure used in\n", 437 | " the official evaluations). The plot will be underneath the cell\n", 438 | " containing the main training loop, and will update after every\n", 439 | " episode.\n", 440 | "\n", 441 | "\n" 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "metadata": { 447 | "id": "JTmK3hZbBDgc", 448 | "colab_type": "code", 449 | "colab": {} 450 | }, 451 | "source": [ 452 | "BATCH_SIZE = 128\n", 453 | "GAMMA = 0.999\n", 454 | "EPS_START = 0.9\n", 455 | "EPS_END = 0.05\n", 456 | "EPS_DECAY = 200\n", 457 | "TARGET_UPDATE = 10\n", 458 | "\n", 459 | "policy_net = DQN().to(device)\n", 460 | "target_net = DQN().to(device)\n", 461 | "target_net.load_state_dict(policy_net.state_dict())\n", 462 | "target_net.eval()\n", 463 | "\n", 464 | "optimizer = optim.RMSprop(policy_net.parameters())\n", 465 | "memory = ReplayMemory(10000)\n", 466 | "\n", 467 | "\n", 468 | "steps_done = 0\n", 469 | "\n", 470 | "\n", 471 | "def select_action(state):\n", 472 | " global steps_done\n", 473 | " sample = random.random()\n", 474 | " eps_threshold = EPS_END + (EPS_START - EPS_END) * \\\n", 475 | " math.exp(-1. * steps_done / EPS_DECAY)\n", 476 | " steps_done += 1\n", 477 | " if sample > eps_threshold:\n", 478 | " with torch.no_grad():\n", 479 | " return policy_net(state).max(1)[1].view(1, 1)\n", 480 | " else:\n", 481 | " return torch.tensor([[random.randrange(2)]], device=device, dtype=torch.long)\n", 482 | "\n", 483 | "\n", 484 | "episode_durations = []\n", 485 | "\n", 486 | "\n", 487 | "def plot_durations():\n", 488 | " plt.figure(2)\n", 489 | " plt.clf()\n", 490 | " durations_t = torch.tensor(episode_durations, dtype=torch.float)\n", 491 | " plt.title('Training...')\n", 492 | " plt.xlabel('Episode')\n", 493 | " plt.ylabel('Duration')\n", 494 | " plt.plot(durations_t.numpy())\n", 495 | " # Take 100 episode averages and plot them too\n", 496 | " if len(durations_t) >= 100:\n", 497 | " means = durations_t.unfold(0, 100, 1).mean(1).view(-1)\n", 498 | " means = torch.cat((torch.zeros(99), means))\n", 499 | " plt.plot(means.numpy())\n", 500 | "\n", 501 | " plt.pause(0.001) # pause a bit so that plots are updated\n", 502 | " if is_ipython:\n", 503 | " display.clear_output(wait=True)\n", 504 | " display.display(plt.gcf())" 505 | ], 506 | "execution_count": 0, 507 | "outputs": [] 508 | }, 509 | { 510 | "cell_type": "markdown", 511 | "metadata": { 512 | "id": "ATOyYQMhBDgg", 513 | "colab_type": "text" 514 | }, 515 | "source": [ 516 | "Training loop\n", 517 | "^^^^^^^^^^^^^\n", 518 | "\n", 519 | "Finally, the code for training our model.\n", 520 | "\n", 521 | "Here, you can find an ``optimize_model`` function that performs a\n", 522 | "single step of the optimization. It first samples a batch, concatenates\n", 523 | "all the tensors into a single one, computes $Q(s_t, a_t)$ and\n", 524 | "$V(s_{t+1}) = \\max_a Q(s_{t+1}, a)$, and combines them into our\n", 525 | "loss. By defition we set $V(s) = 0$ if $s$ is a terminal\n", 526 | "state. We also use a target network to compute $V(s_{t+1})$ for\n", 527 | "added stability. The target network has its weights kept frozen most of\n", 528 | "the time, but is updated with the policy network's weights every so often.\n", 529 | "This is usually a set number of steps but we shall use episodes for\n", 530 | "simplicity.\n", 531 | "\n", 532 | "\n" 533 | ] 534 | }, 535 | { 536 | "cell_type": "code", 537 | "metadata": { 538 | "id": "193lEYtxBDgh", 539 | "colab_type": "code", 540 | "colab": {} 541 | }, 542 | "source": [ 543 | "def optimize_model():\n", 544 | " if len(memory) < BATCH_SIZE:\n", 545 | " return\n", 546 | " transitions = memory.sample(BATCH_SIZE)\n", 547 | " # Transpose the batch (see http://stackoverflow.com/a/19343/3343043 for\n", 548 | " # detailed explanation).\n", 549 | " batch = Transition(*zip(*transitions))\n", 550 | "\n", 551 | " # Compute a mask of non-final states and concatenate the batch elements\n", 552 | " non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,\n", 553 | " batch.next_state)), device=device, dtype=torch.uint8)\n", 554 | " non_final_next_states = torch.cat([s for s in batch.next_state\n", 555 | " if s is not None])\n", 556 | " state_batch = torch.cat(batch.state)\n", 557 | " action_batch = torch.cat(batch.action)\n", 558 | " reward_batch = torch.cat(batch.reward)\n", 559 | "\n", 560 | " # Compute Q(s_t, a) - the model computes Q(s_t), then we select the\n", 561 | " # columns of actions taken\n", 562 | " state_action_values = policy_net(state_batch).gather(1, action_batch)\n", 563 | "\n", 564 | " # Compute V(s_{t+1}) for all next states.\n", 565 | " next_state_values = torch.zeros(BATCH_SIZE, device=device)\n", 566 | " next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()\n", 567 | " # Compute the expected Q values\n", 568 | " expected_state_action_values = (next_state_values * GAMMA) + reward_batch\n", 569 | "\n", 570 | " # Compute Huber loss\n", 571 | " loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))\n", 572 | "\n", 573 | " # Optimize the model\n", 574 | " optimizer.zero_grad()\n", 575 | " loss.backward()\n", 576 | " for param in policy_net.parameters():\n", 577 | " param.grad.data.clamp_(-1, 1)\n", 578 | " optimizer.step()" 579 | ], 580 | "execution_count": 0, 581 | "outputs": [] 582 | }, 583 | { 584 | "cell_type": "markdown", 585 | "metadata": { 586 | "id": "VxbezziZBDgk", 587 | "colab_type": "text" 588 | }, 589 | "source": [ 590 | "Below, you can find the main training loop. At the beginning we reset\n", 591 | "the environment and initialize the ``state`` Tensor. Then, we sample\n", 592 | "an action, execute it, observe the next screen and the reward (always\n", 593 | "1), and optimize our model once. When the episode ends (our model\n", 594 | "fails), we restart the loop.\n", 595 | "\n", 596 | "Below, `num_episodes` is set small. You should download\n", 597 | "the notebook and run lot more epsiodes.\n", 598 | "\n", 599 | "\n" 600 | ] 601 | }, 602 | { 603 | "cell_type": "code", 604 | "metadata": { 605 | "id": "n6-AG1b_BDgl", 606 | "colab_type": "code", 607 | "colab": {} 608 | }, 609 | "source": [ 610 | "num_episodes = 50\n", 611 | "for i_episode in range(num_episodes):\n", 612 | " # Initialize the environment and state\n", 613 | " env.reset()\n", 614 | " last_screen = get_screen()\n", 615 | " current_screen = get_screen()\n", 616 | " state = current_screen - last_screen\n", 617 | " for t in count():\n", 618 | " # Select and perform an action\n", 619 | " action = select_action(state)\n", 620 | " _, reward, done, _ = env.step(action.item())\n", 621 | " reward = torch.tensor([reward], device=device)\n", 622 | "\n", 623 | " # Observe new state\n", 624 | " last_screen = current_screen\n", 625 | " current_screen = get_screen()\n", 626 | " if not done:\n", 627 | " next_state = current_screen - last_screen\n", 628 | " else:\n", 629 | " next_state = None\n", 630 | "\n", 631 | " # Store the transition in memory\n", 632 | " memory.push(state, action, next_state, reward)\n", 633 | "\n", 634 | " # Move to the next state\n", 635 | " state = next_state\n", 636 | "\n", 637 | " # Perform one step of the optimization (on the target network)\n", 638 | " optimize_model()\n", 639 | " if done:\n", 640 | " episode_durations.append(t + 1)\n", 641 | " plot_durations()\n", 642 | " break\n", 643 | " # Update the target network\n", 644 | " if i_episode % TARGET_UPDATE == 0:\n", 645 | " target_net.load_state_dict(policy_net.state_dict())\n", 646 | "\n", 647 | "print('Complete')\n", 648 | "env.render()\n", 649 | "env.close()\n", 650 | "plt.ioff()\n", 651 | "plt.show()" 652 | ], 653 | "execution_count": 0, 654 | "outputs": [] 655 | } 656 | ] 657 | } -------------------------------------------------------------------------------- /transfer-learning/pytorch/cnn_transfer_learning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "display_name": "Python 3", 7 | "language": "python", 8 | "name": "python3" 9 | }, 10 | "language_info": { 11 | "codemirror_mode": { 12 | "name": "ipython", 13 | "version": 3 14 | }, 15 | "file_extension": ".py", 16 | "mimetype": "text/x-python", 17 | "name": "python", 18 | "nbconvert_exporter": "python", 19 | "pygments_lexer": "ipython3", 20 | "version": "3.6.6" 21 | }, 22 | "colab": { 23 | "name": "transfer_learning_tutorial.ipynb", 24 | "provenance": [] 25 | } 26 | }, 27 | "cells": [ 28 | { 29 | "cell_type": "markdown", 30 | "metadata": { 31 | "id": "rUB3UHbfhab1", 32 | "colab_type": "text" 33 | }, 34 | "source": [ 35 | "\n", 36 | " \n", 44 | " \n", 53 | "
\n", 37 | " Run\n", 41 | " in Google Colab\n", 43 | " \n", 45 | " View source on GitHub\n", 52 | "
" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "metadata": { 59 | "id": "fG9DkesQg6hE", 60 | "colab_type": "code", 61 | "colab": {} 62 | }, 63 | "source": [ 64 | "%matplotlib inline" 65 | ], 66 | "execution_count": 0, 67 | "outputs": [] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": { 72 | "id": "rMnwyaZGg6hN", 73 | "colab_type": "text" 74 | }, 75 | "source": [ 76 | "\n", 77 | "Transfer Learning Tutorial\n", 78 | "==========================\n", 79 | "**Author**: [Sasank Chilamkurthy](https://chsasank.github.io)\n", 80 | "\n", 81 | "In this tutorial, you will learn how to train your network using\n", 82 | "transfer learning. You can read more about the transfer learning at [cs231n](http://cs231n.github.io/transfer-learning/).\n", 83 | "\n", 84 | "Quoting these notes,\n", 85 | "\n", 86 | " In practice, very few people train an entire Convolutional Network\n", 87 | " from scratch (with random initialization), because it is relatively\n", 88 | " rare to have a dataset of sufficient size. Instead, it is common to\n", 89 | " pretrain a ConvNet on a very large dataset (e.g. ImageNet, which\n", 90 | " contains 1.2 million images with 1000 categories), and then use the\n", 91 | " ConvNet either as an initialization or a fixed feature extractor for\n", 92 | " the task of interest.\n", 93 | "\n", 94 | "These two major transfer learning scenarios look as follows:\n", 95 | "\n", 96 | "- **Finetuning the convnet**: Instead of random initializaion, we\n", 97 | " initialize the network with a pretrained network, like the one that is\n", 98 | " trained on imagenet 1000 dataset. Rest of the training looks as\n", 99 | " usual.\n", 100 | "- **ConvNet as fixed feature extractor**: Here, we will freeze the weights\n", 101 | " for all of the network except that of the final fully connected\n", 102 | " layer. This last fully connected layer is replaced with a new one\n", 103 | " with random weights and only this layer is trained.\n", 104 | "\n", 105 | "\n" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "metadata": { 111 | "id": "RsxLS88ag6hP", 112 | "colab_type": "code", 113 | "colab": {} 114 | }, 115 | "source": [ 116 | "# License: BSD\n", 117 | "# Author: Sasank Chilamkurthy\n", 118 | "\n", 119 | "from __future__ import print_function, division\n", 120 | "\n", 121 | "import torch\n", 122 | "import torch.nn as nn\n", 123 | "import torch.optim as optim\n", 124 | "from torch.optim import lr_scheduler\n", 125 | "import numpy as np\n", 126 | "import torchvision\n", 127 | "from torchvision import datasets, models, transforms\n", 128 | "import matplotlib.pyplot as plt\n", 129 | "import time\n", 130 | "import os\n", 131 | "import copy\n", 132 | "\n", 133 | "plt.ion() # interactive mode" 134 | ], 135 | "execution_count": 0, 136 | "outputs": [] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": { 141 | "id": "ilD2-Npvg6hT", 142 | "colab_type": "text" 143 | }, 144 | "source": [ 145 | "Load Data\n", 146 | "---------\n", 147 | "\n", 148 | "We will use torchvision and torch.utils.data packages for loading the\n", 149 | "data.\n", 150 | "\n", 151 | "The problem we're going to solve today is to train a model to classify\n", 152 | "**ants** and **bees**. We have about 120 training images each for ants and bees.\n", 153 | "There are 75 validation images for each class. Usually, this is a very\n", 154 | "small dataset to generalize upon, if trained from scratch. Since we\n", 155 | "are using transfer learning, we should be able to generalize reasonably\n", 156 | "well.\n", 157 | "\n", 158 | "This dataset is a very small subset of imagenet.\n", 159 | "\n", 160 | ".. Note ::\n", 161 | " Download the data from\n", 162 | " `here `_\n", 163 | " and extract it to the current directory.\n", 164 | "\n" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "metadata": { 170 | "id": "I92RS7m9g6hV", 171 | "colab_type": "code", 172 | "colab": {} 173 | }, 174 | "source": [ 175 | "# Data augmentation and normalization for training\n", 176 | "# Just normalization for validation\n", 177 | "data_transforms = {\n", 178 | " 'train': transforms.Compose([\n", 179 | " transforms.RandomResizedCrop(224),\n", 180 | " transforms.RandomHorizontalFlip(),\n", 181 | " transforms.ToTensor(),\n", 182 | " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", 183 | " ]),\n", 184 | " 'val': transforms.Compose([\n", 185 | " transforms.Resize(256),\n", 186 | " transforms.CenterCrop(224),\n", 187 | " transforms.ToTensor(),\n", 188 | " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", 189 | " ]),\n", 190 | "}\n", 191 | "\n", 192 | "data_dir = 'hymenoptera_data'\n", 193 | "image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),\n", 194 | " data_transforms[x])\n", 195 | " for x in ['train', 'val']}\n", 196 | "dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,\n", 197 | " shuffle=True, num_workers=4)\n", 198 | " for x in ['train', 'val']}\n", 199 | "dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}\n", 200 | "class_names = image_datasets['train'].classes\n", 201 | "\n", 202 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" 203 | ], 204 | "execution_count": 0, 205 | "outputs": [] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "metadata": { 210 | "id": "yBw1Osdvg6hZ", 211 | "colab_type": "text" 212 | }, 213 | "source": [ 214 | "Visualize a few images\n", 215 | "^^^^^^^^^^^^^^^^^^^^^^\n", 216 | "Let's visualize a few training images so as to understand the data\n", 217 | "augmentations.\n", 218 | "\n" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "metadata": { 224 | "id": "cMwVVMbWg6ha", 225 | "colab_type": "code", 226 | "colab": {} 227 | }, 228 | "source": [ 229 | "def imshow(inp, title=None):\n", 230 | " \"\"\"Imshow for Tensor.\"\"\"\n", 231 | " inp = inp.numpy().transpose((1, 2, 0))\n", 232 | " mean = np.array([0.485, 0.456, 0.406])\n", 233 | " std = np.array([0.229, 0.224, 0.225])\n", 234 | " inp = std * inp + mean\n", 235 | " inp = np.clip(inp, 0, 1)\n", 236 | " plt.imshow(inp)\n", 237 | " if title is not None:\n", 238 | " plt.title(title)\n", 239 | " plt.pause(0.001) # pause a bit so that plots are updated\n", 240 | "\n", 241 | "\n", 242 | "# Get a batch of training data\n", 243 | "inputs, classes = next(iter(dataloaders['train']))\n", 244 | "\n", 245 | "# Make a grid from batch\n", 246 | "out = torchvision.utils.make_grid(inputs)\n", 247 | "\n", 248 | "imshow(out, title=[class_names[x] for x in classes])" 249 | ], 250 | "execution_count": 0, 251 | "outputs": [] 252 | }, 253 | { 254 | "cell_type": "markdown", 255 | "metadata": { 256 | "id": "pF4UDc35g6he", 257 | "colab_type": "text" 258 | }, 259 | "source": [ 260 | "Training the model\n", 261 | "------------------\n", 262 | "\n", 263 | "Now, let's write a general function to train a model. Here, we will\n", 264 | "illustrate:\n", 265 | "\n", 266 | "- Scheduling the learning rate\n", 267 | "- Saving the best model\n", 268 | "\n", 269 | "In the following, parameter ``scheduler`` is an LR scheduler object from\n", 270 | "``torch.optim.lr_scheduler``.\n", 271 | "\n" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "metadata": { 277 | "id": "dPh0TU2ng6hf", 278 | "colab_type": "code", 279 | "colab": {} 280 | }, 281 | "source": [ 282 | "def train_model(model, criterion, optimizer, scheduler, num_epochs=25):\n", 283 | " since = time.time()\n", 284 | "\n", 285 | " best_model_wts = copy.deepcopy(model.state_dict())\n", 286 | " best_acc = 0.0\n", 287 | "\n", 288 | " for epoch in range(num_epochs):\n", 289 | " print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n", 290 | " print('-' * 10)\n", 291 | "\n", 292 | " # Each epoch has a training and validation phase\n", 293 | " for phase in ['train', 'val']:\n", 294 | " if phase == 'train':\n", 295 | " scheduler.step()\n", 296 | " model.train() # Set model to training mode\n", 297 | " else:\n", 298 | " model.eval() # Set model to evaluate mode\n", 299 | "\n", 300 | " running_loss = 0.0\n", 301 | " running_corrects = 0\n", 302 | "\n", 303 | " # Iterate over data.\n", 304 | " for inputs, labels in dataloaders[phase]:\n", 305 | " inputs = inputs.to(device)\n", 306 | " labels = labels.to(device)\n", 307 | "\n", 308 | " # zero the parameter gradients\n", 309 | " optimizer.zero_grad()\n", 310 | "\n", 311 | " # forward\n", 312 | " # track history if only in train\n", 313 | " with torch.set_grad_enabled(phase == 'train'):\n", 314 | " outputs = model(inputs)\n", 315 | " _, preds = torch.max(outputs, 1)\n", 316 | " loss = criterion(outputs, labels)\n", 317 | "\n", 318 | " # backward + optimize only if in training phase\n", 319 | " if phase == 'train':\n", 320 | " loss.backward()\n", 321 | " optimizer.step()\n", 322 | "\n", 323 | " # statistics\n", 324 | " running_loss += loss.item() * inputs.size(0)\n", 325 | " running_corrects += torch.sum(preds == labels.data)\n", 326 | "\n", 327 | " epoch_loss = running_loss / dataset_sizes[phase]\n", 328 | " epoch_acc = running_corrects.double() / dataset_sizes[phase]\n", 329 | "\n", 330 | " print('{} Loss: {:.4f} Acc: {:.4f}'.format(\n", 331 | " phase, epoch_loss, epoch_acc))\n", 332 | "\n", 333 | " # deep copy the model\n", 334 | " if phase == 'val' and epoch_acc > best_acc:\n", 335 | " best_acc = epoch_acc\n", 336 | " best_model_wts = copy.deepcopy(model.state_dict())\n", 337 | "\n", 338 | " print()\n", 339 | "\n", 340 | " time_elapsed = time.time() - since\n", 341 | " print('Training complete in {:.0f}m {:.0f}s'.format(\n", 342 | " time_elapsed // 60, time_elapsed % 60))\n", 343 | " print('Best val Acc: {:4f}'.format(best_acc))\n", 344 | "\n", 345 | " # load best model weights\n", 346 | " model.load_state_dict(best_model_wts)\n", 347 | " return model" 348 | ], 349 | "execution_count": 0, 350 | "outputs": [] 351 | }, 352 | { 353 | "cell_type": "markdown", 354 | "metadata": { 355 | "id": "0FMRK-Fjg6hj", 356 | "colab_type": "text" 357 | }, 358 | "source": [ 359 | "Visualizing the model predictions\n", 360 | "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", 361 | "\n", 362 | "Generic function to display predictions for a few images\n", 363 | "\n", 364 | "\n" 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "metadata": { 370 | "id": "dJqpiavKg6hk", 371 | "colab_type": "code", 372 | "colab": {} 373 | }, 374 | "source": [ 375 | "def visualize_model(model, num_images=6):\n", 376 | " was_training = model.training\n", 377 | " model.eval()\n", 378 | " images_so_far = 0\n", 379 | " fig = plt.figure()\n", 380 | "\n", 381 | " with torch.no_grad():\n", 382 | " for i, (inputs, labels) in enumerate(dataloaders['val']):\n", 383 | " inputs = inputs.to(device)\n", 384 | " labels = labels.to(device)\n", 385 | "\n", 386 | " outputs = model(inputs)\n", 387 | " _, preds = torch.max(outputs, 1)\n", 388 | "\n", 389 | " for j in range(inputs.size()[0]):\n", 390 | " images_so_far += 1\n", 391 | " ax = plt.subplot(num_images//2, 2, images_so_far)\n", 392 | " ax.axis('off')\n", 393 | " ax.set_title('predicted: {}'.format(class_names[preds[j]]))\n", 394 | " imshow(inputs.cpu().data[j])\n", 395 | "\n", 396 | " if images_so_far == num_images:\n", 397 | " model.train(mode=was_training)\n", 398 | " return\n", 399 | " model.train(mode=was_training)" 400 | ], 401 | "execution_count": 0, 402 | "outputs": [] 403 | }, 404 | { 405 | "cell_type": "markdown", 406 | "metadata": { 407 | "id": "f4eRYWz4g6hn", 408 | "colab_type": "text" 409 | }, 410 | "source": [ 411 | "Finetuning the convnet\n", 412 | "----------------------\n", 413 | "\n", 414 | "Load a pretrained model and reset final fully connected layer.\n", 415 | "\n", 416 | "\n" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "metadata": { 422 | "id": "_nircGErg6ho", 423 | "colab_type": "code", 424 | "colab": {} 425 | }, 426 | "source": [ 427 | "model_ft = models.resnet18(pretrained=True)\n", 428 | "num_ftrs = model_ft.fc.in_features\n", 429 | "model_ft.fc = nn.Linear(num_ftrs, 2)\n", 430 | "\n", 431 | "model_ft = model_ft.to(device)\n", 432 | "\n", 433 | "criterion = nn.CrossEntropyLoss()\n", 434 | "\n", 435 | "# Observe that all parameters are being optimized\n", 436 | "optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)\n", 437 | "\n", 438 | "# Decay LR by a factor of 0.1 every 7 epochs\n", 439 | "exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)" 440 | ], 441 | "execution_count": 0, 442 | "outputs": [] 443 | }, 444 | { 445 | "cell_type": "markdown", 446 | "metadata": { 447 | "id": "ShuNlu2fg6hw", 448 | "colab_type": "text" 449 | }, 450 | "source": [ 451 | "Train and evaluate\n", 452 | "^^^^^^^^^^^^^^^^^^\n", 453 | "\n", 454 | "It should take around 15-25 min on CPU. On GPU though, it takes less than a\n", 455 | "minute.\n", 456 | "\n", 457 | "\n" 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "metadata": { 463 | "id": "UBmUB5cpg6hx", 464 | "colab_type": "code", 465 | "colab": {} 466 | }, 467 | "source": [ 468 | "model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,\n", 469 | " num_epochs=25)" 470 | ], 471 | "execution_count": 0, 472 | "outputs": [] 473 | }, 474 | { 475 | "cell_type": "code", 476 | "metadata": { 477 | "id": "GKLtRValg6h0", 478 | "colab_type": "code", 479 | "colab": {} 480 | }, 481 | "source": [ 482 | "visualize_model(model_ft)" 483 | ], 484 | "execution_count": 0, 485 | "outputs": [] 486 | }, 487 | { 488 | "cell_type": "markdown", 489 | "metadata": { 490 | "id": "DcGuNMw5g6h3", 491 | "colab_type": "text" 492 | }, 493 | "source": [ 494 | "ConvNet as fixed feature extractor\n", 495 | "----------------------------------\n", 496 | "\n", 497 | "Here, we need to freeze all the network except the final layer. We need\n", 498 | "to set ``requires_grad == False`` to freeze the parameters so that the\n", 499 | "gradients are not computed in ``backward()``.\n", 500 | "\n", 501 | "You can read more about this in the documentation\n", 502 | "`here `__.\n", 503 | "\n", 504 | "\n" 505 | ] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "metadata": { 510 | "id": "bcYamd5Ig6h4", 511 | "colab_type": "code", 512 | "colab": {} 513 | }, 514 | "source": [ 515 | "model_conv = torchvision.models.resnet18(pretrained=True)\n", 516 | "for param in model_conv.parameters():\n", 517 | " param.requires_grad = False\n", 518 | "\n", 519 | "# Parameters of newly constructed modules have requires_grad=True by default\n", 520 | "num_ftrs = model_conv.fc.in_features\n", 521 | "model_conv.fc = nn.Linear(num_ftrs, 2)\n", 522 | "\n", 523 | "model_conv = model_conv.to(device)\n", 524 | "\n", 525 | "criterion = nn.CrossEntropyLoss()\n", 526 | "\n", 527 | "# Observe that only parameters of final layer are being optimized as\n", 528 | "# opoosed to before.\n", 529 | "optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)\n", 530 | "\n", 531 | "# Decay LR by a factor of 0.1 every 7 epochs\n", 532 | "exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)" 533 | ], 534 | "execution_count": 0, 535 | "outputs": [] 536 | }, 537 | { 538 | "cell_type": "markdown", 539 | "metadata": { 540 | "id": "ullF4ZHWg6h7", 541 | "colab_type": "text" 542 | }, 543 | "source": [ 544 | "Train and evaluate\n", 545 | "^^^^^^^^^^^^^^^^^^\n", 546 | "\n", 547 | "On CPU this will take about half the time compared to previous scenario.\n", 548 | "This is expected as gradients don't need to be computed for most of the\n", 549 | "network. However, forward does need to be computed.\n", 550 | "\n", 551 | "\n" 552 | ] 553 | }, 554 | { 555 | "cell_type": "code", 556 | "metadata": { 557 | "id": "BGx3qzX_g6h7", 558 | "colab_type": "code", 559 | "colab": {} 560 | }, 561 | "source": [ 562 | "model_conv = train_model(model_conv, criterion, optimizer_conv,\n", 563 | " exp_lr_scheduler, num_epochs=25)" 564 | ], 565 | "execution_count": 0, 566 | "outputs": [] 567 | }, 568 | { 569 | "cell_type": "code", 570 | "metadata": { 571 | "id": "h9ygnBsvg6h-", 572 | "colab_type": "code", 573 | "colab": {} 574 | }, 575 | "source": [ 576 | "visualize_model(model_conv)\n", 577 | "\n", 578 | "plt.ioff()\n", 579 | "plt.show()" 580 | ], 581 | "execution_count": 0, 582 | "outputs": [] 583 | } 584 | ] 585 | } --------------------------------------------------------------------------------