├── .dockerignore ├── .gitignore ├── LICENCE ├── README.md ├── data └── .gitignore ├── docker-compose.gpu.yml ├── docker-compose.yml ├── docker ├── Dockerfile └── Dockerfile.gpu ├── docs ├── docker.md ├── googlecloud.md ├── history.csv ├── model_sizes.png └── timeline.png ├── img ├── book_cover.png └── structure.png ├── notebooks ├── 02_deeplearning │ ├── 01_mlp │ │ ├── checkpoint │ │ │ └── .gitignore │ │ ├── logs │ │ │ └── .gitignore │ │ ├── mlp.ipynb │ │ ├── models │ │ │ └── .gitignore │ │ └── output │ │ │ └── .gitignore │ └── 02_cnn │ │ ├── checkpoint │ │ └── .gitignore │ │ ├── cnn.ipynb │ │ ├── convolutions.ipynb │ │ ├── logs │ │ └── .gitignore │ │ ├── models │ │ └── .gitignore │ │ └── output │ │ └── .gitignore ├── 03_vae │ ├── 01_autoencoder │ │ ├── autoencoder.ipynb │ │ ├── checkpoint │ │ │ └── .gitignore │ │ ├── logs │ │ │ └── .gitignore │ │ ├── models │ │ │ └── .gitignore │ │ └── output │ │ │ └── .gitignore │ ├── 02_vae_fashion │ │ ├── checkpoint │ │ │ └── .gitignore │ │ ├── logs │ │ │ └── .gitignore │ │ ├── models │ │ │ └── .gitignore │ │ ├── output │ │ │ └── .gitignore │ │ └── vae_fashion.ipynb │ └── 03_vae_faces │ │ ├── checkpoint │ │ └── .gitignore │ │ ├── logs │ │ └── .gitignore │ │ ├── models │ │ └── .gitignore │ │ ├── output │ │ └── .gitignore │ │ ├── vae_faces.ipynb │ │ └── vae_utils.py ├── 04_gan │ ├── 01_dcgan │ │ ├── checkpoint │ │ │ └── .gitignore │ │ ├── dcgan.ipynb │ │ ├── logs │ │ │ └── .gitignore │ │ ├── models │ │ │ └── .gitignore │ │ └── output │ │ │ └── .gitignore │ ├── 02_wgan_gp │ │ ├── checkpoint │ │ │ └── .gitignore │ │ ├── logs │ │ │ └── .gitignore │ │ ├── models │ │ │ └── .gitignore │ │ ├── output │ │ │ └── .gitignore │ │ └── wgan_gp.ipynb │ └── 03_cgan │ │ ├── cgan.ipynb │ │ ├── checkpoint │ │ └── .gitignore │ │ ├── logs │ │ └── .gitignore │ │ ├── models │ │ └── .gitignore │ │ └── output │ │ └── .gitignore ├── 05_autoregressive │ ├── 01_lstm │ │ ├── checkpoint │ │ │ └── .gitignore │ │ ├── logs │ │ │ └── .gitignore │ │ ├── lstm.ipynb │ │ ├── models │ │ │ └── .gitignore │ │ └── output │ │ │ └── .gitignore │ ├── 02_pixelcnn │ │ ├── checkpoint │ │ │ └── .gitignore │ │ ├── logs │ │ │ └── .gitignore │ │ ├── models │ │ │ └── .gitignore │ │ ├── output │ │ │ └── .gitignore │ │ └── pixelcnn.ipynb │ └── 03_pixelcnn_md │ │ ├── checkpoint │ │ └── .gitignore │ │ ├── logs │ │ └── .gitignore │ │ ├── models │ │ └── .gitignore │ │ ├── output │ │ └── .gitignore │ │ └── pixelcnn_md.ipynb ├── 06_normflow │ └── 01_realnvp │ │ ├── checkpoint │ │ └── .gitignore │ │ ├── logs │ │ └── .gitignore │ │ ├── models │ │ └── .gitignore │ │ ├── output │ │ └── .gitignore │ │ └── realnvp.ipynb ├── 07_ebm │ └── 01_ebm │ │ ├── checkpoint │ │ └── .gitignore │ │ ├── ebm.ipynb │ │ ├── logs │ │ └── .gitignore │ │ ├── models │ │ └── .gitignore │ │ └── output │ │ └── .gitignore ├── 08_diffusion │ └── 01_ddm │ │ ├── checkpoint │ │ └── .gitignore │ │ ├── ddm.ipynb │ │ ├── logs │ │ └── .gitignore │ │ ├── models │ │ └── .gitignore │ │ ├── output │ │ └── .gitignore │ │ └── sinusoidal_embedding.ipynb ├── 09_transformer │ └── gpt │ │ ├── checkpoint │ │ └── .gitignore │ │ ├── gpt.ipynb │ │ ├── logs │ │ └── .gitignore │ │ ├── models │ │ └── .gitignore │ │ └── output │ │ └── .gitignore ├── 11_music │ ├── 01_transformer │ │ ├── checkpoint │ │ │ └── .gitignore │ │ ├── logs │ │ │ └── .gitignore │ │ ├── models │ │ │ └── .gitignore │ │ ├── output │ │ │ └── .gitignore │ │ ├── parsed_data │ │ │ └── .gitignore │ │ ├── transformer.ipynb │ │ └── transformer_utils.py │ └── 02_musegan │ │ ├── checkpoint │ │ └── .gitignore │ │ ├── logs │ │ └── .gitignore │ │ ├── models │ │ └── .gitignore │ │ ├── musegan.ipynb │ │ ├── musegan_utils.py │ │ └── output │ │ └── .gitignore └── utils.py ├── requirements.txt ├── sample.env └── scripts ├── download.sh ├── downloaders ├── download_bach_cello_data.sh ├── download_bach_chorale_data.sh └── download_kaggle_data.sh ├── format.sh └── tensorboard.sh /.dockerignore: -------------------------------------------------------------------------------- 1 | data/ 2 | .git/ -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | docs/history.twbx 2 | 3 | .DS_Store 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # Environments 89 | .env 90 | .venv 91 | env/ 92 | venv/ 93 | ENV/ 94 | env.bak/ 95 | venv.bak/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | 110 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🦜 Generative Deep Learning - 2nd Edition Codebase 2 | 3 | The official code repository for the second edition of the O'Reilly book *Generative Deep Learning: Teaching Machines to Paint, Write, Compose and Play*. 4 | 5 | [O'Reilly link](https://www.oreilly.com/library/view/generative-deep-learning/9781098134174/) 6 | 7 | [Amazon US link](https://www.amazon.com/Generative-Deep-Learning-Teaching-Machines/dp/1098134184/) 8 | 9 | 10 | 11 | ## 📖 Book Chapters 12 | 13 | Below is a outline of the book chapters: 14 | 15 | *Part I: Introduction to Generative Deep Learning* 16 | 17 | 1. Generative Modeling 18 | 2. Deep Learning 19 | 20 | *Part II: Methods* 21 | 22 | 3. Variational Autoencoders 23 | 4. Generative Adversarial Networks 24 | 5. Autoregressive Models 25 | 6. Normalizing Flows 26 | 7. Energy-Based Models 27 | 8. Diffusion Models 28 | 29 | *Part III: Applications* 30 | 31 | 9. Transformers 32 | 10. Advanced GANs 33 | 11. Music Generation 34 | 12. World Models 35 | 13. Multimodal Models 36 | 14. Conclusion 37 | 38 | ## 🌟 Star History 39 | 40 | 41 | 42 | ## 🚀 Getting Started 43 | 44 | ### Kaggle API 45 | 46 | To download some of the datasets for the book, you will need a Kaggle account and an API token. To use the Kaggle API: 47 | 48 | 1. Sign up for a [Kaggle account](https://www.kaggle.com). 49 | 2. Go to the 'Account' tab of your user profile 50 | 3. Select 'Create API Token'. This will trigger the download of `kaggle.json`, a file containing your API credentials. 51 | 52 | ### The .env file 53 | 54 | Create a file called `.env` in the root directory, containing the following values (replacing the Kaggle username and API key with the values from the JSON): 55 | 56 | ``` 57 | JUPYTER_PORT=8888 58 | TENSORBOARD_PORT=6006 59 | KAGGLE_USERNAME= 60 | KAGGLE_KEY= 61 | ``` 62 | 63 | ### Get set up with Docker 64 | 65 | This codebase is designed to be run with [Docker](https://docs.docker.com/get-docker/). 66 | 67 | If you've never used Docker before, don't worry! I have included a guide to Docker in the [Docker README](./docs/docker.md) file in this repository. This includes a full run through of why Docker is awesome and a brief guide to the `Dockerfile` and `docker-compose.yml` for this project. 68 | 69 | ### Building the Docker image 70 | 71 | If you do not have a GPU, run the following command: 72 | 73 | ``` 74 | docker compose build 75 | ``` 76 | 77 | If you do have a GPU that you wish to use, run the following command: 78 | 79 | ``` 80 | docker compose -f docker-compose.gpu.yml build 81 | ``` 82 | 83 | ### Running the container 84 | 85 | If you do not have a GPU, run the following command: 86 | 87 | ``` 88 | docker compose up 89 | ``` 90 | 91 | If you do have a GPU that you wish to use, run the following command: 92 | 93 | ``` 94 | docker compose -f docker-compose.gpu.yml up 95 | ``` 96 | 97 | Jupyter will be available in your local browser, on the port specified in your env file - for example 98 | 99 | ``` 100 | http://localhost:8888 101 | ``` 102 | 103 | The notebooks that accompany the book are available in the `/notebooks` folder, organized by chapter and example. 104 | 105 | ## 🏞️ Downloading data 106 | 107 | The codebase comes with an in-built data downloader helper script. 108 | 109 | Run the data downloader as follows (from outside the container), choosing one of the named datasets below: 110 | 111 | ``` 112 | bash scripts/download.sh [faces, bricks, recipes, flowers, wines, cellosuites, chorales] 113 | ``` 114 | 115 | ## 📈 Tensorboard 116 | 117 | Tensorboard is really useful for monitoring experiments and seeing how your model training is progressing. 118 | 119 | To launch Tensorboard, run the following script (from outside the container): 120 | * `` - the required chapter (e.g. `03_vae`) 121 | * `` - the required example (e.g. `02_vae_fashion`) 122 | 123 | ``` 124 | bash scripts/tensorboard.sh 125 | ``` 126 | 127 | Tensorboard will be available in your local browser on the port specified in your `.env` file - for example: 128 | ``` 129 | http://localhost:6006 130 | ``` 131 | 132 | ## ☁️ Using a cloud virtual machine 133 | 134 | To set up a virtual machine with GPU in Google Cloud Platform, follow the instructions in the [Google Cloud README](./docs/googlecloud.md) file in this repository. 135 | 136 | ## 📦 Other resources 137 | 138 | Some of the examples in this book are adapted from the excellent open source implementations that are available through the [Keras website](https://keras.io/examples/generative/). I highly recommend you check out this resource as new models and examples are constantly being added. 139 | 140 | 141 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /docker-compose.gpu.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | 3 | services: 4 | app: 5 | build: 6 | context: . 7 | dockerfile: ./docker/Dockerfile.gpu 8 | tty: true 9 | volumes: 10 | - ./data/:/app/data 11 | - ./notebooks/:/app/notebooks 12 | - ./scripts/:/app/scripts 13 | deploy: 14 | resources: 15 | reservations: 16 | devices: 17 | - driver: nvidia 18 | count: 1 19 | capabilities: [gpu] 20 | ports: 21 | - "$JUPYTER_PORT:$JUPYTER_PORT" 22 | - "$TENSORBOARD_PORT:$TENSORBOARD_PORT" 23 | env_file: 24 | - ./.env 25 | entrypoint: jupyter lab --ip 0.0.0.0 --port=$JUPYTER_PORT --no-browser --allow-root 26 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | 3 | services: 4 | app: 5 | build: 6 | context: . 7 | dockerfile: ./docker/Dockerfile 8 | tty: true 9 | volumes: 10 | - ./data/:/app/data 11 | - ./notebooks/:/app/notebooks 12 | - ./scripts/:/app/scripts 13 | ports: 14 | - "$JUPYTER_PORT:$JUPYTER_PORT" 15 | - "$TENSORBOARD_PORT:$TENSORBOARD_PORT" 16 | env_file: 17 | - ./.env 18 | entrypoint: jupyter lab --ip 0.0.0.0 --port=$JUPYTER_PORT --no-browser --allow-root 19 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:20.04 2 | 3 | ENV DEBIAN_FRONTEND noninteractive 4 | 5 | RUN apt-get update 6 | RUN apt-get install -y unzip graphviz curl musescore3 python3-pip 7 | 8 | RUN pip install --upgrade pip 9 | 10 | WORKDIR /app 11 | 12 | COPY ./requirements.txt /app 13 | RUN pip install -r /app/requirements.txt 14 | 15 | # Hack to get around tensorflow-io issue - https://github.com/tensorflow/io/issues/1755 16 | RUN pip install tensorflow-io 17 | RUN pip uninstall -y tensorflow-io 18 | 19 | COPY /notebooks/. /app/notebooks 20 | COPY /scripts/. /app/scripts 21 | 22 | ENV PYTHONPATH="${PYTHONPATH}:/app" -------------------------------------------------------------------------------- /docker/Dockerfile.gpu: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:2.10.1-gpu 2 | 3 | RUN rm /etc/apt/sources.list.d/cuda.list 4 | 5 | RUN apt-get update 6 | RUN apt-get install -y unzip graphviz curl musescore3 7 | 8 | RUN pip install --upgrade pip 9 | 10 | WORKDIR /app 11 | 12 | COPY ./requirements.txt /app 13 | RUN pip install -r /app/requirements.txt 14 | 15 | # Hack to get around tensorflow-io issue - https://github.com/tensorflow/io/issues/1755 16 | RUN pip install tensorflow-io 17 | RUN pip uninstall -y tensorflow-io 18 | 19 | COPY /notebooks/. /app/notebooks 20 | COPY /scripts/. /app/scripts 21 | 22 | ENV PYTHONPATH="${PYTHONPATH}:/app" -------------------------------------------------------------------------------- /docs/docker.md: -------------------------------------------------------------------------------- 1 | # 🐳 Getting started with Docker 2 | 3 | Before we start, you will need to download Docker Desktop for free from https://www.docker.com/products/docker-desktop/ 4 | 5 | In this section, we will explain the reason for using Docker and how to get started. 6 | 7 | The overall structure of the codebase and how Docker can be used to interact with it is shown below. 8 | 9 | 10 | 11 | We will run through everything in the diagram in this `README`. 12 | 13 | ## 🙋‍♂️ Why is Docker useful? 14 | 15 | A common problem you may have faced when using other people's code is that it simply doesn't run on your machine. This can be especially frustrating, especially when it runs just fine on the developer's machine and they cannot help you with specific bugs that occur when deploying onto your operating system. 16 | 17 | There are two reasons why this might be the case: 18 | 19 | 1. The underlying Python version or packages that you have installed are not identical to the ones that were used during development of the codebase 20 | 21 | 2. There are fundamental differences between your operating system and the developer's operating system that mean that the code does not run in the same way. 22 | 23 | The first of these problems can be mostly solved by using *virtual environments*. The idea here is that you create an environment for each project that contains the required Python interpreter and specific required packages and is separated from all other projects on your machine. For example, you could have one project running Python 3.8 with TensorFlow 2.10 and another running Python 3.7 with TensorFlow 2.8. Both versions of Python and TensorFlow would exist on your machine, but within isolated environments so they cannot conflict with each other. 24 | 25 | Anaconda and Virtualenv are two popular ways to create virtual environments. Typically, codebases are shipped with a `requirements.txt` file that contains the specific Python package versions that you need to run the codebase. 26 | 27 | However, solely using virtual environments does not solve the second problem - how can you ensure that your operating system is set up in the same way as the developer's? For example, the code may have been developed on a Windows machine, but you have a Mac and the errors you are seeing are specific to macOS. Or perhaps the codebase manipulates the filesystem using scripts written in `bash`, which you do not have on your machine. 28 | 29 | Docker solves this problem. Instead of specifying only the required Python packages, the developer specifies a recipe to build what is known as an *image*, which contains everything required to run the app, including an operating system and all dependencies. The recipe is called the `Dockerfile` - it is a bit like the `requirements.txt` files, but for the entire runtime environment rather than just the Python packages. 30 | 31 | Let's first take a look at the Dockerfile for the Generative Deep Learning codebase and see how it contains all the information required to build the image. 32 | 33 | ## 📝 The Dockerfile 34 | 35 | In the codebase that you pulled from GitHub, there is a file simply called 'Dockerfile' inside the `docker` folder. This is the recipe that Docker will use to build the image. We'll walk through line by line, explaining what each step does. 36 | 37 | ``` 38 | FROM ubuntu:20.04 #<1> 39 | ENV DEBIAN_FRONTEND noninteractive 40 | 41 | RUN apt-get update #<2> 42 | RUN apt-get install -y unzip graphviz curl musescore3 python3-pip 43 | 44 | RUN pip install --upgrade pip #<3> 45 | 46 | WORKDIR /app #<4> 47 | 48 | COPY ./requirements.txt /app #<5> 49 | RUN pip install -r /app/requirements.txt 50 | 51 | # Hack to get around tensorflow-io issue - https://github.com/tensorflow/io/issues/1755 52 | RUN pip install tensorflow-io 53 | RUN pip uninstall -y tensorflow-io 54 | 55 | COPY /notebooks/. /app/notebooks #<6> 56 | COPY /scripts/. /app/scripts 57 | 58 | ENV PYTHONPATH="${PYTHONPATH}:/app" #<7> 59 | ``` 60 | 61 | 1. The first line defines the base image. Our base image is an Ubuntu 20.04 (Linux) operating system. This is pulled from DockerHub - the online store of publicly available images (`https://hub.docker.com/_/ubuntu`). 62 | 2. Update `apt-get`, the Linux package manager and install relevant packages 63 | 3. Upgrade `pip` the Python package manager 64 | 4. Change the working directory to `/app`. 65 | 5. Copy the `requirements.txt` file into the image and use `pip` to install all relevant Python packages 66 | 6. Copy relevant folders into the image 67 | 7. Update the `PYTHONPATH` so that we can import functions that we write from our `/app` directory 68 | 69 | You can see how the Dockerfile can be thought of as a recipe for building a particular run-time environment. The magic of Docker is that you do not need to worry about installing a resource intensive virtual machine on your computer - Docker is lightweight and allows you to build an environment template purely using code. 70 | 71 | A running version of an image is called a *container*. You can think of the image as like a cookie cutter, that can be used to create a particular cookie (the container). There is one other file that we need to look at before we finally get to build our image and run the container - the docker-compose.yaml file. 72 | 73 | ## 🎼 The docker-compose.yaml file 74 | 75 | Docker Compose is an extension to Docker that allows you to define how you would like your containers to run, through a simple YAML file, called 'docker-compose.yaml'. 76 | 77 | For example, you can specify which ports and folders should be mapped between your machine and the container. Folder mapping allows the container to treat folders on your machine as if they were folders inside the container. Therefore any changes you make to mapped files and folders on your machine will be immediately reflected inside the container. Port mapping will forward any traffic on a local port through to the container. For example, we could map port 8888 on your local machine to port 8888 in the container, so that if you visit `localhost:8888` in your browser, you will see Jupyter, which is running on port 8888 inside the container. The ports do not have have to be the same - for example you could map port 8000 to port 8888 in the container if you wanted. 78 | 79 | The alternative to using Docker Compose is specify all of these parameters in the command line when using the `docker run` command. However, this is cumbersome and it is much easier to just use a Docker Compose YAML file. Docker Compose also gives you the ability to specify multiple services that should be run at the same time (for example, a application container and a database), and how they should talk to each other. However, our purposes, we only need a single service - the container that runs Jupyter, so that we can interact with the generative deep learning codebase. 80 | 81 | Let's now take a look at the Docker Compose YAML file. 82 | 83 | ``` 84 | version: '3' #<1> 85 | services: #<2> 86 | app: #<3> 87 | build: #<4> 88 | context: . 89 | dockerfile: ./docker/Dockerfile 90 | tty: true #<5> 91 | volumes: #<6> 92 | - ./data/:/app/data 93 | - ./notebooks/:/app/notebooks 94 | - ./scripts/:/app/scripts 95 | ports: #<7> 96 | - "$JUPYTER_PORT:$JUPYTER_PORT" 97 | - "$TENSORBOARD_PORT:$TENSORBOARD_PORT" 98 | env_file: #<8> 99 | - ./.env 100 | entrypoint: jupyter lab --ip 0.0.0.0 --port=$JUPYTER_PORT --no-browser --allow-root #<9> 101 | ``` 102 | 103 | 1. This specifies the version of Docker Compose to use (currently version 3) 104 | 2. Here, we specify the services we wish to launch 105 | 3. We only have one service, which we call `app` 106 | 4. Here, we tell Docker where to find the Dockerfile (the same directory as the docker-compose.yaml file) 107 | 5. This allows us to open up an interactive command line inside the container, if we wish 108 | 6. Here, we map folders on our local machine (e.g. ./data), to folders inside the container (e.g. /app/data). 109 | 7. Here, we specify the port mappings - the dollar sign means that it will use the ports as specified in the `.env` file (e.g. `JUPYTER_PORT=8888`) 110 | 8. The location of the `.env` file on your local machine. 111 | 9. The command that should be run when the container runs - here, we run JupyterLab. 112 | 113 | ## 🧱 Building the image and running the container 114 | 115 | We're now at a point where we have everything we need to build our image and run the container. Building the image is simply a case of running the command shown below in your terminal, from the root folder. 116 | 117 | ``` 118 | docker compose build 119 | ``` 120 | 121 | You should see Docker start to run through the steps in the Dockerfile. You can use the command `docker images` to see a list of the images that have been built on your machine. 122 | 123 | To run the container based on the image we have just created, we use the command shown below: 124 | 125 | ``` 126 | docker compose up 127 | ``` 128 | 129 | You should see that Docker launches the Jupyter notebook server within the container and provides you with a URL to the running server. 130 | 131 | Because we have mapped port 8888 in the container to port 8888 on your machine, you can simply navigate to the address starting `http://127.0.0.1:8888/lab?token=` into a web browser and you should see the running Jupyter server. The folders that we mapped across in the `docker-compose.yaml` file should be visible on the left hand side. 132 | 133 | Congratulations! You now have a functioning Docker container that you can use to start working through the Generative Deep Learning codebase! To stop running the Jupyter server, you use `Ctrl-C` and to bring down the running container, you use the command `docker compose down`. Because the volumes are mapped, you won't lose any of your work that you save whilst working in the Jupyter notebooks, even if you bring the container down. 134 | 135 | ## ⚡️ Using a GPU 136 | 137 | The default `Dockerfile` and `docker-compose.yaml` file assume that you do not want to use a local GPU to train your models. If you do have a GPU that you wish to use (for example, you are using a cloud VM), I have provided two extra files called `Dockerfile-gpu` and `docker-compose.gpu.yml` files that can be used in place of the default files. 138 | 139 | For example, to build an image that includes support for GPU, use the command shown below: 140 | 141 | ``` 142 | docker compose -f docker-compose.gpu.yml build 143 | ``` 144 | 145 | To run this image, use the following command: 146 | 147 | ``` 148 | docker compose -f docker-compose.gpu.yml up 149 | ``` 150 | -------------------------------------------------------------------------------- /docs/googlecloud.md: -------------------------------------------------------------------------------- 1 | # ⚡️ Setting up a VM with GPU in Google Cloud Platform 2 | 3 | ## ☁️ Google Cloud Platform 4 | 5 | Google Cloud Platform is a public cloud vendor that offers a wide variety of computing services, running on Google infrastructure. It is extremely useful for spinning up resources on demand, that otherwise would require a large upfront investment in hardware and set up time. 6 | 7 | There are many cloud vendors - the largest being Amazon Web Services (AWS), Microsoft Azure and Google Cloud Platform (GCP). Through these cloud platforms, you can spin up servers, databases other services through the online user interface and easily turn them off or tear them down when you are finished using them. This makes cloud computing an excellent choice for anyone who wants to get started with state-of-the-art machine learning, because you can get access to powerful resources at the click of a button, without having to invest significant amounts of money in your own hardware. 8 | 9 | In this book, we will demonstrate how to set up a virtual server with a GPU on Google Cloud Platform. You can use this environment to run the codebase that accompanies this book on accelerated GPU hardware, for faster training. 10 | 11 | ## 🚀 Getting started 12 | 13 | ### Get a Google Cloud Platform Account 14 | 15 | Firstly, you'll need a Google Cloud Console account - visit https://cloud.google.com/cloud-console/ to get set up. 16 | 17 | If you've never used GCP, you get access to $300 free credit over 90 days (correct at time of writing), which is more than enough to run all of the training examples in this book. 18 | 19 | Once you're logged in, the first thing you need to do is create a project. You can call this whatever you like, for example `generative-deep-learning`. Then you have to set up a billing account for the project - if you are using the free trial credits, you will not be automatically billed after the credits are used up. 20 | 21 | You can now go ahead and set up a compute engine with an attached GPU, within this project. You can see which project you are currently in next to the Google Cloud Platform logo in the top left corner of the screen. 22 | 23 | ### Set up a compute engine 24 | 25 | Firstly, search for 'Compute Engine' in the search bar and click 'Create Instance' (or navigate straight to https://console.cloud.google.com/compute/instancesAdd 26 | 27 | You will see a screen where you can select the configuration of your virtual machine (VM). Below, we run through a description of each field and a recommended value to choose. 28 | 29 | | Option | Description | Example Value 30 | | --- | --- | --- 31 | | Name | What you would like to call the compute instance? | gdl-compute 32 | | Region | Where the compute instance should be physically located? | us-central1-c (Iowa) 33 | | Machine family | What type of machine would you like? | GPU 34 | | GPU type | What GPU would you like? | NVIDIA Tesla T4 35 | | Machine type | What specification of machine would you like? | n1-standard-4 (4vCPU, 15 GB memory) 36 | | Boot disk > Operating System| What boot disk OS would you like? | Deep Learning on Linux 37 | 38 | All other options can be left as the default. 39 | 40 | The NVIDIA Tesla T4 is a sufficiently powerful GPU to use to run the examples in this book. Note that other more powerful GPUs are available, but these are more expensive per hour. Choosing the 'Deep Learning on Linux' boot disk means that the NVIDIA CUDA stack will be installed automatically at no extra cost. 41 | 42 | You may wish to select a different region for your machine - price does vary between regions and some regions are not compatible with GPU (see https://cloud.google.com/compute/docs/gpus/gpu-regions-zones for the full list of GPU regions). In the `us-central1-c` region, the whole virtual machine should be priced at around $0.55 per hour (without any sustained use discount applied). 43 | 44 | Click 'Create' to build your virtual machine - this may take a few minutes. You'll be able to see when the machine is ready to access, as there will be a green tick next to the machine name. 45 | 46 | ### Accessing the VM 47 | 48 | The easiest way to access the VM is by using the Google Cloud CLI and Visual Studio Code. 49 | 50 | Firstly, you'll need to install the Google Cloud CLI onto your machine. You can do this by following the instructions at the following link: https://cloud.google.com/sdk/docs/install 51 | 52 | Then, to set up the configuration for the SSH connection to your virtual machine, first run the command below in your local terminal: 53 | 54 | ``` 55 | gcloud compute config-ssh 56 | ``` 57 | 58 | Then run the command shown below to connect to the virtual machine, replacing `` with the name of your virtual machine (e.g. `gdl_compute`). 59 | 60 | ``` 61 | gcloud compute ssh 62 | ``` 63 | 64 | The first time you connect, it will ask if you want to install NVIDIA drivers - select 'yes'. 65 | 66 | To start coding on the VM, I recommend installing VSCode from https://code.visualstudio.com/download, which is a modern Interactive Development Environment (IDE) through which you can easily access you virtual machine and manipulate the file system. 67 | 68 | Make sure you have the Remote SSH extension installed (see https://code.visualstudio.com/blogs/2019/10/03/remote-ssh-tips-and-tricks). 69 | 70 | In the bottom left-hand corner of VSCode, you'll see a green box with two arrows - click this and then select 'Connect Current Window to Host'. Then select your virtual machine. VSCode will then connect via SSH. Congratulations - you are now connected to the virtual machine! 71 | 72 | ### Cloning the codebase 73 | 74 | You can then clone the codebase onto the VM: 75 | 76 | ``` 77 | git clone https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition.git 78 | ``` 79 | 80 | Then in VSCode, if you click 'Open Folder', you can select the folder that you have just cloned into, to see the files in the repository. 81 | 82 | Lastly, you'll need to make sure that Docker Compose is installed on the virtual machine using the commands shown below 83 | 84 | ``` 85 | sudo apt-get update 86 | sudo apt-get install docker-compose-plugin 87 | ``` 88 | 89 | Now you can build the image and run the container, using the GPU-specific commands discussed in the main `README` of this codebase. 90 | 91 | Congratulations - you're now running the Generative Deep Learning codebase on a cloud virtual machine, with attached GPU! 92 | 93 | If you navigate to `http://127.0.0.1:8888/` in your browser, you should see Jupyter Lab, as VS Code automatically forwards the Jupyter Lab port in the container to the same port on your machine. If not, check that the ports are mapped under the 'Ports' tab in VSCode. 94 | 95 | ## 🚦 Starting and stopping the virtual machine 96 | 97 | When you are not using the virtual machine, it's important to turn it off so that you don't incur unnecessary charges. 98 | 99 | You can do this by selecting 'Stop' in the dropdown menu net to your virtual machine, in GCP. A stopped virtual machine will not be accessible via SSH, but the storage will be persisted, so you will not lose any of your files or progress. The storage will still incur a small charge. 100 | 101 | To turn the virtual machine back on, select 'Start / Resume' from the dropdown menu net to your virtual machine, in GCP. This will take a few minutes. When the green tick appears next to the machine description, you'll be able to access your machine in VS Code as before. 102 | 103 | ### Changing IP address 104 | 105 | When you stop and restart your virtual machine, the external IP address of the server may change. You can see the current external IP address of the virtual machine in the console, as shown in <>. If this happens, you will need to update the external IP address of the aliased SSH connection, which you can do within VS Code. 106 | 107 | Firstly, click on the green SSH connection arrows in the bottom left-hand corner of VS Code and then 'Open SSH Configuration File...'. Select your SSH configuration file and you should see a block of text similar to that shown in <>. This is the text added by Google Cloud CLI to allow you an easy way to connect to your virtual machine. 108 | 109 | To update the external IP address, you just need to change the `HostName` value and save the file. 110 | 111 | ### Destroying the VM 112 | 113 | To completely destroy the machine, select 'Delete' from the dropdown menu net to your virtual machine in GCP. This will completely remove the virtual machine and all attached stored, and there will be no further incurred charges. 114 | 115 | In this section, we have seen how to build a virtual machine with attached GPU in Google Cloud Platform, from scratch. This should allow you to build sophisticated generative deep learning models, trained on sophisticated hardware that you can scale to your requirements. -------------------------------------------------------------------------------- /docs/history.csv: -------------------------------------------------------------------------------- 1 | Name,Paper Date,paper,hide,type,parameters 2 | LSTM,01/11/1997,http://www.bioinf.jku.at/publications/older/2604.pdf,1,Autoregressive / Transformer, 3 | VAE,20/12/2013,https://arxiv.org/abs/1312.6114,0,Variational Autoencoder, 4 | Encoder Decoder,03/06/2014,https://arxiv.org/abs/1406.1078,1,Autoregressive / Transformer, 5 | GAN,10/06/2014,https://arxiv.org/abs/1406.2661,0,Generative Adversarial Network, 6 | Attention,01/09/2014,https://arxiv.org/abs/1409.0473,1,General, 7 | GRU,03/09/2014,https://arxiv.org/abs/1409.1259,0,Autoregressive / Transformer, 8 | CGAN,06/11/2014,https://arxiv.org/abs/1411.1784,0,Generative Adversarial Network, 9 | Diffusion Process,12/03/2015,https://arxiv.org/abs/1503.03585,1,Energy-Based / Diffusion Models, 10 | UNet,18/05/2015,https://arxiv.org/abs/1505.04597,0,General, 11 | Neural Style,26/08/2015,https://arxiv.org/abs/1508.06576,1,General, 12 | DCGAN,19/11/2015,https://arxiv.org/abs/1511.06434,0,Generative Adversarial Network, 13 | ResNet,10/12/2015,https://arxiv.org/abs/1512.03385,0,General, 14 | VAE-GAN,31/12/2015,https://arxiv.org/abs/1512.09300,0,Variational Autoencoder, 15 | Self Attention,25/01/2016,https://arxiv.org/abs/1601.06733,1,Autoregressive / Transformer, 16 | PixelRNN,25/01/2016,https://arxiv.org/abs/1601.06759,0,Autoregressive / Transformer, 17 | RealNVP,27/05/2016,https://arxiv.org/abs/1605.08803v3,0,Normalizing Flow, 18 | PixelCNN,16/06/2016,https://arxiv.org/abs/1606.05328,0,Autoregressive / Transformer, 19 | pix2pix,21/11/2016,https://arxiv.org/abs/1611.07004,0,Generative Adversarial Network, 20 | Stack GAN,10/12/2016,https://arxiv.org/abs/1612.03242,1,Generative Adversarial Network, 21 | PixelCNN++,19/01/2017,https://arxiv.org/abs/1701.05517,0,Autoregressive / Transformer, 22 | WGAN,26/01/2017,https://arxiv.org/abs/1701.07875,0,Generative Adversarial Network, 23 | CycleGAN,30/03/2017,https://arxiv.org/abs/1703.10593,0,Generative Adversarial Network, 24 | WGAN GP,31/03/2017,https://arxiv.org/abs/1704.00028,1,Generative Adversarial Network, 25 | Transformers,12/06/2017,https://arxiv.org/abs/1706.03762,0,Autoregressive / Transformer, 26 | MuseGAN,19/09/2017,https://arxiv.org/abs/1709.06298,0,Generative Adversarial Network, 27 | ProGAN,27/10/2017,https://arxiv.org/abs/1710.10196,0,Generative Adversarial Network, 28 | VQ-VAE,02/11/2017,https://arxiv.org/abs/1711.00937v2,0,Variational Autoencoder, 29 | World Models,27/03/2018,https://arxiv.org/abs/1803.10122,0,Variational Autoencoder, 30 | SAGAN,21/05/2018,https://arxiv.org/abs/1805.08318v2,0,Generative Adversarial Network, 31 | GPT,11/06/2018,https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf,0,Autoregressive / Transformer,0.117 32 | GLOW,09/07/2018,https://arxiv.org/abs/1807.03039,0,Normalizing Flow, 33 | Universal Transformer,10/07/2018,https://arxiv.org/abs/1807.03819,1,Autoregressive / Transformer, 34 | BigGAN,28/09/2018,https://arxiv.org/abs/1809.11096,0,Generative Adversarial Network, 35 | FFJORD,02/10/2018,https://arxiv.org/abs/1810.01367,0,Normalizing Flow, 36 | BERT,11/10/2018,https://arxiv.org/abs/1810.04805,0,Autoregressive / Transformer,0.345 37 | StyleGAN,12/12/2018,https://arxiv.org/abs/1812.04948,0,Generative Adversarial Network, 38 | Music Transformer,12/12/2018,https://arxiv.org/abs/1809.04281,0,Autoregressive / Transformer, 39 | GPT-2,14/02/2019,https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf,0,Autoregressive / Transformer,1.5 40 | MuseNet,25/04/2019,https://openai.com/blog/musenet/,0,Autoregressive / Transformer, 41 | VQ-VAE-2,02/06/2019,https://arxiv.org/abs/1906.00446v1,0,Variational Autoencoder, 42 | NCSN,12/07/2019,https://arxiv.org/abs/1907.05600,0,Energy-Based / Diffusion Models, 43 | T5,23/10/2019,https://arxiv.org/abs/1910.10683,0,Autoregressive / Transformer,11 44 | StyleGAN2,03/12/2019,https://arxiv.org/abs/1912.04958,0,Generative Adversarial Network, 45 | NeRF,19/03/2020,https://arxiv.org/abs/2003.08934,0,General, 46 | GPT-3,28/05/2020,https://arxiv.org/abs/2005.14165,0,Autoregressive / Transformer,175 47 | DDPM,19/06/2020,https://arxiv.org/abs/2006.11239,0,Energy-Based / Diffusion Models, 48 | DDIM,06/10/2020,https://arxiv.org/abs/2010.02502,0,Energy-Based / Diffusion Models, 49 | Vision Transformer,22/10/2020,https://arxiv.org/abs/2010.11929,0,Autoregressive / Transformer, 50 | VQ-GAN,17/12/2020,https://arxiv.org/abs/2012.09841,0,Generative Adversarial Network, 51 | DALL.E,24/02/2021,https://arxiv.org/abs/2102.12092,0,Multimodal Models,12 52 | CLIP,26/02/2021,https://arxiv.org/abs/2103.00020,0,Multimodal Models, 53 | GPT-Neo,21/03/2021,https://github.com/EleutherAI/gpt-neo,0,Autoregressive / Transformer,2.7 54 | GPT-J,10/06/2021,https://github.com/kingoflolz/mesh-transformer-jax,0,Autoregressive / Transformer,6 55 | StyleGAN3,23/06/2021,https://arxiv.org/abs/2106.12423,0,Generative Adversarial Network, 56 | Codex,07/07/2021,https://arxiv.org/abs/2107.03374,0,Autoregressive / Transformer, 57 | ViT VQ-GAN,09/10/2021,https://arxiv.org/abs/2110.04627,0,Generative Adversarial Network, 58 | Megatron-Turing NLG,11/10/2021,https://developer.nvidia.com/blog/using-deepspeed-and-megatron-to-train-megatron-turing-nlg-530b-the-worlds-largest-and-most-powerful-generative-language-model/,0,Autoregressive / Transformer,530 59 | Gopher,08/12/2021,https://arxiv.org/abs/2112.11446,0,Autoregressive / Transformer,280 60 | GLIDE,20/12/2021,https://arxiv.org/abs/2112.10741,0,Multimodal Models,5 61 | Latent Diffusion,20/12/2021,https://arxiv.org/abs/2112.10752,0,Energy-Based / Diffusion Models, 62 | LaMDA,20/01/2022,https://arxiv.org/abs/2201.08239,0,Autoregressive / Transformer,137 63 | StyleGAN-XL,01/02/2022,https://arxiv.org/abs/2202.00273v2,0,Generative Adversarial Network,0 64 | GPT-NeoX,02/02/2022,https://github.com/EleutherAI/gpt-neox,0,Autoregressive / Transformer,20 65 | Chinchilla,29/03/2022,https://arxiv.org/abs/2203.15556v1,0,Autoregressive / Transformer,70 66 | PaLM,05/04/2022,https://arxiv.org/abs/2204.02311,0,Autoregressive / Transformer,540 67 | DALL.E 2,13/04/2022,https://arxiv.org/abs/2204.06125,0,Multimodal Models,3.5 68 | Flamingo,29/04/2022,https://arxiv.org/abs/2204.14198,0,Multimodal Models,80 69 | OPT,02/05/2022,https://arxiv.org/abs/2205.01068,0,Autoregressive / Transformer,175 70 | Imagen,23/05/2022,https://arxiv.org/abs/2205.11487,0,Multimodal Models,4.6 71 | Parti,22/06/2022,https://arxiv.org/abs/2206.10789,0,Multimodal Models,20 72 | BLOOM,16/07/2022,https://arxiv.org/abs/2211.05100,0,Autoregressive / Transformer,176 73 | Stable Diffusion,22/08/2022,https://stability.ai/blog/stable-diffusion-public-release,0,Multimodal Models,0.89 74 | ChatGPT,30/11/2022,https://chat.openai.com/,0,Autoregressive / Transformer, 75 | MUSE,02/01/2023,https://arxiv.org/abs/2301.00704,0,Multimodal Models,3 76 | MusicLM,26/01/2023,https://arxiv.org/abs/2301.11325,0,Multimodal Models, 77 | Dreamix,02/02/2023,https://arxiv.org/pdf/2302.01329.pdf,0,Multimodal Models, 78 | Toolformer,09/02/2023,https://arxiv.org/pdf/2302.04761.pdf,0,Autoregressive / Transformer, 79 | ControlNet,10/02/2023,https://arxiv.org/abs/2302.05543,0,Multimodal Models, 80 | LLaMA,24/02/2023,https://arxiv.org/abs/2302.13971,0,Autoregressive / Transformer,65 81 | PaLM-E,06/03/2023,https://arxiv.org/abs/2303.03378,0,Multimodal Models,562 82 | Visual ChatGPT,08/03/2023,https://arxiv.org/abs/2303.04671,0,Multimodal Models, 83 | Alpaca,13/03/2023,https://github.com/tatsu-lab/stanford_alpaca,1,Autoregressive / Transformer, 84 | GPT-4,16/03/2023,https://cdn.openai.com/papers/gpt-4.pdf,0,Autoregressive / Transformer,1000 85 | Luminous,14/04/2022,https://www.aleph-alpha.com/luminous,0,Autoregressive / Transformer, 86 | Flan-T5,20/10/2022,https://arxiv.org/abs/2210.11416,0,Autoregressive / Transformer,11 87 | Falcon,17/03/2023,https://falconllm.tii.ae/,0,Autoregressive / Transformer,40 88 | PaLM 2,10/05/2023,https://ai.google/discover/palm2/,0,Autoregressive / Transformer,340 89 | PanGu-Σ,20/03/2023,https://arxiv.org/abs/2303.10845,1,Autoregressive / Transformer,1085 90 | GPT-3.5,30/11/2022,,0,Autoregressive / Transformer,175 91 | Llama 2,17/07/2023,,0,Autoregressive / Transformer,70 -------------------------------------------------------------------------------- /docs/model_sizes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidADSP/Generative_Deep_Learning_2nd_Edition/9b1048dbe0d698b486ed16d529cf16fcd3aea29d/docs/model_sizes.png -------------------------------------------------------------------------------- /docs/timeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidADSP/Generative_Deep_Learning_2nd_Edition/9b1048dbe0d698b486ed16d529cf16fcd3aea29d/docs/timeline.png -------------------------------------------------------------------------------- /img/book_cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidADSP/Generative_Deep_Learning_2nd_Edition/9b1048dbe0d698b486ed16d529cf16fcd3aea29d/img/book_cover.png -------------------------------------------------------------------------------- /img/structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidADSP/Generative_Deep_Learning_2nd_Edition/9b1048dbe0d698b486ed16d529cf16fcd3aea29d/img/structure.png -------------------------------------------------------------------------------- /notebooks/02_deeplearning/01_mlp/checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/02_deeplearning/01_mlp/logs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/02_deeplearning/01_mlp/mlp.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 👀 Multilayer perceptron (MLP)" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "In this notebook, we'll walk through the steps required to train your own multilayer perceptron on the CIFAR dataset" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import numpy as np\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "\n", 26 | "from tensorflow.keras import layers, models, optimizers, utils, datasets\n", 27 | "from notebooks.utils import display" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": { 33 | "tags": [] 34 | }, 35 | "source": [ 36 | "## 0. Parameters " 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "NUM_CLASSES = 10" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "## 1. Prepare the Data " 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "x_train = x_train.astype(\"float32\") / 255.0\n", 71 | "x_test = x_test.astype(\"float32\") / 255.0\n", 72 | "\n", 73 | "y_train = utils.to_categorical(y_train, NUM_CLASSES)\n", 74 | "y_test = utils.to_categorical(y_test, NUM_CLASSES)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "display(x_train[:10])\n", 84 | "print(y_train[:10])" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": {}, 90 | "source": [ 91 | "## 2. Build the model " 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "input_layer = layers.Input((32, 32, 3))\n", 101 | "\n", 102 | "x = layers.Flatten()(input_layer)\n", 103 | "x = layers.Dense(200, activation=\"relu\")(x)\n", 104 | "x = layers.Dense(150, activation=\"relu\")(x)\n", 105 | "\n", 106 | "output_layer = layers.Dense(NUM_CLASSES, activation=\"softmax\")(x)\n", 107 | "\n", 108 | "model = models.Model(input_layer, output_layer)\n", 109 | "\n", 110 | "model.summary()" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": { 116 | "tags": [] 117 | }, 118 | "source": [ 119 | "## 3. Train the model " 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "opt = optimizers.Adam(learning_rate=0.0005)\n", 129 | "model.compile(\n", 130 | " loss=\"categorical_crossentropy\", optimizer=opt, metrics=[\"accuracy\"]\n", 131 | ")" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "model.fit(x_train, y_train, batch_size=32, epochs=10, shuffle=True)" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "metadata": { 146 | "tags": [] 147 | }, 148 | "source": [ 149 | "## 4. Evaluation " 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "model.evaluate(x_test, y_test)" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "CLASSES = np.array(\n", 168 | " [\n", 169 | " \"airplane\",\n", 170 | " \"automobile\",\n", 171 | " \"bird\",\n", 172 | " \"cat\",\n", 173 | " \"deer\",\n", 174 | " \"dog\",\n", 175 | " \"frog\",\n", 176 | " \"horse\",\n", 177 | " \"ship\",\n", 178 | " \"truck\",\n", 179 | " ]\n", 180 | ")\n", 181 | "\n", 182 | "preds = model.predict(x_test)\n", 183 | "preds_single = CLASSES[np.argmax(preds, axis=-1)]\n", 184 | "actual_single = CLASSES[np.argmax(y_test, axis=-1)]" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "metadata": { 191 | "scrolled": true 192 | }, 193 | "outputs": [], 194 | "source": [ 195 | "n_to_show = 10\n", 196 | "indices = np.random.choice(range(len(x_test)), n_to_show)\n", 197 | "\n", 198 | "fig = plt.figure(figsize=(15, 3))\n", 199 | "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n", 200 | "\n", 201 | "for i, idx in enumerate(indices):\n", 202 | " img = x_test[idx]\n", 203 | " ax = fig.add_subplot(1, n_to_show, i + 1)\n", 204 | " ax.axis(\"off\")\n", 205 | " ax.text(\n", 206 | " 0.5,\n", 207 | " -0.35,\n", 208 | " \"pred = \" + str(preds_single[idx]),\n", 209 | " fontsize=10,\n", 210 | " ha=\"center\",\n", 211 | " transform=ax.transAxes,\n", 212 | " )\n", 213 | " ax.text(\n", 214 | " 0.5,\n", 215 | " -0.7,\n", 216 | " \"act = \" + str(actual_single[idx]),\n", 217 | " fontsize=10,\n", 218 | " ha=\"center\",\n", 219 | " transform=ax.transAxes,\n", 220 | " )\n", 221 | " ax.imshow(img)" 222 | ] 223 | } 224 | ], 225 | "metadata": { 226 | "kernelspec": { 227 | "display_name": "Python 3 (ipykernel)", 228 | "language": "python", 229 | "name": "python3" 230 | }, 231 | "language_info": { 232 | "codemirror_mode": { 233 | "name": "ipython", 234 | "version": 3 235 | }, 236 | "file_extension": ".py", 237 | "mimetype": "text/x-python", 238 | "name": "python", 239 | "nbconvert_exporter": "python", 240 | "pygments_lexer": "ipython3", 241 | "version": "3.8.10" 242 | }, 243 | "vscode": { 244 | "interpreter": { 245 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" 246 | } 247 | } 248 | }, 249 | "nbformat": 4, 250 | "nbformat_minor": 4 251 | } 252 | -------------------------------------------------------------------------------- /notebooks/02_deeplearning/01_mlp/models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/02_deeplearning/01_mlp/output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/02_deeplearning/02_cnn/checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/02_deeplearning/02_cnn/cnn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 🏞 Convolutional Neural Network" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "In this notebook, we'll walk through the steps required to train your own convolutional neural network (CNN) on the CIFAR dataset" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import numpy as np\n", 24 | "\n", 25 | "from tensorflow.keras import layers, models, optimizers, utils, datasets\n", 26 | "from notebooks.utils import display" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": { 32 | "tags": [] 33 | }, 34 | "source": [ 35 | "## 0. Parameters " 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "NUM_CLASSES = 10" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "## 1. Prepare the Data " 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "x_train = x_train.astype(\"float32\") / 255.0\n", 70 | "x_test = x_test.astype(\"float32\") / 255.0\n", 71 | "\n", 72 | "y_train = utils.to_categorical(y_train, NUM_CLASSES)\n", 73 | "y_test = utils.to_categorical(y_test, NUM_CLASSES)" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "display(x_train[:10])\n", 83 | "print(y_train[:10])" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "## 2. Build the model " 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "input_layer = layers.Input((32, 32, 3))\n", 100 | "\n", 101 | "x = layers.Conv2D(filters=32, kernel_size=3, strides=1, padding=\"same\")(\n", 102 | " input_layer\n", 103 | ")\n", 104 | "x = layers.BatchNormalization()(x)\n", 105 | "x = layers.LeakyReLU()(x)\n", 106 | "\n", 107 | "x = layers.Conv2D(filters=32, kernel_size=3, strides=2, padding=\"same\")(x)\n", 108 | "x = layers.BatchNormalization()(x)\n", 109 | "x = layers.LeakyReLU()(x)\n", 110 | "\n", 111 | "x = layers.Conv2D(filters=64, kernel_size=3, strides=1, padding=\"same\")(x)\n", 112 | "x = layers.BatchNormalization()(x)\n", 113 | "x = layers.LeakyReLU()(x)\n", 114 | "\n", 115 | "x = layers.Conv2D(filters=64, kernel_size=3, strides=2, padding=\"same\")(x)\n", 116 | "x = layers.BatchNormalization()(x)\n", 117 | "x = layers.LeakyReLU()(x)\n", 118 | "\n", 119 | "x = layers.Flatten()(x)\n", 120 | "\n", 121 | "x = layers.Dense(128)(x)\n", 122 | "x = layers.BatchNormalization()(x)\n", 123 | "x = layers.LeakyReLU()(x)\n", 124 | "x = layers.Dropout(rate=0.5)(x)\n", 125 | "\n", 126 | "x = layers.Dense(NUM_CLASSES)(x)\n", 127 | "output_layer = layers.Activation(\"softmax\")(x)\n", 128 | "\n", 129 | "model = models.Model(input_layer, output_layer)\n", 130 | "\n", 131 | "model.summary()" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": { 137 | "tags": [] 138 | }, 139 | "source": [ 140 | "## 3. Train the model " 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "opt = optimizers.Adam(learning_rate=0.0005)\n", 150 | "model.compile(\n", 151 | " loss=\"categorical_crossentropy\", optimizer=opt, metrics=[\"accuracy\"]\n", 152 | ")" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": { 159 | "tags": [] 160 | }, 161 | "outputs": [], 162 | "source": [ 163 | "model.fit(\n", 164 | " x_train,\n", 165 | " y_train,\n", 166 | " batch_size=32,\n", 167 | " epochs=10,\n", 168 | " shuffle=True,\n", 169 | " validation_data=(x_test, y_test),\n", 170 | ")" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "metadata": { 176 | "tags": [] 177 | }, 178 | "source": [ 179 | "## 4. Evaluation " 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "model.evaluate(x_test, y_test, batch_size=1000)" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "CLASSES = np.array(\n", 198 | " [\n", 199 | " \"airplane\",\n", 200 | " \"automobile\",\n", 201 | " \"bird\",\n", 202 | " \"cat\",\n", 203 | " \"deer\",\n", 204 | " \"dog\",\n", 205 | " \"frog\",\n", 206 | " \"horse\",\n", 207 | " \"ship\",\n", 208 | " \"truck\",\n", 209 | " ]\n", 210 | ")\n", 211 | "\n", 212 | "preds = model.predict(x_test)\n", 213 | "preds_single = CLASSES[np.argmax(preds, axis=-1)]\n", 214 | "actual_single = CLASSES[np.argmax(y_test, axis=-1)]" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "import matplotlib.pyplot as plt\n", 224 | "\n", 225 | "n_to_show = 10\n", 226 | "indices = np.random.choice(range(len(x_test)), n_to_show)\n", 227 | "\n", 228 | "fig = plt.figure(figsize=(15, 3))\n", 229 | "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n", 230 | "\n", 231 | "for i, idx in enumerate(indices):\n", 232 | " img = x_test[idx]\n", 233 | " ax = fig.add_subplot(1, n_to_show, i + 1)\n", 234 | " ax.axis(\"off\")\n", 235 | " ax.text(\n", 236 | " 0.5,\n", 237 | " -0.35,\n", 238 | " \"pred = \" + str(preds_single[idx]),\n", 239 | " fontsize=10,\n", 240 | " ha=\"center\",\n", 241 | " transform=ax.transAxes,\n", 242 | " )\n", 243 | " ax.text(\n", 244 | " 0.5,\n", 245 | " -0.7,\n", 246 | " \"act = \" + str(actual_single[idx]),\n", 247 | " fontsize=10,\n", 248 | " ha=\"center\",\n", 249 | " transform=ax.transAxes,\n", 250 | " )\n", 251 | " ax.imshow(img)" 252 | ] 253 | } 254 | ], 255 | "metadata": { 256 | "kernelspec": { 257 | "display_name": "Python 3 (ipykernel)", 258 | "language": "python", 259 | "name": "python3" 260 | }, 261 | "language_info": { 262 | "codemirror_mode": { 263 | "name": "ipython", 264 | "version": 3 265 | }, 266 | "file_extension": ".py", 267 | "mimetype": "text/x-python", 268 | "name": "python", 269 | "nbconvert_exporter": "python", 270 | "pygments_lexer": "ipython3", 271 | "version": "3.8.10" 272 | }, 273 | "vscode": { 274 | "interpreter": { 275 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" 276 | } 277 | } 278 | }, 279 | "nbformat": 4, 280 | "nbformat_minor": 4 281 | } 282 | -------------------------------------------------------------------------------- /notebooks/02_deeplearning/02_cnn/convolutions.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 🎆 Convolutions" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "In this notebook, we'll walk through how convolutional filters can pick out different aspects of an image" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "%matplotlib inline\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "import numpy as np\n", 26 | "from skimage import data\n", 27 | "from skimage.color import rgb2gray\n", 28 | "from skimage.transform import resize" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "## 0. Original Input Image " 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "im = rgb2gray(data.coffee())\n", 45 | "im = resize(im, (64, 64))\n", 46 | "print(im.shape)\n", 47 | "\n", 48 | "plt.axis(\"off\")\n", 49 | "plt.imshow(im, cmap=\"gray\")" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "## Horizontal Edge Filter " 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "filter1 = np.array([[1, 1, 1], [0, 0, 0], [-1, -1, -1]])\n", 66 | "\n", 67 | "new_image = np.zeros(im.shape)\n", 68 | "\n", 69 | "im_pad = np.pad(im, 1, \"constant\")\n", 70 | "\n", 71 | "for i in range(im.shape[0]):\n", 72 | " for j in range(im.shape[1]):\n", 73 | " try:\n", 74 | " new_image[i, j] = (\n", 75 | " im_pad[i - 1, j - 1] * filter1[0, 0]\n", 76 | " + im_pad[i - 1, j] * filter1[0, 1]\n", 77 | " + im_pad[i - 1, j + 1] * filter1[0, 2]\n", 78 | " + im_pad[i, j - 1] * filter1[1, 0]\n", 79 | " + im_pad[i, j] * filter1[1, 1]\n", 80 | " + im_pad[i, j + 1] * filter1[1, 2]\n", 81 | " + im_pad[i + 1, j - 1] * filter1[2, 0]\n", 82 | " + im_pad[i + 1, j] * filter1[2, 1]\n", 83 | " + im_pad[i + 1, j + 1] * filter1[2, 2]\n", 84 | " )\n", 85 | " except:\n", 86 | " pass\n", 87 | "\n", 88 | "plt.axis(\"off\")\n", 89 | "plt.imshow(new_image, cmap=\"Greys\")" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": {}, 95 | "source": [ 96 | "## Vertical Edge Filter " 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "filter2 = np.array([[-1, 0, 1], [-1, 0, 1], [-1, 0, 1]])\n", 106 | "\n", 107 | "new_image = np.zeros(im.shape)\n", 108 | "\n", 109 | "im_pad = np.pad(im, 1, \"constant\")\n", 110 | "\n", 111 | "for i in range(im.shape[0]):\n", 112 | " for j in range(im.shape[1]):\n", 113 | " try:\n", 114 | " new_image[i, j] = (\n", 115 | " im_pad[i - 1, j - 1] * filter2[0, 0]\n", 116 | " + im_pad[i - 1, j] * filter2[0, 1]\n", 117 | " + im_pad[i - 1, j + 1] * filter2[0, 2]\n", 118 | " + im_pad[i, j - 1] * filter2[1, 0]\n", 119 | " + im_pad[i, j] * filter2[1, 1]\n", 120 | " + im_pad[i, j + 1] * filter2[1, 2]\n", 121 | " + im_pad[i + 1, j - 1] * filter2[2, 0]\n", 122 | " + im_pad[i + 1, j] * filter2[2, 1]\n", 123 | " + im_pad[i + 1, j + 1] * filter2[2, 2]\n", 124 | " )\n", 125 | " except:\n", 126 | " pass\n", 127 | "\n", 128 | "plt.axis(\"off\")\n", 129 | "plt.imshow(new_image, cmap=\"Greys\")" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": {}, 135 | "source": [ 136 | "## Horizontal Edge Filter with Stride 2 " 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "filter1 = np.array([[1, 1, 1], [0, 0, 0], [-1, -1, -1]])\n", 146 | "\n", 147 | "stride = 2\n", 148 | "\n", 149 | "new_image = np.zeros((int(im.shape[0] / stride), int(im.shape[1] / stride)))\n", 150 | "\n", 151 | "im_pad = np.pad(im, 1, \"constant\")\n", 152 | "\n", 153 | "for i in range(0, im.shape[0], stride):\n", 154 | " for j in range(0, im.shape[1], stride):\n", 155 | " try:\n", 156 | " new_image[int(i / stride), int(j / stride)] = (\n", 157 | " im_pad[i - 1, j - 1] * filter1[0, 0]\n", 158 | " + im_pad[i - 1, j] * filter1[0, 1]\n", 159 | " + im_pad[i - 1, j + 1] * filter1[0, 2]\n", 160 | " + im_pad[i, j - 1] * filter1[1, 0]\n", 161 | " + im_pad[i, j] * filter1[1, 1]\n", 162 | " + im_pad[i, j + 1] * filter1[1, 2]\n", 163 | " + im_pad[i + 1, j - 1] * filter1[2, 0]\n", 164 | " + im_pad[i + 1, j] * filter1[2, 1]\n", 165 | " + im_pad[i + 1, j + 1] * filter1[2, 2]\n", 166 | " )\n", 167 | " except:\n", 168 | " pass\n", 169 | "\n", 170 | "plt.axis(\"off\")\n", 171 | "plt.imshow(new_image, cmap=\"Greys\")" 172 | ] 173 | }, 174 | { 175 | "cell_type": "markdown", 176 | "metadata": {}, 177 | "source": [ 178 | "## Vertical Edge Filter with Stride 2 " 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "filter2 = np.array([[-1, 0, 1], [-1, 0, 1], [-1, 0, 1]])\n", 188 | "\n", 189 | "stride = 2\n", 190 | "\n", 191 | "new_image = np.zeros((int(im.shape[0] / stride), int(im.shape[1] / stride)))\n", 192 | "\n", 193 | "im_pad = np.pad(im, 1, \"constant\")\n", 194 | "\n", 195 | "for i in range(0, im.shape[0], stride):\n", 196 | " for j in range(0, im.shape[1], stride):\n", 197 | " try:\n", 198 | " new_image[int(i / stride), int(j / stride)] = (\n", 199 | " im_pad[i - 1, j - 1] * filter2[0, 0]\n", 200 | " + im_pad[i - 1, j] * filter2[0, 1]\n", 201 | " + im_pad[i - 1, j + 1] * filter2[0, 2]\n", 202 | " + im_pad[i, j - 1] * filter2[1, 0]\n", 203 | " + im_pad[i, j] * filter2[1, 1]\n", 204 | " + im_pad[i, j + 1] * filter2[1, 2]\n", 205 | " + im_pad[i + 1, j - 1] * filter2[2, 0]\n", 206 | " + im_pad[i + 1, j] * filter2[2, 1]\n", 207 | " + im_pad[i + 1, j + 1] * filter2[2, 2]\n", 208 | " )\n", 209 | " except:\n", 210 | " pass\n", 211 | "\n", 212 | "plt.axis(\"off\")\n", 213 | "plt.imshow(new_image, cmap=\"Greys\")" 214 | ] 215 | } 216 | ], 217 | "metadata": { 218 | "kernelspec": { 219 | "display_name": "Python 3 (ipykernel)", 220 | "language": "python", 221 | "name": "python3" 222 | }, 223 | "language_info": { 224 | "codemirror_mode": { 225 | "name": "ipython", 226 | "version": 3 227 | }, 228 | "file_extension": ".py", 229 | "mimetype": "text/x-python", 230 | "name": "python", 231 | "nbconvert_exporter": "python", 232 | "pygments_lexer": "ipython3", 233 | "version": "3.8.10" 234 | }, 235 | "vscode": { 236 | "interpreter": { 237 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" 238 | } 239 | } 240 | }, 241 | "nbformat": 4, 242 | "nbformat_minor": 4 243 | } 244 | -------------------------------------------------------------------------------- /notebooks/02_deeplearning/02_cnn/logs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/02_deeplearning/02_cnn/models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/02_deeplearning/02_cnn/output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/03_vae/01_autoencoder/autoencoder.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b076bd1a-b236-4fbc-953d-8295b25122ae", 6 | "metadata": {}, 7 | "source": [ 8 | "# 👖 Autoencoders on Fashion MNIST" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "9235cbd1-f136-411c-88d9-f69f270c0b96", 14 | "metadata": {}, 15 | "source": [ 16 | "In this notebook, we'll walk through the steps required to train your own autoencoder on the fashion MNIST dataset." 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "id": "84acc7be-6764-4668-b2bb-178f63deeed3", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "%load_ext autoreload\n", 27 | "%autoreload 2\n", 28 | "\n", 29 | "import numpy as np\n", 30 | "import matplotlib.pyplot as plt\n", 31 | "\n", 32 | "from tensorflow.keras import layers, models, datasets, callbacks\n", 33 | "import tensorflow.keras.backend as K\n", 34 | "\n", 35 | "from notebooks.utils import display" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "id": "339e6268-ebd7-4feb-86db-1fe7abccdbe5", 41 | "metadata": {}, 42 | "source": [ 43 | "## 0. Parameters " 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "id": "1b2ee6ce-129f-4833-b0c5-fa567381c4e0", 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "IMAGE_SIZE = 32\n", 54 | "CHANNELS = 1\n", 55 | "BATCH_SIZE = 100\n", 56 | "BUFFER_SIZE = 1000\n", 57 | "VALIDATION_SPLIT = 0.2\n", 58 | "EMBEDDING_DIM = 2\n", 59 | "EPOCHS = 3" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "id": "b7716fac-0010-49b0-b98e-53be2259edde", 65 | "metadata": {}, 66 | "source": [ 67 | "## 1. Prepare the data " 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "id": "9a73e5a4-1638-411c-8d3c-29f823424458", 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "# Load the data\n", 78 | "(x_train, y_train), (x_test, y_test) = datasets.fashion_mnist.load_data()" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "id": "ebae2f0d-59fd-4796-841f-7213eae638de", 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "# Preprocess the data\n", 89 | "\n", 90 | "\n", 91 | "def preprocess(imgs):\n", 92 | " \"\"\"\n", 93 | " Normalize and reshape the images\n", 94 | " \"\"\"\n", 95 | " imgs = imgs.astype(\"float32\") / 255.0\n", 96 | " imgs = np.pad(imgs, ((0, 0), (2, 2), (2, 2)), constant_values=0.0)\n", 97 | " imgs = np.expand_dims(imgs, -1)\n", 98 | " return imgs\n", 99 | "\n", 100 | "\n", 101 | "x_train = preprocess(x_train)\n", 102 | "x_test = preprocess(x_test)" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "id": "fa53709f-7f3f-483b-9db8-2e5f9b9942c2", 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "# Show some items of clothing from the training set\n", 113 | "display(x_train)" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "id": "aff50401-3abe-4c10-bba8-b35bc13ad7d5", 119 | "metadata": { 120 | "tags": [] 121 | }, 122 | "source": [ 123 | "## 2. Build the autoencoder " 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "id": "086e2584-c60d-4990-89f4-2092c44e023e", 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "# Encoder\n", 134 | "encoder_input = layers.Input(\n", 135 | " shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS), name=\"encoder_input\"\n", 136 | ")\n", 137 | "x = layers.Conv2D(32, (3, 3), strides=2, activation=\"relu\", padding=\"same\")(\n", 138 | " encoder_input\n", 139 | ")\n", 140 | "x = layers.Conv2D(64, (3, 3), strides=2, activation=\"relu\", padding=\"same\")(x)\n", 141 | "x = layers.Conv2D(128, (3, 3), strides=2, activation=\"relu\", padding=\"same\")(x)\n", 142 | "shape_before_flattening = K.int_shape(x)[1:] # the decoder will need this!\n", 143 | "\n", 144 | "x = layers.Flatten()(x)\n", 145 | "encoder_output = layers.Dense(EMBEDDING_DIM, name=\"encoder_output\")(x)\n", 146 | "\n", 147 | "encoder = models.Model(encoder_input, encoder_output)\n", 148 | "encoder.summary()" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "id": "6c409e63-1aea-42e2-8324-c3e2a12073ee", 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "# Decoder\n", 159 | "decoder_input = layers.Input(shape=(EMBEDDING_DIM,), name=\"decoder_input\")\n", 160 | "x = layers.Dense(np.prod(shape_before_flattening))(decoder_input)\n", 161 | "x = layers.Reshape(shape_before_flattening)(x)\n", 162 | "x = layers.Conv2DTranspose(\n", 163 | " 128, (3, 3), strides=2, activation=\"relu\", padding=\"same\"\n", 164 | ")(x)\n", 165 | "x = layers.Conv2DTranspose(\n", 166 | " 64, (3, 3), strides=2, activation=\"relu\", padding=\"same\"\n", 167 | ")(x)\n", 168 | "x = layers.Conv2DTranspose(\n", 169 | " 32, (3, 3), strides=2, activation=\"relu\", padding=\"same\"\n", 170 | ")(x)\n", 171 | "decoder_output = layers.Conv2D(\n", 172 | " CHANNELS,\n", 173 | " (3, 3),\n", 174 | " strides=1,\n", 175 | " activation=\"sigmoid\",\n", 176 | " padding=\"same\",\n", 177 | " name=\"decoder_output\",\n", 178 | ")(x)\n", 179 | "\n", 180 | "decoder = models.Model(decoder_input, decoder_output)\n", 181 | "decoder.summary()" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "id": "34dc7c69-26a8-4c17-aa24-792f1b0a69b4", 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "# Autoencoder\n", 192 | "autoencoder = models.Model(\n", 193 | " encoder_input, decoder(encoder_output)\n", 194 | ") # decoder(encoder_output)\n", 195 | "autoencoder.summary()" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "id": "35b14665-4359-447b-be58-3fd58ba69084", 201 | "metadata": {}, 202 | "source": [ 203 | "## 3. Train the autoencoder " 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "id": "b429fdad-ea9c-45a2-a556-eb950d793824", 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "# Compile the autoencoder\n", 214 | "autoencoder.compile(optimizer=\"adam\", loss=\"binary_crossentropy\")" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "id": "c525e44b-b3bb-489c-9d35-fcfe3e714e6a", 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "# Create a model save checkpoint\n", 225 | "model_checkpoint_callback = callbacks.ModelCheckpoint(\n", 226 | " filepath=\"./checkpoint\",\n", 227 | " save_weights_only=False,\n", 228 | " save_freq=\"epoch\",\n", 229 | " monitor=\"loss\",\n", 230 | " mode=\"min\",\n", 231 | " save_best_only=True,\n", 232 | " verbose=0,\n", 233 | ")\n", 234 | "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "id": "b3c497b7-fa40-48df-b2bf-541239cc9400", 241 | "metadata": { 242 | "tags": [] 243 | }, 244 | "outputs": [], 245 | "source": [ 246 | "autoencoder.fit(\n", 247 | " x_train,\n", 248 | " x_train,\n", 249 | " epochs=EPOCHS,\n", 250 | " batch_size=BATCH_SIZE,\n", 251 | " shuffle=True,\n", 252 | " validation_data=(x_test, x_test),\n", 253 | " callbacks=[model_checkpoint_callback, tensorboard_callback],\n", 254 | ")" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": null, 260 | "id": "edb847d1-c22d-4923-ba92-0ecde0f12fca", 261 | "metadata": {}, 262 | "outputs": [], 263 | "source": [ 264 | "# Save the final models\n", 265 | "autoencoder.save(\"./models/autoencoder\")\n", 266 | "encoder.save(\"./models/encoder\")\n", 267 | "decoder.save(\"./models/decoder\")" 268 | ] 269 | }, 270 | { 271 | "cell_type": "markdown", 272 | "id": "bc0f31bc-77e6-49e8-bb76-51bca124744c", 273 | "metadata": { 274 | "tags": [] 275 | }, 276 | "source": [ 277 | "## 4. Reconstruct using the autoencoder " 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": null, 283 | "id": "d4d83729-71a2-4494-86a5-e17830974ef0", 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [ 287 | "n_to_predict = 5000\n", 288 | "example_images = x_test[:n_to_predict]\n", 289 | "example_labels = y_test[:n_to_predict]" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": null, 295 | "id": "5c9b2a91-7cea-4595-a857-11f5ab00875e", 296 | "metadata": {}, 297 | "outputs": [], 298 | "source": [ 299 | "predictions = autoencoder.predict(example_images)\n", 300 | "\n", 301 | "print(\"Example real clothing items\")\n", 302 | "display(example_images)\n", 303 | "print(\"Reconstructions\")\n", 304 | "display(predictions)" 305 | ] 306 | }, 307 | { 308 | "cell_type": "markdown", 309 | "id": "b77c88bb-ada4-4091-94e3-764f1385f1fc", 310 | "metadata": {}, 311 | "source": [ 312 | "## 5. Embed using the encoder " 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": null, 318 | "id": "5e723c1c-136b-47e5-9972-ee964712d148", 319 | "metadata": {}, 320 | "outputs": [], 321 | "source": [ 322 | "# Encode the example images\n", 323 | "embeddings = encoder.predict(example_images)" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": null, 329 | "id": "2ed4e9bd-df14-4832-a765-dfaf36d49fca", 330 | "metadata": {}, 331 | "outputs": [], 332 | "source": [ 333 | "# Some examples of the embeddings\n", 334 | "print(embeddings[:10])" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": null, 340 | "id": "3bb208e8-6351-49ac-a68c-679a830f13bf", 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "# Show the encoded points in 2D space\n", 345 | "figsize = 8\n", 346 | "\n", 347 | "plt.figure(figsize=(figsize, figsize))\n", 348 | "plt.scatter(embeddings[:, 0], embeddings[:, 1], c=\"black\", alpha=0.5, s=3)\n", 349 | "plt.show()" 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": null, 355 | "id": "138a34ca-67b4-42b7-a9fa-f7ffe397df49", 356 | "metadata": {}, 357 | "outputs": [], 358 | "source": [ 359 | "# Colour the embeddings by their label (clothing type - see table)\n", 360 | "example_labels = y_test[:n_to_predict]\n", 361 | "\n", 362 | "figsize = 8\n", 363 | "plt.figure(figsize=(figsize, figsize))\n", 364 | "plt.scatter(\n", 365 | " embeddings[:, 0],\n", 366 | " embeddings[:, 1],\n", 367 | " cmap=\"rainbow\",\n", 368 | " c=example_labels,\n", 369 | " alpha=0.8,\n", 370 | " s=3,\n", 371 | ")\n", 372 | "plt.colorbar()\n", 373 | "plt.show()" 374 | ] 375 | }, 376 | { 377 | "cell_type": "markdown", 378 | "id": "e0616b71-3354-419c-8ddb-f64fc29850ca", 379 | "metadata": {}, 380 | "source": [ 381 | "## 6. Generate using the decoder " 382 | ] 383 | }, 384 | { 385 | "cell_type": "code", 386 | "execution_count": null, 387 | "id": "2d494893-059f-42e4-825e-31c06fa3cd09", 388 | "metadata": {}, 389 | "outputs": [], 390 | "source": [ 391 | "# Get the range of the existing embeddings\n", 392 | "mins, maxs = np.min(embeddings, axis=0), np.max(embeddings, axis=0)\n", 393 | "\n", 394 | "# Sample some points in the latent space\n", 395 | "grid_width, grid_height = (6, 3)\n", 396 | "sample = np.random.uniform(\n", 397 | " mins, maxs, size=(grid_width * grid_height, EMBEDDING_DIM)\n", 398 | ")" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": null, 404 | "id": "ba3b1c66-c89d-436a-b009-19f1f5a785e5", 405 | "metadata": {}, 406 | "outputs": [], 407 | "source": [ 408 | "# Decode the sampled points\n", 409 | "reconstructions = decoder.predict(sample)" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": null, 415 | "id": "feea9b9d-8d3e-43f5-9ead-cd9e38367c00", 416 | "metadata": {}, 417 | "outputs": [], 418 | "source": [ 419 | "# Draw a plot of...\n", 420 | "figsize = 8\n", 421 | "plt.figure(figsize=(figsize, figsize))\n", 422 | "\n", 423 | "# ... the original embeddings ...\n", 424 | "plt.scatter(embeddings[:, 0], embeddings[:, 1], c=\"black\", alpha=0.5, s=2)\n", 425 | "\n", 426 | "# ... and the newly generated points in the latent space\n", 427 | "plt.scatter(sample[:, 0], sample[:, 1], c=\"#00B0F0\", alpha=1, s=40)\n", 428 | "plt.show()\n", 429 | "\n", 430 | "# Add underneath a grid of the decoded images\n", 431 | "fig = plt.figure(figsize=(figsize, grid_height * 2))\n", 432 | "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n", 433 | "\n", 434 | "for i in range(grid_width * grid_height):\n", 435 | " ax = fig.add_subplot(grid_height, grid_width, i + 1)\n", 436 | " ax.axis(\"off\")\n", 437 | " ax.text(\n", 438 | " 0.5,\n", 439 | " -0.35,\n", 440 | " str(np.round(sample[i, :], 1)),\n", 441 | " fontsize=10,\n", 442 | " ha=\"center\",\n", 443 | " transform=ax.transAxes,\n", 444 | " )\n", 445 | " ax.imshow(reconstructions[i, :, :], cmap=\"Greys\")" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": null, 451 | "id": "f64434a4-41c5-4225-ad31-9cf83f8797e1", 452 | "metadata": {}, 453 | "outputs": [], 454 | "source": [ 455 | "# Colour the embeddings by their label (clothing type - see table)\n", 456 | "figsize = 12\n", 457 | "grid_size = 15\n", 458 | "plt.figure(figsize=(figsize, figsize))\n", 459 | "plt.scatter(\n", 460 | " embeddings[:, 0],\n", 461 | " embeddings[:, 1],\n", 462 | " cmap=\"rainbow\",\n", 463 | " c=example_labels,\n", 464 | " alpha=0.8,\n", 465 | " s=300,\n", 466 | ")\n", 467 | "plt.colorbar()\n", 468 | "\n", 469 | "x = np.linspace(min(embeddings[:, 0]), max(embeddings[:, 0]), grid_size)\n", 470 | "y = np.linspace(max(embeddings[:, 1]), min(embeddings[:, 1]), grid_size)\n", 471 | "xv, yv = np.meshgrid(x, y)\n", 472 | "xv = xv.flatten()\n", 473 | "yv = yv.flatten()\n", 474 | "grid = np.array(list(zip(xv, yv)))\n", 475 | "\n", 476 | "reconstructions = decoder.predict(grid)\n", 477 | "# plt.scatter(grid[:, 0], grid[:, 1], c=\"black\", alpha=1, s=10)\n", 478 | "plt.show()\n", 479 | "\n", 480 | "fig = plt.figure(figsize=(figsize, figsize))\n", 481 | "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n", 482 | "for i in range(grid_size**2):\n", 483 | " ax = fig.add_subplot(grid_size, grid_size, i + 1)\n", 484 | " ax.axis(\"off\")\n", 485 | " ax.imshow(reconstructions[i, :, :], cmap=\"Greys\")" 486 | ] 487 | } 488 | ], 489 | "metadata": { 490 | "kernelspec": { 491 | "display_name": "Python 3 (ipykernel)", 492 | "language": "python", 493 | "name": "python3" 494 | }, 495 | "language_info": { 496 | "codemirror_mode": { 497 | "name": "ipython", 498 | "version": 3 499 | }, 500 | "file_extension": ".py", 501 | "mimetype": "text/x-python", 502 | "name": "python", 503 | "nbconvert_exporter": "python", 504 | "pygments_lexer": "ipython3", 505 | "version": "3.8.10" 506 | }, 507 | "vscode": { 508 | "interpreter": { 509 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" 510 | } 511 | } 512 | }, 513 | "nbformat": 4, 514 | "nbformat_minor": 5 515 | } 516 | -------------------------------------------------------------------------------- /notebooks/03_vae/01_autoencoder/checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/03_vae/01_autoencoder/logs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/03_vae/01_autoencoder/models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/03_vae/01_autoencoder/output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/03_vae/02_vae_fashion/checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/03_vae/02_vae_fashion/logs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/03_vae/02_vae_fashion/models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/03_vae/02_vae_fashion/output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/03_vae/03_vae_faces/checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/03_vae/03_vae_faces/logs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/03_vae/03_vae_faces/models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/03_vae/03_vae_faces/output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/03_vae/03_vae_faces/vae_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | def get_vector_from_label(data, vae, embedding_dim, label): 6 | current_sum_POS = np.zeros(shape=embedding_dim, dtype="float32") 7 | current_n_POS = 0 8 | current_mean_POS = np.zeros(shape=embedding_dim, dtype="float32") 9 | 10 | current_sum_NEG = np.zeros(shape=embedding_dim, dtype="float32") 11 | current_n_NEG = 0 12 | current_mean_NEG = np.zeros(shape=embedding_dim, dtype="float32") 13 | 14 | current_vector = np.zeros(shape=embedding_dim, dtype="float32") 15 | current_dist = 0 16 | 17 | print("label: " + label) 18 | print("images : POS move : NEG move :distance : 𝛥 distance") 19 | while current_n_POS < 10000: 20 | batch = list(data.take(1).get_single_element()) 21 | im = batch[0] 22 | attribute = batch[1] 23 | 24 | _, _, z = vae.encoder.predict(np.array(im), verbose=0) 25 | 26 | z_POS = z[attribute == 1] 27 | z_NEG = z[attribute == -1] 28 | 29 | if len(z_POS) > 0: 30 | current_sum_POS = current_sum_POS + np.sum(z_POS, axis=0) 31 | current_n_POS += len(z_POS) 32 | new_mean_POS = current_sum_POS / current_n_POS 33 | movement_POS = np.linalg.norm(new_mean_POS - current_mean_POS) 34 | 35 | if len(z_NEG) > 0: 36 | current_sum_NEG = current_sum_NEG + np.sum(z_NEG, axis=0) 37 | current_n_NEG += len(z_NEG) 38 | new_mean_NEG = current_sum_NEG / current_n_NEG 39 | movement_NEG = np.linalg.norm(new_mean_NEG - current_mean_NEG) 40 | 41 | current_vector = new_mean_POS - new_mean_NEG 42 | new_dist = np.linalg.norm(current_vector) 43 | dist_change = new_dist - current_dist 44 | 45 | print( 46 | str(current_n_POS) 47 | + " : " 48 | + str(np.round(movement_POS, 3)) 49 | + " : " 50 | + str(np.round(movement_NEG, 3)) 51 | + " : " 52 | + str(np.round(new_dist, 3)) 53 | + " : " 54 | + str(np.round(dist_change, 3)) 55 | ) 56 | 57 | current_mean_POS = np.copy(new_mean_POS) 58 | current_mean_NEG = np.copy(new_mean_NEG) 59 | current_dist = np.copy(new_dist) 60 | 61 | if np.sum([movement_POS, movement_NEG]) < 0.08: 62 | current_vector = current_vector / current_dist 63 | print("Found the " + label + " vector") 64 | break 65 | 66 | return current_vector 67 | 68 | 69 | def add_vector_to_images(data, vae, feature_vec): 70 | n_to_show = 5 71 | factors = [-4, -3, -2, -1, 0, 1, 2, 3, 4] 72 | 73 | example_batch = list(data.take(1).get_single_element()) 74 | example_images = example_batch[0] 75 | 76 | _, _, z_points = vae.encoder.predict(example_images, verbose=0) 77 | 78 | fig = plt.figure(figsize=(18, 10)) 79 | 80 | counter = 1 81 | 82 | for i in range(n_to_show): 83 | img = example_images[i] 84 | sub = fig.add_subplot(n_to_show, len(factors) + 1, counter) 85 | sub.axis("off") 86 | sub.imshow(img) 87 | 88 | counter += 1 89 | 90 | for factor in factors: 91 | changed_z_point = z_points[i] + feature_vec * factor 92 | changed_image = vae.decoder.predict( 93 | np.array([changed_z_point]), verbose=0 94 | )[0] 95 | 96 | sub = fig.add_subplot(n_to_show, len(factors) + 1, counter) 97 | sub.axis("off") 98 | sub.imshow(changed_image) 99 | 100 | counter += 1 101 | 102 | plt.show() 103 | 104 | 105 | def morph_faces(data, vae): 106 | factors = np.arange(0, 1, 0.1) 107 | 108 | example_batch = list(data.take(1).get_single_element())[:2] 109 | example_images = example_batch[0] 110 | _, _, z_points = vae.encoder.predict(example_images, verbose=0) 111 | 112 | fig = plt.figure(figsize=(18, 8)) 113 | 114 | counter = 1 115 | 116 | img = example_images[0] 117 | sub = fig.add_subplot(1, len(factors) + 2, counter) 118 | sub.axis("off") 119 | sub.imshow(img) 120 | 121 | counter += 1 122 | 123 | for factor in factors: 124 | changed_z_point = z_points[0] * (1 - factor) + z_points[1] * factor 125 | changed_image = vae.decoder.predict( 126 | np.array([changed_z_point]), verbose=0 127 | )[0] 128 | sub = fig.add_subplot(1, len(factors) + 2, counter) 129 | sub.axis("off") 130 | sub.imshow(changed_image) 131 | 132 | counter += 1 133 | 134 | img = example_images[1] 135 | sub = fig.add_subplot(1, len(factors) + 2, counter) 136 | sub.axis("off") 137 | sub.imshow(img) 138 | 139 | plt.show() 140 | -------------------------------------------------------------------------------- /notebooks/04_gan/01_dcgan/checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/04_gan/01_dcgan/dcgan.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b076bd1a-b236-4fbc-953d-8295b25122ae", 6 | "metadata": {}, 7 | "source": [ 8 | "# 🧱 DCGAN - Bricks Data" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "1603ea4b-8345-4e2e-ae7c-01c9953900e8", 14 | "metadata": {}, 15 | "source": [ 16 | "In this notebook, we'll walk through the steps required to train your own DCGAN on the bricks dataset" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "id": "4e0d56cc-4773-4029-97d8-26f882ba79c9", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "%load_ext autoreload\n", 27 | "%autoreload 2\n", 28 | "import numpy as np\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "\n", 31 | "import tensorflow as tf\n", 32 | "from tensorflow.keras import (\n", 33 | " layers,\n", 34 | " models,\n", 35 | " callbacks,\n", 36 | " losses,\n", 37 | " utils,\n", 38 | " metrics,\n", 39 | " optimizers,\n", 40 | ")\n", 41 | "\n", 42 | "from notebooks.utils import display, sample_batch" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "id": "339e6268-ebd7-4feb-86db-1fe7abccdbe5", 48 | "metadata": {}, 49 | "source": [ 50 | "## 0. Parameters " 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "id": "1b2ee6ce-129f-4833-b0c5-fa567381c4e0", 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "IMAGE_SIZE = 64\n", 61 | "CHANNELS = 1\n", 62 | "BATCH_SIZE = 128\n", 63 | "Z_DIM = 100\n", 64 | "EPOCHS = 300\n", 65 | "LOAD_MODEL = False\n", 66 | "ADAM_BETA_1 = 0.5\n", 67 | "ADAM_BETA_2 = 0.999\n", 68 | "LEARNING_RATE = 0.0002\n", 69 | "NOISE_PARAM = 0.1" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "id": "b7716fac-0010-49b0-b98e-53be2259edde", 75 | "metadata": {}, 76 | "source": [ 77 | "## 1. Prepare the data " 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "id": "73f4c594-3f6d-4c8e-94c1-2c2ba7bce076", 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "train_data = utils.image_dataset_from_directory(\n", 88 | " \"/app/data/lego-brick-images/dataset/\",\n", 89 | " labels=None,\n", 90 | " color_mode=\"grayscale\",\n", 91 | " image_size=(IMAGE_SIZE, IMAGE_SIZE),\n", 92 | " batch_size=BATCH_SIZE,\n", 93 | " shuffle=True,\n", 94 | " seed=42,\n", 95 | " interpolation=\"bilinear\",\n", 96 | ")" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "id": "a995473f-c389-4158-92d2-93a2fa937916", 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "def preprocess(img):\n", 107 | " \"\"\"\n", 108 | " Normalize and reshape the images\n", 109 | " \"\"\"\n", 110 | " img = (tf.cast(img, \"float32\") - 127.5) / 127.5\n", 111 | " return img\n", 112 | "\n", 113 | "\n", 114 | "train = train_data.map(lambda x: preprocess(x))" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "id": "80bcdbdd-fb1e-451f-b89c-03fd9b80deb5", 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "train_sample = sample_batch(train)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "id": "fa53709f-7f3f-483b-9db8-2e5f9b9942c2", 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "display(train_sample)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "id": "aff50401-3abe-4c10-bba8-b35bc13ad7d5", 140 | "metadata": { 141 | "tags": [] 142 | }, 143 | "source": [ 144 | "## 2. Build the GAN " 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "id": "9230b5bf-b4a8-48d5-b73b-6899a598f296", 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "discriminator_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS))\n", 155 | "x = layers.Conv2D(64, kernel_size=4, strides=2, padding=\"same\", use_bias=False)(\n", 156 | " discriminator_input\n", 157 | ")\n", 158 | "x = layers.LeakyReLU(0.2)(x)\n", 159 | "x = layers.Dropout(0.3)(x)\n", 160 | "x = layers.Conv2D(\n", 161 | " 128, kernel_size=4, strides=2, padding=\"same\", use_bias=False\n", 162 | ")(x)\n", 163 | "x = layers.BatchNormalization(momentum=0.9)(x)\n", 164 | "x = layers.LeakyReLU(0.2)(x)\n", 165 | "x = layers.Dropout(0.3)(x)\n", 166 | "x = layers.Conv2D(\n", 167 | " 256, kernel_size=4, strides=2, padding=\"same\", use_bias=False\n", 168 | ")(x)\n", 169 | "x = layers.BatchNormalization(momentum=0.9)(x)\n", 170 | "x = layers.LeakyReLU(0.2)(x)\n", 171 | "x = layers.Dropout(0.3)(x)\n", 172 | "x = layers.Conv2D(\n", 173 | " 512, kernel_size=4, strides=2, padding=\"same\", use_bias=False\n", 174 | ")(x)\n", 175 | "x = layers.BatchNormalization(momentum=0.9)(x)\n", 176 | "x = layers.LeakyReLU(0.2)(x)\n", 177 | "x = layers.Dropout(0.3)(x)\n", 178 | "x = layers.Conv2D(\n", 179 | " 1,\n", 180 | " kernel_size=4,\n", 181 | " strides=1,\n", 182 | " padding=\"valid\",\n", 183 | " use_bias=False,\n", 184 | " activation=\"sigmoid\",\n", 185 | ")(x)\n", 186 | "discriminator_output = layers.Flatten()(x)\n", 187 | "\n", 188 | "discriminator = models.Model(discriminator_input, discriminator_output)\n", 189 | "discriminator.summary()" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": null, 195 | "id": "b30dcc08-3869-4b67-a295-61f13d5d4e4c", 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "generator_input = layers.Input(shape=(Z_DIM,))\n", 200 | "x = layers.Reshape((1, 1, Z_DIM))(generator_input)\n", 201 | "x = layers.Conv2DTranspose(\n", 202 | " 512, kernel_size=4, strides=1, padding=\"valid\", use_bias=False\n", 203 | ")(x)\n", 204 | "x = layers.BatchNormalization(momentum=0.9)(x)\n", 205 | "x = layers.LeakyReLU(0.2)(x)\n", 206 | "x = layers.Conv2DTranspose(\n", 207 | " 256, kernel_size=4, strides=2, padding=\"same\", use_bias=False\n", 208 | ")(x)\n", 209 | "x = layers.BatchNormalization(momentum=0.9)(x)\n", 210 | "x = layers.LeakyReLU(0.2)(x)\n", 211 | "x = layers.Conv2DTranspose(\n", 212 | " 128, kernel_size=4, strides=2, padding=\"same\", use_bias=False\n", 213 | ")(x)\n", 214 | "x = layers.BatchNormalization(momentum=0.9)(x)\n", 215 | "x = layers.LeakyReLU(0.2)(x)\n", 216 | "x = layers.Conv2DTranspose(\n", 217 | " 64, kernel_size=4, strides=2, padding=\"same\", use_bias=False\n", 218 | ")(x)\n", 219 | "x = layers.BatchNormalization(momentum=0.9)(x)\n", 220 | "x = layers.LeakyReLU(0.2)(x)\n", 221 | "generator_output = layers.Conv2DTranspose(\n", 222 | " CHANNELS,\n", 223 | " kernel_size=4,\n", 224 | " strides=2,\n", 225 | " padding=\"same\",\n", 226 | " use_bias=False,\n", 227 | " activation=\"tanh\",\n", 228 | ")(x)\n", 229 | "generator = models.Model(generator_input, generator_output)\n", 230 | "generator.summary()" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": null, 236 | "id": "ed493725-488b-4390-8c64-661f3b97a632", 237 | "metadata": { 238 | "tags": [] 239 | }, 240 | "outputs": [], 241 | "source": [ 242 | "class DCGAN(models.Model):\n", 243 | " def __init__(self, discriminator, generator, latent_dim):\n", 244 | " super(DCGAN, self).__init__()\n", 245 | " self.discriminator = discriminator\n", 246 | " self.generator = generator\n", 247 | " self.latent_dim = latent_dim\n", 248 | "\n", 249 | " def compile(self, d_optimizer, g_optimizer):\n", 250 | " super(DCGAN, self).compile()\n", 251 | " self.loss_fn = losses.BinaryCrossentropy()\n", 252 | " self.d_optimizer = d_optimizer\n", 253 | " self.g_optimizer = g_optimizer\n", 254 | " self.d_loss_metric = metrics.Mean(name=\"d_loss\")\n", 255 | " self.d_real_acc_metric = metrics.BinaryAccuracy(name=\"d_real_acc\")\n", 256 | " self.d_fake_acc_metric = metrics.BinaryAccuracy(name=\"d_fake_acc\")\n", 257 | " self.d_acc_metric = metrics.BinaryAccuracy(name=\"d_acc\")\n", 258 | " self.g_loss_metric = metrics.Mean(name=\"g_loss\")\n", 259 | " self.g_acc_metric = metrics.BinaryAccuracy(name=\"g_acc\")\n", 260 | "\n", 261 | " @property\n", 262 | " def metrics(self):\n", 263 | " return [\n", 264 | " self.d_loss_metric,\n", 265 | " self.d_real_acc_metric,\n", 266 | " self.d_fake_acc_metric,\n", 267 | " self.d_acc_metric,\n", 268 | " self.g_loss_metric,\n", 269 | " self.g_acc_metric,\n", 270 | " ]\n", 271 | "\n", 272 | " def train_step(self, real_images):\n", 273 | " # Sample random points in the latent space\n", 274 | " batch_size = tf.shape(real_images)[0]\n", 275 | " random_latent_vectors = tf.random.normal(\n", 276 | " shape=(batch_size, self.latent_dim)\n", 277 | " )\n", 278 | "\n", 279 | " # Train the discriminator on fake images\n", 280 | " with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n", 281 | " generated_images = self.generator(\n", 282 | " random_latent_vectors, training=True\n", 283 | " )\n", 284 | " real_predictions = self.discriminator(real_images, training=True)\n", 285 | " fake_predictions = self.discriminator(\n", 286 | " generated_images, training=True\n", 287 | " )\n", 288 | "\n", 289 | " real_labels = tf.ones_like(real_predictions)\n", 290 | " real_noisy_labels = real_labels + NOISE_PARAM * tf.random.uniform(\n", 291 | " tf.shape(real_predictions)\n", 292 | " )\n", 293 | " fake_labels = tf.zeros_like(fake_predictions)\n", 294 | " fake_noisy_labels = fake_labels - NOISE_PARAM * tf.random.uniform(\n", 295 | " tf.shape(fake_predictions)\n", 296 | " )\n", 297 | "\n", 298 | " d_real_loss = self.loss_fn(real_noisy_labels, real_predictions)\n", 299 | " d_fake_loss = self.loss_fn(fake_noisy_labels, fake_predictions)\n", 300 | " d_loss = (d_real_loss + d_fake_loss) / 2.0\n", 301 | "\n", 302 | " g_loss = self.loss_fn(real_labels, fake_predictions)\n", 303 | "\n", 304 | " gradients_of_discriminator = disc_tape.gradient(\n", 305 | " d_loss, self.discriminator.trainable_variables\n", 306 | " )\n", 307 | " gradients_of_generator = gen_tape.gradient(\n", 308 | " g_loss, self.generator.trainable_variables\n", 309 | " )\n", 310 | "\n", 311 | " self.d_optimizer.apply_gradients(\n", 312 | " zip(gradients_of_discriminator, discriminator.trainable_variables)\n", 313 | " )\n", 314 | " self.g_optimizer.apply_gradients(\n", 315 | " zip(gradients_of_generator, generator.trainable_variables)\n", 316 | " )\n", 317 | "\n", 318 | " # Update metrics\n", 319 | " self.d_loss_metric.update_state(d_loss)\n", 320 | " self.d_real_acc_metric.update_state(real_labels, real_predictions)\n", 321 | " self.d_fake_acc_metric.update_state(fake_labels, fake_predictions)\n", 322 | " self.d_acc_metric.update_state(\n", 323 | " [real_labels, fake_labels], [real_predictions, fake_predictions]\n", 324 | " )\n", 325 | " self.g_loss_metric.update_state(g_loss)\n", 326 | " self.g_acc_metric.update_state(real_labels, fake_predictions)\n", 327 | "\n", 328 | " return {m.name: m.result() for m in self.metrics}" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": null, 334 | "id": "e898dd8e-f562-4517-8351-fc2f8b617a24", 335 | "metadata": {}, 336 | "outputs": [], 337 | "source": [ 338 | "# Create a DCGAN\n", 339 | "dcgan = DCGAN(\n", 340 | " discriminator=discriminator, generator=generator, latent_dim=Z_DIM\n", 341 | ")" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": null, 347 | "id": "800a3c6e-fb11-4792-b6bc-9a43a7c977ad", 348 | "metadata": {}, 349 | "outputs": [], 350 | "source": [ 351 | "if LOAD_MODEL:\n", 352 | " dcgan.load_weights(\"./checkpoint/checkpoint.ckpt\")" 353 | ] 354 | }, 355 | { 356 | "cell_type": "markdown", 357 | "id": "35b14665-4359-447b-be58-3fd58ba69084", 358 | "metadata": {}, 359 | "source": [ 360 | "## 3. Train the GAN " 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": null, 366 | "id": "245e6374-5f5b-4efa-be0a-07b182f82d2d", 367 | "metadata": {}, 368 | "outputs": [], 369 | "source": [ 370 | "dcgan.compile(\n", 371 | " d_optimizer=optimizers.Adam(\n", 372 | " learning_rate=LEARNING_RATE, beta_1=ADAM_BETA_1, beta_2=ADAM_BETA_2\n", 373 | " ),\n", 374 | " g_optimizer=optimizers.Adam(\n", 375 | " learning_rate=LEARNING_RATE, beta_1=ADAM_BETA_1, beta_2=ADAM_BETA_2\n", 376 | " ),\n", 377 | ")" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": null, 383 | "id": "349865fe-ffbe-450e-97be-043ae1740e78", 384 | "metadata": {}, 385 | "outputs": [], 386 | "source": [ 387 | "# Create a model save checkpoint\n", 388 | "model_checkpoint_callback = callbacks.ModelCheckpoint(\n", 389 | " filepath=\"./checkpoint/checkpoint.ckpt\",\n", 390 | " save_weights_only=True,\n", 391 | " save_freq=\"epoch\",\n", 392 | " verbose=0,\n", 393 | ")\n", 394 | "\n", 395 | "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")\n", 396 | "\n", 397 | "\n", 398 | "class ImageGenerator(callbacks.Callback):\n", 399 | " def __init__(self, num_img, latent_dim):\n", 400 | " self.num_img = num_img\n", 401 | " self.latent_dim = latent_dim\n", 402 | "\n", 403 | " def on_epoch_end(self, epoch, logs=None):\n", 404 | " random_latent_vectors = tf.random.normal(\n", 405 | " shape=(self.num_img, self.latent_dim)\n", 406 | " )\n", 407 | " generated_images = self.model.generator(random_latent_vectors)\n", 408 | " generated_images = generated_images * 127.5 + 127.5\n", 409 | " generated_images = generated_images.numpy()\n", 410 | " display(\n", 411 | " generated_images,\n", 412 | " save_to=\"./output/generated_img_%03d.png\" % (epoch),\n", 413 | " )" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": null, 419 | "id": "a8913a77-f472-4008-9039-dba00e6db980", 420 | "metadata": { 421 | "tags": [] 422 | }, 423 | "outputs": [], 424 | "source": [ 425 | "dcgan.fit(\n", 426 | " train,\n", 427 | " epochs=EPOCHS,\n", 428 | " callbacks=[\n", 429 | " model_checkpoint_callback,\n", 430 | " tensorboard_callback,\n", 431 | " ImageGenerator(num_img=10, latent_dim=Z_DIM),\n", 432 | " ],\n", 433 | ")" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": null, 439 | "id": "369bde44-2e39-4bc6-8549-a3a27ecce55c", 440 | "metadata": {}, 441 | "outputs": [], 442 | "source": [ 443 | "# Save the final models\n", 444 | "generator.save(\"./models/generator\")\n", 445 | "discriminator.save(\"./models/discriminator\")" 446 | ] 447 | }, 448 | { 449 | "cell_type": "markdown", 450 | "id": "26999087-0e85-4ddf-ba5f-13036466fce7", 451 | "metadata": {}, 452 | "source": [ 453 | "## 3. Generate new images " 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": null, 459 | "id": "48e90117-2e0e-4f4b-9138-b25dce9870f6", 460 | "metadata": {}, 461 | "outputs": [], 462 | "source": [ 463 | "# Sample some points in the latent space, from the standard normal distribution\n", 464 | "grid_width, grid_height = (10, 3)\n", 465 | "z_sample = np.random.normal(size=(grid_width * grid_height, Z_DIM))" 466 | ] 467 | }, 468 | { 469 | "cell_type": "code", 470 | "execution_count": null, 471 | "id": "2e185509-3861-425c-882d-4fe16d82d355", 472 | "metadata": {}, 473 | "outputs": [], 474 | "source": [ 475 | "# Decode the sampled points\n", 476 | "reconstructions = generator.predict(z_sample)" 477 | ] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "execution_count": null, 482 | "id": "5e5e43c0-ef06-4d32-acf6-09f00cf2fa9c", 483 | "metadata": {}, 484 | "outputs": [], 485 | "source": [ 486 | "# Draw a plot of decoded images\n", 487 | "fig = plt.figure(figsize=(18, 5))\n", 488 | "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n", 489 | "\n", 490 | "# Output the grid of faces\n", 491 | "for i in range(grid_width * grid_height):\n", 492 | " ax = fig.add_subplot(grid_height, grid_width, i + 1)\n", 493 | " ax.axis(\"off\")\n", 494 | " ax.imshow(reconstructions[i, :, :], cmap=\"Greys\")" 495 | ] 496 | }, 497 | { 498 | "cell_type": "code", 499 | "execution_count": null, 500 | "id": "59bfd4e4-7fdc-488a-86df-2c131c904803", 501 | "metadata": {}, 502 | "outputs": [], 503 | "source": [ 504 | "def compare_images(img1, img2):\n", 505 | " return np.mean(np.abs(img1 - img2))" 506 | ] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "execution_count": null, 511 | "id": "b568995a-d4ad-478c-98b2-d9a1cdb9e841", 512 | "metadata": {}, 513 | "outputs": [], 514 | "source": [ 515 | "all_data = []\n", 516 | "for i in train.as_numpy_iterator():\n", 517 | " all_data.extend(i)\n", 518 | "all_data = np.array(all_data)" 519 | ] 520 | }, 521 | { 522 | "cell_type": "code", 523 | "execution_count": null, 524 | "id": "7c4b5bb1-3581-49b3-81ce-920400d6f3f7", 525 | "metadata": {}, 526 | "outputs": [], 527 | "source": [ 528 | "r, c = 3, 5\n", 529 | "fig, axs = plt.subplots(r, c, figsize=(10, 6))\n", 530 | "fig.suptitle(\"Generated images\", fontsize=20)\n", 531 | "\n", 532 | "noise = np.random.normal(size=(r * c, Z_DIM))\n", 533 | "gen_imgs = generator.predict(noise)\n", 534 | "\n", 535 | "cnt = 0\n", 536 | "for i in range(r):\n", 537 | " for j in range(c):\n", 538 | " axs[i, j].imshow(gen_imgs[cnt], cmap=\"gray_r\")\n", 539 | " axs[i, j].axis(\"off\")\n", 540 | " cnt += 1\n", 541 | "\n", 542 | "plt.show()" 543 | ] 544 | }, 545 | { 546 | "cell_type": "code", 547 | "execution_count": null, 548 | "id": "51923e98-bf0e-4de4-948a-05147c486b72", 549 | "metadata": {}, 550 | "outputs": [], 551 | "source": [ 552 | "fig, axs = plt.subplots(r, c, figsize=(10, 6))\n", 553 | "fig.suptitle(\"Closest images in the training set\", fontsize=20)\n", 554 | "\n", 555 | "cnt = 0\n", 556 | "for i in range(r):\n", 557 | " for j in range(c):\n", 558 | " c_diff = 99999\n", 559 | " c_img = None\n", 560 | " for k_idx, k in enumerate(all_data):\n", 561 | " diff = compare_images(gen_imgs[cnt], k)\n", 562 | " if diff < c_diff:\n", 563 | " c_img = np.copy(k)\n", 564 | " c_diff = diff\n", 565 | " axs[i, j].imshow(c_img, cmap=\"gray_r\")\n", 566 | " axs[i, j].axis(\"off\")\n", 567 | " cnt += 1\n", 568 | "\n", 569 | "plt.show()" 570 | ] 571 | } 572 | ], 573 | "metadata": { 574 | "kernelspec": { 575 | "display_name": "Python 3 (ipykernel)", 576 | "language": "python", 577 | "name": "python3" 578 | }, 579 | "language_info": { 580 | "codemirror_mode": { 581 | "name": "ipython", 582 | "version": 3 583 | }, 584 | "file_extension": ".py", 585 | "mimetype": "text/x-python", 586 | "name": "python", 587 | "nbconvert_exporter": "python", 588 | "pygments_lexer": "ipython3", 589 | "version": "3.8.10" 590 | } 591 | }, 592 | "nbformat": 4, 593 | "nbformat_minor": 5 594 | } 595 | -------------------------------------------------------------------------------- /notebooks/04_gan/01_dcgan/logs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/04_gan/01_dcgan/models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/04_gan/01_dcgan/output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/04_gan/02_wgan_gp/checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/04_gan/02_wgan_gp/logs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/04_gan/02_wgan_gp/models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/04_gan/02_wgan_gp/output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/04_gan/02_wgan_gp/wgan_gp.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b076bd1a-b236-4fbc-953d-8295b25122ae", 6 | "metadata": {}, 7 | "source": [ 8 | "# 🤪 WGAN - CelebA Faces" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "9235cbd1-f136-411c-88d9-f69f270c0b96", 14 | "metadata": {}, 15 | "source": [ 16 | "In this notebook, we'll walk through the steps required to train your own Wasserstein GAN on the CelebA faces dataset" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "098731a6-4193-4fda-9018-b14114c54250", 22 | "metadata": {}, 23 | "source": [ 24 | "The code has been adapted from the excellent [WGAN-GP tutorial](https://keras.io/examples/generative/wgan_gp/) created by Aakash Kumar Nain, available on the Keras website." 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "id": "84acc7be-6764-4668-b2bb-178f63deeed3", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "%load_ext autoreload\n", 35 | "%autoreload 2\n", 36 | "import numpy as np\n", 37 | "\n", 38 | "import tensorflow as tf\n", 39 | "from tensorflow.keras import (\n", 40 | " layers,\n", 41 | " models,\n", 42 | " callbacks,\n", 43 | " utils,\n", 44 | " metrics,\n", 45 | " optimizers,\n", 46 | ")\n", 47 | "\n", 48 | "from notebooks.utils import display, sample_batch" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "id": "339e6268-ebd7-4feb-86db-1fe7abccdbe5", 54 | "metadata": {}, 55 | "source": [ 56 | "## 0. Parameters " 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "id": "1b2ee6ce-129f-4833-b0c5-fa567381c4e0", 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "IMAGE_SIZE = 64\n", 67 | "CHANNELS = 3\n", 68 | "BATCH_SIZE = 512\n", 69 | "NUM_FEATURES = 64\n", 70 | "Z_DIM = 128\n", 71 | "LEARNING_RATE = 0.0002\n", 72 | "ADAM_BETA_1 = 0.5\n", 73 | "ADAM_BETA_2 = 0.999\n", 74 | "EPOCHS = 200\n", 75 | "CRITIC_STEPS = 3\n", 76 | "GP_WEIGHT = 10.0\n", 77 | "LOAD_MODEL = False\n", 78 | "ADAM_BETA_1 = 0.5\n", 79 | "ADAM_BETA_2 = 0.9" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "id": "b7716fac-0010-49b0-b98e-53be2259edde", 85 | "metadata": {}, 86 | "source": [ 87 | "## 1. Prepare the data " 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "id": "9a73e5a4-1638-411c-8d3c-29f823424458", 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "# Load the data\n", 98 | "train_data = utils.image_dataset_from_directory(\n", 99 | " \"/app/data/celeba-dataset/img_align_celeba/img_align_celeba\",\n", 100 | " labels=None,\n", 101 | " color_mode=\"rgb\",\n", 102 | " image_size=(IMAGE_SIZE, IMAGE_SIZE),\n", 103 | " batch_size=BATCH_SIZE,\n", 104 | " shuffle=True,\n", 105 | " seed=42,\n", 106 | " interpolation=\"bilinear\",\n", 107 | ")" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "id": "ebae2f0d-59fd-4796-841f-7213eae638de", 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "# Preprocess the data\n", 118 | "def preprocess(img):\n", 119 | " \"\"\"\n", 120 | " Normalize and reshape the images\n", 121 | " \"\"\"\n", 122 | " img = (tf.cast(img, \"float32\") - 127.5) / 127.5\n", 123 | " return img\n", 124 | "\n", 125 | "\n", 126 | "train = train_data.map(lambda x: preprocess(x))" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "id": "fa53709f-7f3f-483b-9db8-2e5f9b9942c2", 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "# Show some faces from the training set\n", 137 | "train_sample = sample_batch(train)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "id": "b86c15ef-82b2-4a75-99f7-2d8810440403", 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "display(train_sample, cmap=None)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "id": "aff50401-3abe-4c10-bba8-b35bc13ad7d5", 153 | "metadata": { 154 | "tags": [] 155 | }, 156 | "source": [ 157 | "## 2. Build the WGAN-GP " 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "id": "371eb69d-e534-4666-a412-b5b6fe24689a", 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "critic_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS))\n", 168 | "x = layers.Conv2D(64, kernel_size=4, strides=2, padding=\"same\")(critic_input)\n", 169 | "x = layers.LeakyReLU(0.2)(x)\n", 170 | "x = layers.Conv2D(128, kernel_size=4, strides=2, padding=\"same\")(x)\n", 171 | "x = layers.LeakyReLU()(x)\n", 172 | "x = layers.Dropout(0.3)(x)\n", 173 | "x = layers.Conv2D(256, kernel_size=4, strides=2, padding=\"same\")(x)\n", 174 | "x = layers.LeakyReLU(0.2)(x)\n", 175 | "x = layers.Dropout(0.3)(x)\n", 176 | "x = layers.Conv2D(512, kernel_size=4, strides=2, padding=\"same\")(x)\n", 177 | "x = layers.LeakyReLU(0.2)(x)\n", 178 | "x = layers.Dropout(0.3)(x)\n", 179 | "x = layers.Conv2D(1, kernel_size=4, strides=1, padding=\"valid\")(x)\n", 180 | "critic_output = layers.Flatten()(x)\n", 181 | "\n", 182 | "critic = models.Model(critic_input, critic_output)\n", 183 | "critic.summary()" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "id": "086e2584-c60d-4990-89f4-2092c44e023e", 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "generator_input = layers.Input(shape=(Z_DIM,))\n", 194 | "x = layers.Reshape((1, 1, Z_DIM))(generator_input)\n", 195 | "x = layers.Conv2DTranspose(\n", 196 | " 512, kernel_size=4, strides=1, padding=\"valid\", use_bias=False\n", 197 | ")(x)\n", 198 | "x = layers.BatchNormalization(momentum=0.9)(x)\n", 199 | "x = layers.LeakyReLU(0.2)(x)\n", 200 | "x = layers.Conv2DTranspose(\n", 201 | " 256, kernel_size=4, strides=2, padding=\"same\", use_bias=False\n", 202 | ")(x)\n", 203 | "x = layers.BatchNormalization(momentum=0.9)(x)\n", 204 | "x = layers.LeakyReLU(0.2)(x)\n", 205 | "x = layers.Conv2DTranspose(\n", 206 | " 128, kernel_size=4, strides=2, padding=\"same\", use_bias=False\n", 207 | ")(x)\n", 208 | "x = layers.BatchNormalization(momentum=0.9)(x)\n", 209 | "x = layers.LeakyReLU(0.2)(x)\n", 210 | "x = layers.Conv2DTranspose(\n", 211 | " 64, kernel_size=4, strides=2, padding=\"same\", use_bias=False\n", 212 | ")(x)\n", 213 | "x = layers.BatchNormalization(momentum=0.9)(x)\n", 214 | "x = layers.LeakyReLU(0.2)(x)\n", 215 | "generator_output = layers.Conv2DTranspose(\n", 216 | " CHANNELS, kernel_size=4, strides=2, padding=\"same\", activation=\"tanh\"\n", 217 | ")(x)\n", 218 | "generator = models.Model(generator_input, generator_output)\n", 219 | "generator.summary()" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "id": "88010f20-fb61-498c-b2b2-dac96f6c03b3", 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "class WGANGP(models.Model):\n", 230 | " def __init__(self, critic, generator, latent_dim, critic_steps, gp_weight):\n", 231 | " super(WGANGP, self).__init__()\n", 232 | " self.critic = critic\n", 233 | " self.generator = generator\n", 234 | " self.latent_dim = latent_dim\n", 235 | " self.critic_steps = critic_steps\n", 236 | " self.gp_weight = gp_weight\n", 237 | "\n", 238 | " def compile(self, c_optimizer, g_optimizer):\n", 239 | " super(WGANGP, self).compile()\n", 240 | " self.c_optimizer = c_optimizer\n", 241 | " self.g_optimizer = g_optimizer\n", 242 | " self.c_wass_loss_metric = metrics.Mean(name=\"c_wass_loss\")\n", 243 | " self.c_gp_metric = metrics.Mean(name=\"c_gp\")\n", 244 | " self.c_loss_metric = metrics.Mean(name=\"c_loss\")\n", 245 | " self.g_loss_metric = metrics.Mean(name=\"g_loss\")\n", 246 | "\n", 247 | " @property\n", 248 | " def metrics(self):\n", 249 | " return [\n", 250 | " self.c_loss_metric,\n", 251 | " self.c_wass_loss_metric,\n", 252 | " self.c_gp_metric,\n", 253 | " self.g_loss_metric,\n", 254 | " ]\n", 255 | "\n", 256 | " def gradient_penalty(self, batch_size, real_images, fake_images):\n", 257 | " alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)\n", 258 | " diff = fake_images - real_images\n", 259 | " interpolated = real_images + alpha * diff\n", 260 | "\n", 261 | " with tf.GradientTape() as gp_tape:\n", 262 | " gp_tape.watch(interpolated)\n", 263 | " pred = self.critic(interpolated, training=True)\n", 264 | "\n", 265 | " grads = gp_tape.gradient(pred, [interpolated])[0]\n", 266 | " norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))\n", 267 | " gp = tf.reduce_mean((norm - 1.0) ** 2)\n", 268 | " return gp\n", 269 | "\n", 270 | " def train_step(self, real_images):\n", 271 | " batch_size = tf.shape(real_images)[0]\n", 272 | "\n", 273 | " for i in range(self.critic_steps):\n", 274 | " random_latent_vectors = tf.random.normal(\n", 275 | " shape=(batch_size, self.latent_dim)\n", 276 | " )\n", 277 | "\n", 278 | " with tf.GradientTape() as tape:\n", 279 | " fake_images = self.generator(\n", 280 | " random_latent_vectors, training=True\n", 281 | " )\n", 282 | " fake_predictions = self.critic(fake_images, training=True)\n", 283 | " real_predictions = self.critic(real_images, training=True)\n", 284 | "\n", 285 | " c_wass_loss = tf.reduce_mean(fake_predictions) - tf.reduce_mean(\n", 286 | " real_predictions\n", 287 | " )\n", 288 | " c_gp = self.gradient_penalty(\n", 289 | " batch_size, real_images, fake_images\n", 290 | " )\n", 291 | " c_loss = c_wass_loss + c_gp * self.gp_weight\n", 292 | "\n", 293 | " c_gradient = tape.gradient(c_loss, self.critic.trainable_variables)\n", 294 | " self.c_optimizer.apply_gradients(\n", 295 | " zip(c_gradient, self.critic.trainable_variables)\n", 296 | " )\n", 297 | "\n", 298 | " random_latent_vectors = tf.random.normal(\n", 299 | " shape=(batch_size, self.latent_dim)\n", 300 | " )\n", 301 | " with tf.GradientTape() as tape:\n", 302 | " fake_images = self.generator(random_latent_vectors, training=True)\n", 303 | " fake_predictions = self.critic(fake_images, training=True)\n", 304 | " g_loss = -tf.reduce_mean(fake_predictions)\n", 305 | "\n", 306 | " gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)\n", 307 | " self.g_optimizer.apply_gradients(\n", 308 | " zip(gen_gradient, self.generator.trainable_variables)\n", 309 | " )\n", 310 | "\n", 311 | " self.c_loss_metric.update_state(c_loss)\n", 312 | " self.c_wass_loss_metric.update_state(c_wass_loss)\n", 313 | " self.c_gp_metric.update_state(c_gp)\n", 314 | " self.g_loss_metric.update_state(g_loss)\n", 315 | "\n", 316 | " return {m.name: m.result() for m in self.metrics}" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": null, 322 | "id": "edf2f892-9209-42ee-b251-1e7604df5335", 323 | "metadata": {}, 324 | "outputs": [], 325 | "source": [ 326 | "# Create a GAN\n", 327 | "wgangp = WGANGP(\n", 328 | " critic=critic,\n", 329 | " generator=generator,\n", 330 | " latent_dim=Z_DIM,\n", 331 | " critic_steps=CRITIC_STEPS,\n", 332 | " gp_weight=GP_WEIGHT,\n", 333 | ")" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": null, 339 | "id": "b2f48907-fa82-41b5-8caa-813b2f232c79", 340 | "metadata": { 341 | "tags": [] 342 | }, 343 | "outputs": [], 344 | "source": [ 345 | "if LOAD_MODEL:\n", 346 | " wgangp.load_weights(\"./checkpoint/checkpoint.ckpt\")" 347 | ] 348 | }, 349 | { 350 | "cell_type": "markdown", 351 | "id": "35b14665-4359-447b-be58-3fd58ba69084", 352 | "metadata": {}, 353 | "source": [ 354 | "## 3. Train the GAN " 355 | ] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "execution_count": null, 360 | "id": "b429fdad-ea9c-45a2-a556-eb950d793824", 361 | "metadata": {}, 362 | "outputs": [], 363 | "source": [ 364 | "# Compile the GAN\n", 365 | "wgangp.compile(\n", 366 | " c_optimizer=optimizers.Adam(\n", 367 | " learning_rate=LEARNING_RATE, beta_1=ADAM_BETA_1, beta_2=ADAM_BETA_2\n", 368 | " ),\n", 369 | " g_optimizer=optimizers.Adam(\n", 370 | " learning_rate=LEARNING_RATE, beta_1=ADAM_BETA_1, beta_2=ADAM_BETA_2\n", 371 | " ),\n", 372 | ")" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": null, 378 | "id": "c525e44b-b3bb-489c-9d35-fcfe3e714e6a", 379 | "metadata": {}, 380 | "outputs": [], 381 | "source": [ 382 | "# Create a model save checkpoint\n", 383 | "model_checkpoint_callback = callbacks.ModelCheckpoint(\n", 384 | " filepath=\"./checkpoint/checkpoint.ckpt\",\n", 385 | " save_weights_only=True,\n", 386 | " save_freq=\"epoch\",\n", 387 | " verbose=0,\n", 388 | ")\n", 389 | "\n", 390 | "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")\n", 391 | "\n", 392 | "\n", 393 | "class ImageGenerator(callbacks.Callback):\n", 394 | " def __init__(self, num_img, latent_dim):\n", 395 | " self.num_img = num_img\n", 396 | " self.latent_dim = latent_dim\n", 397 | "\n", 398 | " def on_epoch_end(self, epoch, logs=None):\n", 399 | " random_latent_vectors = tf.random.normal(\n", 400 | " shape=(self.num_img, self.latent_dim)\n", 401 | " )\n", 402 | " generated_images = self.model.generator(random_latent_vectors)\n", 403 | " generated_images = generated_images * 127.5 + 127.5\n", 404 | " generated_images = generated_images.numpy()\n", 405 | " display(\n", 406 | " generated_images,\n", 407 | " save_to=\"./output/generated_img_%03d.png\" % (epoch),\n", 408 | " cmap=None,\n", 409 | " )" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": null, 415 | "id": "b3c497b7-fa40-48df-b2bf-541239cc9400", 416 | "metadata": { 417 | "tags": [] 418 | }, 419 | "outputs": [], 420 | "source": [ 421 | "wgangp.fit(\n", 422 | " train,\n", 423 | " epochs=EPOCHS,\n", 424 | " steps_per_epoch=2,\n", 425 | " callbacks=[\n", 426 | " model_checkpoint_callback,\n", 427 | " tensorboard_callback,\n", 428 | " ImageGenerator(num_img=10, latent_dim=Z_DIM),\n", 429 | " ],\n", 430 | ")" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": null, 436 | "id": "028138af-d3a5-4134-b980-d3a8a703e70f", 437 | "metadata": {}, 438 | "outputs": [], 439 | "source": [ 440 | "# Save the final models\n", 441 | "generator.save(\"./models/generator\")\n", 442 | "critic.save(\"./models/critic\")" 443 | ] 444 | }, 445 | { 446 | "cell_type": "markdown", 447 | "id": "0765b66b-d12c-42c4-90fa-2ff851a9b3f5", 448 | "metadata": {}, 449 | "source": [ 450 | "## Generate images" 451 | ] 452 | }, 453 | { 454 | "cell_type": "code", 455 | "execution_count": null, 456 | "id": "86576e84-afc4-443a-b68d-9a5ee13ce730", 457 | "metadata": {}, 458 | "outputs": [], 459 | "source": [ 460 | "z_sample = np.random.normal(size=(10, Z_DIM))\n", 461 | "imgs = wgangp.generator.predict(z_sample)\n", 462 | "display(imgs, cmap=None)" 463 | ] 464 | } 465 | ], 466 | "metadata": { 467 | "kernelspec": { 468 | "display_name": "Python 3 (ipykernel)", 469 | "language": "python", 470 | "name": "python3" 471 | }, 472 | "language_info": { 473 | "codemirror_mode": { 474 | "name": "ipython", 475 | "version": 3 476 | }, 477 | "file_extension": ".py", 478 | "mimetype": "text/x-python", 479 | "name": "python", 480 | "nbconvert_exporter": "python", 481 | "pygments_lexer": "ipython3", 482 | "version": "3.8.10" 483 | }, 484 | "vscode": { 485 | "interpreter": { 486 | "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49" 487 | } 488 | } 489 | }, 490 | "nbformat": 4, 491 | "nbformat_minor": 5 492 | } 493 | -------------------------------------------------------------------------------- /notebooks/04_gan/03_cgan/checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/04_gan/03_cgan/logs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/04_gan/03_cgan/models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/04_gan/03_cgan/output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/05_autoregressive/01_lstm/checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/05_autoregressive/01_lstm/logs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/05_autoregressive/01_lstm/lstm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b076bd1a-b236-4fbc-953d-8295b25122ae", 6 | "metadata": {}, 7 | "source": [ 8 | "# 🥙 LSTM on Recipe Data" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "658a95da-9645-4bcf-bd9d-4b95a4b6f582", 14 | "metadata": {}, 15 | "source": [ 16 | "In this notebook, we'll walk through the steps required to train your own LSTM on the recipes dataset" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "id": "4e0d56cc-4773-4029-97d8-26f882ba79c9", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "%load_ext autoreload\n", 27 | "%autoreload 2\n", 28 | "\n", 29 | "import numpy as np\n", 30 | "import json\n", 31 | "import re\n", 32 | "import string\n", 33 | "\n", 34 | "import tensorflow as tf\n", 35 | "from tensorflow.keras import layers, models, callbacks, losses" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "id": "339e6268-ebd7-4feb-86db-1fe7abccdbe5", 41 | "metadata": {}, 42 | "source": [ 43 | "## 0. Parameters " 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "id": "2d8352af-343e-4c2e-8c91-95f8bac1c8a1", 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "VOCAB_SIZE = 10000\n", 54 | "MAX_LEN = 200\n", 55 | "EMBEDDING_DIM = 100\n", 56 | "N_UNITS = 128\n", 57 | "VALIDATION_SPLIT = 0.2\n", 58 | "SEED = 42\n", 59 | "LOAD_MODEL = False\n", 60 | "BATCH_SIZE = 32\n", 61 | "EPOCHS = 25" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "id": "b7716fac-0010-49b0-b98e-53be2259edde", 67 | "metadata": {}, 68 | "source": [ 69 | "## 1. Load the data " 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "id": "93cf6b0f-9667-4146-8911-763a8a2925d3", 76 | "metadata": { 77 | "tags": [] 78 | }, 79 | "outputs": [], 80 | "source": [ 81 | "# Load the full dataset\n", 82 | "with open(\"/app/data/epirecipes/full_format_recipes.json\") as json_data:\n", 83 | " recipe_data = json.load(json_data)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "id": "23a74eca-f1b7-4a46-9a1f-b5806a4ed361", 90 | "metadata": { 91 | "tags": [] 92 | }, 93 | "outputs": [], 94 | "source": [ 95 | "# Filter the dataset\n", 96 | "filtered_data = [\n", 97 | " \"Recipe for \" + x[\"title\"] + \" | \" + \" \".join(x[\"directions\"])\n", 98 | " for x in recipe_data\n", 99 | " if \"title\" in x\n", 100 | " and x[\"title\"] is not None\n", 101 | " and \"directions\" in x\n", 102 | " and x[\"directions\"] is not None\n", 103 | "]" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "id": "389c20de-0422-4c48-a7b4-6ee12a7bf0e2", 110 | "metadata": { 111 | "tags": [] 112 | }, 113 | "outputs": [], 114 | "source": [ 115 | "# Count the recipes\n", 116 | "n_recipes = len(filtered_data)\n", 117 | "print(f\"{n_recipes} recipes loaded\")" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "id": "1b2e3cf7-e416-460e-874a-0dd9637bca36", 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "example = filtered_data[9]\n", 128 | "print(example)" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "id": "3f871aaf-d873-41c7-8946-e4eef7ac17c1", 134 | "metadata": {}, 135 | "source": [ 136 | "## 2. Tokenise the data" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "id": "5b2064fb-5dcc-4657-b470-0928d10e2ddc", 143 | "metadata": { 144 | "tags": [] 145 | }, 146 | "outputs": [], 147 | "source": [ 148 | "# Pad the punctuation, to treat them as separate 'words'\n", 149 | "def pad_punctuation(s):\n", 150 | " s = re.sub(f\"([{string.punctuation}])\", r\" \\1 \", s)\n", 151 | " s = re.sub(\" +\", \" \", s)\n", 152 | " return s\n", 153 | "\n", 154 | "\n", 155 | "text_data = [pad_punctuation(x) for x in filtered_data]" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "id": "b87d7c65-9a46-492a-a5c0-a043b0d252f3", 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "# Display an example of a recipe\n", 166 | "example_data = text_data[9]\n", 167 | "example_data" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "id": "9834f916-b21a-4104-acc9-f28d3bd7a8c1", 174 | "metadata": { 175 | "tags": [] 176 | }, 177 | "outputs": [], 178 | "source": [ 179 | "# Convert to a Tensorflow Dataset\n", 180 | "text_ds = (\n", 181 | " tf.data.Dataset.from_tensor_slices(text_data)\n", 182 | " .batch(BATCH_SIZE)\n", 183 | " .shuffle(1000)\n", 184 | ")" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "id": "884c0bcb-0807-45a1-8f7e-a32f2c6fa4de", 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "# Create a vectorisation layer\n", 195 | "vectorize_layer = layers.TextVectorization(\n", 196 | " standardize=\"lower\",\n", 197 | " max_tokens=VOCAB_SIZE,\n", 198 | " output_mode=\"int\",\n", 199 | " output_sequence_length=MAX_LEN + 1,\n", 200 | ")" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "id": "4d6dd34a-d905-497b-926a-405380ebcf98", 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "# Adapt the layer to the training set\n", 211 | "vectorize_layer.adapt(text_ds)\n", 212 | "vocab = vectorize_layer.get_vocabulary()" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "id": "f6c1c7ce-3cf0-40d4-a3dc-ab7090f69f2f", 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [ 222 | "# Display some token:word mappings\n", 223 | "for i, word in enumerate(vocab[:10]):\n", 224 | " print(f\"{i}: {word}\")" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": null, 230 | "id": "1cc30186-7ec6-4eb6-b29a-65df6714d321", 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [ 234 | "# Display the same example converted to ints\n", 235 | "example_tokenised = vectorize_layer(example_data)\n", 236 | "print(example_tokenised.numpy())" 237 | ] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "id": "8c195efb-84c6-4be0-a989-a7542188ad35", 242 | "metadata": {}, 243 | "source": [ 244 | "## 3. Create the Training Set" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "id": "740294a1-1a6b-4c89-92f2-036d7d1b788b", 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "# Create the training set of recipes and the same text shifted by one word\n", 255 | "def prepare_inputs(text):\n", 256 | " text = tf.expand_dims(text, -1)\n", 257 | " tokenized_sentences = vectorize_layer(text)\n", 258 | " x = tokenized_sentences[:, :-1]\n", 259 | " y = tokenized_sentences[:, 1:]\n", 260 | " return x, y\n", 261 | "\n", 262 | "\n", 263 | "train_ds = text_ds.map(prepare_inputs)" 264 | ] 265 | }, 266 | { 267 | "cell_type": "markdown", 268 | "id": "aff50401-3abe-4c10-bba8-b35bc13ad7d5", 269 | "metadata": { 270 | "tags": [] 271 | }, 272 | "source": [ 273 | "## 4. Build the LSTM " 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": null, 279 | "id": "9230b5bf-b4a8-48d5-b73b-6899a598f296", 280 | "metadata": {}, 281 | "outputs": [], 282 | "source": [ 283 | "inputs = layers.Input(shape=(None,), dtype=\"int32\")\n", 284 | "x = layers.Embedding(VOCAB_SIZE, EMBEDDING_DIM)(inputs)\n", 285 | "x = layers.LSTM(N_UNITS, return_sequences=True)(x)\n", 286 | "outputs = layers.Dense(VOCAB_SIZE, activation=\"softmax\")(x)\n", 287 | "lstm = models.Model(inputs, outputs)\n", 288 | "lstm.summary()" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": null, 294 | "id": "800a3c6e-fb11-4792-b6bc-9a43a7c977ad", 295 | "metadata": { 296 | "tags": [] 297 | }, 298 | "outputs": [], 299 | "source": [ 300 | "if LOAD_MODEL:\n", 301 | " # model.load_weights('./models/model')\n", 302 | " lstm = models.load_model(\"./models/lstm\", compile=False)" 303 | ] 304 | }, 305 | { 306 | "cell_type": "markdown", 307 | "id": "35b14665-4359-447b-be58-3fd58ba69084", 308 | "metadata": {}, 309 | "source": [ 310 | "## 5. Train the LSTM " 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": null, 316 | "id": "ffb1bd3b-6fd9-4536-973e-6375bbcbf16d", 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [ 320 | "loss_fn = losses.SparseCategoricalCrossentropy()\n", 321 | "lstm.compile(\"adam\", loss_fn)" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": null, 327 | "id": "3ddcff5f-829d-4449-99d2-9a3cb68f7d72", 328 | "metadata": {}, 329 | "outputs": [], 330 | "source": [ 331 | "# Create a TextGenerator checkpoint\n", 332 | "class TextGenerator(callbacks.Callback):\n", 333 | " def __init__(self, index_to_word, top_k=10):\n", 334 | " self.index_to_word = index_to_word\n", 335 | " self.word_to_index = {\n", 336 | " word: index for index, word in enumerate(index_to_word)\n", 337 | " } # <1>\n", 338 | "\n", 339 | " def sample_from(self, probs, temperature): # <2>\n", 340 | " probs = probs ** (1 / temperature)\n", 341 | " probs = probs / np.sum(probs)\n", 342 | " return np.random.choice(len(probs), p=probs), probs\n", 343 | "\n", 344 | " def generate(self, start_prompt, max_tokens, temperature):\n", 345 | " start_tokens = [\n", 346 | " self.word_to_index.get(x, 1) for x in start_prompt.split()\n", 347 | " ] # <3>\n", 348 | " sample_token = None\n", 349 | " info = []\n", 350 | " while len(start_tokens) < max_tokens and sample_token != 0: # <4>\n", 351 | " x = np.array([start_tokens])\n", 352 | " y = self.model.predict(x, verbose=0) # <5>\n", 353 | " sample_token, probs = self.sample_from(y[0][-1], temperature) # <6>\n", 354 | " info.append({\"prompt\": start_prompt, \"word_probs\": probs})\n", 355 | " start_tokens.append(sample_token) # <7>\n", 356 | " start_prompt = start_prompt + \" \" + self.index_to_word[sample_token]\n", 357 | " print(f\"\\ngenerated text:\\n{start_prompt}\\n\")\n", 358 | " return info\n", 359 | "\n", 360 | " def on_epoch_end(self, epoch, logs=None):\n", 361 | " self.generate(\"recipe for\", max_tokens=100, temperature=1.0)" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": null, 367 | "id": "349865fe-ffbe-450e-97be-043ae1740e78", 368 | "metadata": {}, 369 | "outputs": [], 370 | "source": [ 371 | "# Create a model save checkpoint\n", 372 | "model_checkpoint_callback = callbacks.ModelCheckpoint(\n", 373 | " filepath=\"./checkpoint/checkpoint.ckpt\",\n", 374 | " save_weights_only=True,\n", 375 | " save_freq=\"epoch\",\n", 376 | " verbose=0,\n", 377 | ")\n", 378 | "\n", 379 | "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")\n", 380 | "\n", 381 | "# Tokenize starting prompt\n", 382 | "text_generator = TextGenerator(vocab)" 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": null, 388 | "id": "461c2b3e-b5ae-4def-8bd9-e7bab8c63d8e", 389 | "metadata": { 390 | "tags": [] 391 | }, 392 | "outputs": [], 393 | "source": [ 394 | "lstm.fit(\n", 395 | " train_ds,\n", 396 | " epochs=EPOCHS,\n", 397 | " callbacks=[model_checkpoint_callback, tensorboard_callback, text_generator],\n", 398 | ")" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": null, 404 | "id": "369bde44-2e39-4bc6-8549-a3a27ecce55c", 405 | "metadata": { 406 | "tags": [] 407 | }, 408 | "outputs": [], 409 | "source": [ 410 | "# Save the final model\n", 411 | "lstm.save(\"./models/lstm\")" 412 | ] 413 | }, 414 | { 415 | "cell_type": "markdown", 416 | "id": "d64e02d2-84dc-40c8-8446-40c09adf1e20", 417 | "metadata": {}, 418 | "source": [ 419 | "## 6. Generate text using the LSTM" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": null, 425 | "id": "4ad23adb-3ec9-4e9a-9a59-b9f9bafca649", 426 | "metadata": {}, 427 | "outputs": [], 428 | "source": [ 429 | "def print_probs(info, vocab, top_k=5):\n", 430 | " for i in info:\n", 431 | " print(f\"\\nPROMPT: {i['prompt']}\")\n", 432 | " word_probs = i[\"word_probs\"]\n", 433 | " p_sorted = np.sort(word_probs)[::-1][:top_k]\n", 434 | " i_sorted = np.argsort(word_probs)[::-1][:top_k]\n", 435 | " for p, i in zip(p_sorted, i_sorted):\n", 436 | " print(f\"{vocab[i]}: \\t{np.round(100*p,2)}%\")\n", 437 | " print(\"--------\\n\")" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": null, 443 | "id": "3cf25578-d47c-4b26-8252-fcdf2316a4ac", 444 | "metadata": {}, 445 | "outputs": [], 446 | "source": [ 447 | "info = text_generator.generate(\n", 448 | " \"recipe for roasted vegetables | chop 1 /\", max_tokens=10, temperature=1.0\n", 449 | ")" 450 | ] 451 | }, 452 | { 453 | "cell_type": "code", 454 | "execution_count": null, 455 | "id": "9df72866-b483-4489-8e26-d5e1466410fa", 456 | "metadata": { 457 | "tags": [] 458 | }, 459 | "outputs": [], 460 | "source": [ 461 | "print_probs(info, vocab)" 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": null, 467 | "id": "562e1fe8-cbcb-438f-9637-2f2a6279c924", 468 | "metadata": {}, 469 | "outputs": [], 470 | "source": [ 471 | "info = text_generator.generate(\n", 472 | " \"recipe for roasted vegetables | chop 1 /\", max_tokens=10, temperature=0.2\n", 473 | ")" 474 | ] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "execution_count": null, 479 | "id": "56356f21-04ac-40e5-94ff-291eca6a7054", 480 | "metadata": {}, 481 | "outputs": [], 482 | "source": [ 483 | "print_probs(info, vocab)" 484 | ] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": null, 489 | "id": "2e434497-07f3-4989-a68d-3e31cf8fa4fe", 490 | "metadata": {}, 491 | "outputs": [], 492 | "source": [ 493 | "info = text_generator.generate(\n", 494 | " \"recipe for chocolate ice cream |\", max_tokens=7, temperature=1.0\n", 495 | ")\n", 496 | "print_probs(info, vocab)" 497 | ] 498 | }, 499 | { 500 | "cell_type": "code", 501 | "execution_count": null, 502 | "id": "011cd0e0-956c-4a63-8ec3-f7dfed31764e", 503 | "metadata": {}, 504 | "outputs": [], 505 | "source": [ 506 | "info = text_generator.generate(\n", 507 | " \"recipe for chocolate ice cream |\", max_tokens=7, temperature=0.2\n", 508 | ")\n", 509 | "print_probs(info, vocab)" 510 | ] 511 | } 512 | ], 513 | "metadata": { 514 | "kernelspec": { 515 | "display_name": "Python 3 (ipykernel)", 516 | "language": "python", 517 | "name": "python3" 518 | }, 519 | "language_info": { 520 | "codemirror_mode": { 521 | "name": "ipython", 522 | "version": 3 523 | }, 524 | "file_extension": ".py", 525 | "mimetype": "text/x-python", 526 | "name": "python", 527 | "nbconvert_exporter": "python", 528 | "pygments_lexer": "ipython3", 529 | "version": "3.8.10" 530 | } 531 | }, 532 | "nbformat": 4, 533 | "nbformat_minor": 5 534 | } 535 | -------------------------------------------------------------------------------- /notebooks/05_autoregressive/01_lstm/models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/05_autoregressive/01_lstm/output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/05_autoregressive/02_pixelcnn/checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/05_autoregressive/02_pixelcnn/logs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/05_autoregressive/02_pixelcnn/models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/05_autoregressive/02_pixelcnn/output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/05_autoregressive/02_pixelcnn/pixelcnn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "206a93a2-e9b0-4ea2-a43f-696faa83ea03", 6 | "metadata": {}, 7 | "source": [ 8 | "# 👾 PixelCNN from scratch" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "1af9e216-7e84-4f5b-a2db-26aca3bea464", 14 | "metadata": {}, 15 | "source": [ 16 | "In this notebook, we'll walk through the steps required to train your own PixelCNN on the fashion MNIST dataset from scratch" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "9c1e6bbc-6f3b-48ac-a4f3-fde6f739f0ca", 22 | "metadata": {}, 23 | "source": [ 24 | "The code has been adapted from the excellent [PixelCNN tutorial](https://keras.io/examples/generative/pixelcnn/) created by ADMoreau, available on the Keras website." 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "id": "6acebfa8-4546-41fd-adaa-2307c65b1b8e", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "%load_ext autoreload\n", 35 | "%autoreload 2\n", 36 | "import numpy as np\n", 37 | "\n", 38 | "import tensorflow as tf\n", 39 | "from tensorflow.keras import datasets, layers, models, optimizers, callbacks\n", 40 | "\n", 41 | "from notebooks.utils import display" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "id": "8543166d-f4c7-43f8-a452-21ccbf2a0496", 47 | "metadata": {}, 48 | "source": [ 49 | "## 0. Parameters " 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "id": "444d84de-2843-40d6-8e2e-93691a5393ab", 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "IMAGE_SIZE = 16\n", 60 | "PIXEL_LEVELS = 4\n", 61 | "N_FILTERS = 128\n", 62 | "RESIDUAL_BLOCKS = 5\n", 63 | "BATCH_SIZE = 128\n", 64 | "EPOCHS = 150" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "id": "d65dac68-d20b-4ed9-a136-eed57095ce4f", 70 | "metadata": {}, 71 | "source": [ 72 | "## 1. Prepare the data " 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "id": "0ed0fc56-d1b0-4d42-b029-f4198f78e666", 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "# Load the data\n", 83 | "(x_train, _), (_, _) = datasets.fashion_mnist.load_data()" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "id": "b667e78c-8fa7-4e5b-a2c0-69e50166ef77", 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "# Preprocess the data\n", 94 | "def preprocess(imgs_int):\n", 95 | " imgs_int = np.expand_dims(imgs_int, -1)\n", 96 | " imgs_int = tf.image.resize(imgs_int, (IMAGE_SIZE, IMAGE_SIZE)).numpy()\n", 97 | " imgs_int = (imgs_int / (256 / PIXEL_LEVELS)).astype(int)\n", 98 | " imgs = imgs_int.astype(\"float32\")\n", 99 | " imgs = imgs / PIXEL_LEVELS\n", 100 | " return imgs, imgs_int\n", 101 | "\n", 102 | "\n", 103 | "input_data, output_data = preprocess(x_train)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "id": "e3c2b304-8385-4931-8291-9b7cc462c95e", 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "# Show some items of clothing from the training set\n", 114 | "display(input_data)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "id": "5ccd5cb2-8c7b-4667-8adb-4902f3fa60cf", 120 | "metadata": {}, 121 | "source": [ 122 | "## 2. Build the PixelCNN" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "id": "847050a5-e4e6-4134-9bfc-c690cb8cb44d", 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "# The first layer is the PixelCNN layer. This layer simply\n", 133 | "# builds on the 2D convolutional layer, but includes masking.\n", 134 | "class MaskedConv2D(layers.Layer):\n", 135 | " def __init__(self, mask_type, **kwargs):\n", 136 | " super(MaskedConv2D, self).__init__()\n", 137 | " self.mask_type = mask_type\n", 138 | " self.conv = layers.Conv2D(**kwargs)\n", 139 | "\n", 140 | " def build(self, input_shape):\n", 141 | " # Build the conv2d layer to initialize kernel variables\n", 142 | " self.conv.build(input_shape)\n", 143 | " # Use the initialized kernel to create the mask\n", 144 | " kernel_shape = self.conv.kernel.get_shape()\n", 145 | " self.mask = np.zeros(shape=kernel_shape)\n", 146 | " self.mask[: kernel_shape[0] // 2, ...] = 1.0\n", 147 | " self.mask[kernel_shape[0] // 2, : kernel_shape[1] // 2, ...] = 1.0\n", 148 | " if self.mask_type == \"B\":\n", 149 | " self.mask[kernel_shape[0] // 2, kernel_shape[1] // 2, ...] = 1.0\n", 150 | "\n", 151 | " def call(self, inputs):\n", 152 | " self.conv.kernel.assign(self.conv.kernel * self.mask)\n", 153 | " return self.conv(inputs)\n", 154 | "\n", 155 | " def get_config(self):\n", 156 | " cfg = super().get_config()\n", 157 | " return cfg" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "id": "a52f7795-790e-47b0-b724-80be3e3c3666", 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "class ResidualBlock(layers.Layer):\n", 168 | " def __init__(self, filters, **kwargs):\n", 169 | " super(ResidualBlock, self).__init__(**kwargs)\n", 170 | " self.conv1 = layers.Conv2D(\n", 171 | " filters=filters // 2, kernel_size=1, activation=\"relu\"\n", 172 | " )\n", 173 | " self.pixel_conv = MaskedConv2D(\n", 174 | " mask_type=\"B\",\n", 175 | " filters=filters // 2,\n", 176 | " kernel_size=3,\n", 177 | " activation=\"relu\",\n", 178 | " padding=\"same\",\n", 179 | " )\n", 180 | " self.conv2 = layers.Conv2D(\n", 181 | " filters=filters, kernel_size=1, activation=\"relu\"\n", 182 | " )\n", 183 | "\n", 184 | " def call(self, inputs):\n", 185 | " x = self.conv1(inputs)\n", 186 | " x = self.pixel_conv(x)\n", 187 | " x = self.conv2(x)\n", 188 | " return layers.add([inputs, x])\n", 189 | "\n", 190 | " def get_config(self):\n", 191 | " cfg = super().get_config()\n", 192 | " return cfg" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "id": "19b4508f-84de-42a9-a77f-950fb493db13", 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "inputs = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 1))\n", 203 | "x = MaskedConv2D(\n", 204 | " mask_type=\"A\",\n", 205 | " filters=N_FILTERS,\n", 206 | " kernel_size=7,\n", 207 | " activation=\"relu\",\n", 208 | " padding=\"same\",\n", 209 | ")(inputs)\n", 210 | "\n", 211 | "for _ in range(RESIDUAL_BLOCKS):\n", 212 | " x = ResidualBlock(filters=N_FILTERS)(x)\n", 213 | "\n", 214 | "for _ in range(2):\n", 215 | " x = MaskedConv2D(\n", 216 | " mask_type=\"B\",\n", 217 | " filters=N_FILTERS,\n", 218 | " kernel_size=1,\n", 219 | " strides=1,\n", 220 | " activation=\"relu\",\n", 221 | " padding=\"valid\",\n", 222 | " )(x)\n", 223 | "\n", 224 | "out = layers.Conv2D(\n", 225 | " filters=PIXEL_LEVELS,\n", 226 | " kernel_size=1,\n", 227 | " strides=1,\n", 228 | " activation=\"softmax\",\n", 229 | " padding=\"valid\",\n", 230 | ")(x)\n", 231 | "\n", 232 | "pixel_cnn = models.Model(inputs, out)\n", 233 | "pixel_cnn.summary()" 234 | ] 235 | }, 236 | { 237 | "cell_type": "markdown", 238 | "id": "442b5ffa-67a3-4b15-a342-eb1eed5e87ac", 239 | "metadata": {}, 240 | "source": [ 241 | "## 3. Train the PixelCNN " 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "id": "7204789a-2ad3-48bf-b7e8-00d4cab10d9c", 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [ 251 | "adam = optimizers.Adam(learning_rate=0.0005)\n", 252 | "pixel_cnn.compile(optimizer=adam, loss=\"sparse_categorical_crossentropy\")" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": null, 258 | "id": "09d327fc-aff8-40e6-b390-d1bff4c06ea6", 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")\n", 263 | "\n", 264 | "\n", 265 | "class ImageGenerator(callbacks.Callback):\n", 266 | " def __init__(self, num_img):\n", 267 | " self.num_img = num_img\n", 268 | "\n", 269 | " def sample_from(self, probs, temperature): # <2>\n", 270 | " probs = probs ** (1 / temperature)\n", 271 | " probs = probs / np.sum(probs)\n", 272 | " return np.random.choice(len(probs), p=probs)\n", 273 | "\n", 274 | " def generate(self, temperature):\n", 275 | " generated_images = np.zeros(\n", 276 | " shape=(self.num_img,) + (pixel_cnn.input_shape)[1:]\n", 277 | " )\n", 278 | " batch, rows, cols, channels = generated_images.shape\n", 279 | "\n", 280 | " for row in range(rows):\n", 281 | " for col in range(cols):\n", 282 | " for channel in range(channels):\n", 283 | " probs = self.model.predict(generated_images, verbose=0)[\n", 284 | " :, row, col, :\n", 285 | " ]\n", 286 | " generated_images[:, row, col, channel] = [\n", 287 | " self.sample_from(x, temperature) for x in probs\n", 288 | " ]\n", 289 | " generated_images[:, row, col, channel] /= PIXEL_LEVELS\n", 290 | "\n", 291 | " return generated_images\n", 292 | "\n", 293 | " def on_epoch_end(self, epoch, logs=None):\n", 294 | " generated_images = self.generate(temperature=1.0)\n", 295 | " display(\n", 296 | " generated_images,\n", 297 | " save_to=\"./output/generated_img_%03d.png\" % (epoch),\n", 298 | " )\n", 299 | "\n", 300 | "\n", 301 | "img_generator_callback = ImageGenerator(num_img=10)" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": null, 307 | "id": "85231056-d4a4-4897-ab91-065325a18d93", 308 | "metadata": { 309 | "tags": [] 310 | }, 311 | "outputs": [], 312 | "source": [ 313 | "pixel_cnn.fit(\n", 314 | " input_data,\n", 315 | " output_data,\n", 316 | " batch_size=BATCH_SIZE,\n", 317 | " epochs=EPOCHS,\n", 318 | " callbacks=[tensorboard_callback, img_generator_callback],\n", 319 | ")" 320 | ] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "id": "cfb4fa72-dd2d-44c1-ad18-9c965060683e", 325 | "metadata": {}, 326 | "source": [ 327 | "## 4. Generate images " 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": null, 333 | "id": "7bbd4643-be09-49ba-b7bc-a524a2f00806", 334 | "metadata": {}, 335 | "outputs": [], 336 | "source": [ 337 | "generated_images = img_generator_callback.generate(temperature=1.0)" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": null, 343 | "id": "52cadb4b-ae2c-42a9-92ac-68e2131380ef", 344 | "metadata": {}, 345 | "outputs": [], 346 | "source": [ 347 | "display(generated_images)" 348 | ] 349 | } 350 | ], 351 | "metadata": { 352 | "kernelspec": { 353 | "display_name": "Python 3 (ipykernel)", 354 | "language": "python", 355 | "name": "python3" 356 | }, 357 | "language_info": { 358 | "codemirror_mode": { 359 | "name": "ipython", 360 | "version": 3 361 | }, 362 | "file_extension": ".py", 363 | "mimetype": "text/x-python", 364 | "name": "python", 365 | "nbconvert_exporter": "python", 366 | "pygments_lexer": "ipython3", 367 | "version": "3.8.10" 368 | } 369 | }, 370 | "nbformat": 4, 371 | "nbformat_minor": 5 372 | } 373 | -------------------------------------------------------------------------------- /notebooks/05_autoregressive/03_pixelcnn_md/checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/05_autoregressive/03_pixelcnn_md/logs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/05_autoregressive/03_pixelcnn_md/models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/05_autoregressive/03_pixelcnn_md/output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/05_autoregressive/03_pixelcnn_md/pixelcnn_md.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b076bd1a-b236-4fbc-953d-8295b25122ae", 6 | "metadata": {}, 7 | "source": [ 8 | "# 👾 PixelCNN using Tensorflow distributions" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "9235cbd1-f136-411c-88d9-f69f270c0b96", 14 | "metadata": {}, 15 | "source": [ 16 | "In this notebook, we'll walk through the steps required to train your own PixelCNN on the fashion MNIST dataset using Tensorflow distributions" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "id": "84acc7be-6764-4668-b2bb-178f63deeed3", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "%load_ext autoreload\n", 27 | "%autoreload 2\n", 28 | "import numpy as np\n", 29 | "\n", 30 | "import tensorflow as tf\n", 31 | "from tensorflow.keras import datasets, layers, models, optimizers, callbacks\n", 32 | "import tensorflow_probability as tfp\n", 33 | "\n", 34 | "from notebooks.utils import display" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "id": "339e6268-ebd7-4feb-86db-1fe7abccdbe5", 40 | "metadata": {}, 41 | "source": [ 42 | "## 0. Parameters " 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "id": "1b2ee6ce-129f-4833-b0c5-fa567381c4e0", 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "IMAGE_SIZE = 32\n", 53 | "N_COMPONENTS = 5\n", 54 | "EPOCHS = 10\n", 55 | "BATCH_SIZE = 128" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "id": "b7716fac-0010-49b0-b98e-53be2259edde", 61 | "metadata": {}, 62 | "source": [ 63 | "## 1. Prepare the data " 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "id": "9a73e5a4-1638-411c-8d3c-29f823424458", 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "# Load the data\n", 74 | "(x_train, _), (_, _) = datasets.fashion_mnist.load_data()" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "id": "ebae2f0d-59fd-4796-841f-7213eae638de", 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "# Preprocess the data\n", 85 | "\n", 86 | "\n", 87 | "def preprocess(imgs):\n", 88 | " imgs = np.expand_dims(imgs, -1)\n", 89 | " imgs = tf.image.resize(imgs, (IMAGE_SIZE, IMAGE_SIZE)).numpy()\n", 90 | " return imgs\n", 91 | "\n", 92 | "\n", 93 | "input_data = preprocess(x_train)" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "id": "fa53709f-7f3f-483b-9db8-2e5f9b9942c2", 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "# Show some items of clothing from the training set\n", 104 | "display(input_data)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "id": "aff50401-3abe-4c10-bba8-b35bc13ad7d5", 110 | "metadata": { 111 | "tags": [] 112 | }, 113 | "source": [ 114 | "## 2. Build the PixelCNN " 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "id": "71a2a4a1-690e-4c94-b323-86f0e5b691d5", 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "# Define a Pixel CNN network\n", 125 | "dist = tfp.distributions.PixelCNN(\n", 126 | " image_shape=(IMAGE_SIZE, IMAGE_SIZE, 1),\n", 127 | " num_resnet=1,\n", 128 | " num_hierarchies=2,\n", 129 | " num_filters=32,\n", 130 | " num_logistic_mix=N_COMPONENTS,\n", 131 | " dropout_p=0.3,\n", 132 | ")\n", 133 | "\n", 134 | "# Define the model input\n", 135 | "image_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 1))\n", 136 | "\n", 137 | "# Define the log likelihood for the loss fn\n", 138 | "log_prob = dist.log_prob(image_input)\n", 139 | "\n", 140 | "# Define the model\n", 141 | "pixelcnn = models.Model(inputs=image_input, outputs=log_prob)\n", 142 | "pixelcnn.add_loss(-tf.reduce_mean(log_prob))" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "id": "35b14665-4359-447b-be58-3fd58ba69084", 148 | "metadata": {}, 149 | "source": [ 150 | "## 3. Train the PixelCNN " 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "id": "d9ec362d-41fa-473a-ad56-ebeec6cfd3b8", 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "# Compile and train the model\n", 161 | "pixelcnn.compile(\n", 162 | " optimizer=optimizers.Adam(0.001),\n", 163 | ")" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "id": "c525e44b-b3bb-489c-9d35-fcfe3e714e6a", 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")\n", 174 | "\n", 175 | "\n", 176 | "class ImageGenerator(callbacks.Callback):\n", 177 | " def __init__(self, num_img):\n", 178 | " self.num_img = num_img\n", 179 | "\n", 180 | " def generate(self):\n", 181 | " return dist.sample(self.num_img).numpy()\n", 182 | "\n", 183 | " def on_epoch_end(self, epoch, logs=None):\n", 184 | " generated_images = self.generate()\n", 185 | " display(\n", 186 | " generated_images,\n", 187 | " n=self.num_img,\n", 188 | " save_to=\"./output/generated_img_%03d.png\" % (epoch),\n", 189 | " )\n", 190 | "\n", 191 | "\n", 192 | "img_generator_callback = ImageGenerator(num_img=2)" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "id": "bd6a5a71-eb55-4ec0-9c8c-cb11a382ff90", 199 | "metadata": { 200 | "tags": [] 201 | }, 202 | "outputs": [], 203 | "source": [ 204 | "pixelcnn.fit(\n", 205 | " input_data,\n", 206 | " batch_size=BATCH_SIZE,\n", 207 | " epochs=EPOCHS,\n", 208 | " verbose=True,\n", 209 | " callbacks=[tensorboard_callback, img_generator_callback],\n", 210 | ")" 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "id": "fb1f295f-ade0-4040-a6a5-a7b428b08ebc", 216 | "metadata": {}, 217 | "source": [ 218 | "## 4. Generate images " 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": null, 224 | "id": "8db3cfe3-339e-463d-8af5-fbd403385fca", 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [ 228 | "generated_images = img_generator_callback.generate()" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "id": "80087297-3f47-4e0c-ac89-8758d4386d7c", 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "display(generated_images, n=img_generator_callback.num_img)" 239 | ] 240 | } 241 | ], 242 | "metadata": { 243 | "kernelspec": { 244 | "display_name": "Python 3 (ipykernel)", 245 | "language": "python", 246 | "name": "python3" 247 | }, 248 | "language_info": { 249 | "codemirror_mode": { 250 | "name": "ipython", 251 | "version": 3 252 | }, 253 | "file_extension": ".py", 254 | "mimetype": "text/x-python", 255 | "name": "python", 256 | "nbconvert_exporter": "python", 257 | "pygments_lexer": "ipython3", 258 | "version": "3.8.10" 259 | } 260 | }, 261 | "nbformat": 4, 262 | "nbformat_minor": 5 263 | } 264 | -------------------------------------------------------------------------------- /notebooks/06_normflow/01_realnvp/checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/06_normflow/01_realnvp/logs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/06_normflow/01_realnvp/models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/06_normflow/01_realnvp/output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/06_normflow/01_realnvp/realnvp.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b076bd1a-b236-4fbc-953d-8295b25122ae", 6 | "metadata": {}, 7 | "source": [ 8 | "# 🌀 RealNVP" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "9235cbd1-f136-411c-88d9-f69f270c0b96", 14 | "metadata": {}, 15 | "source": [ 16 | "In this notebook, we'll walk through the steps required to train your own RealNVP network to predict the distribution of a demo dataset" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "2db9a506-bf8f-4a40-ab27-703b0c82371b", 22 | "metadata": {}, 23 | "source": [ 24 | "The code has been adapted from the excellent [RealNVP tutorial](https://keras.io/examples/generative/real_nvp) created by Mandolini Giorgio Maria, Sanna Daniele and Zannini Quirini Giorgio available on the Keras website." 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "id": "84acc7be-6764-4668-b2bb-178f63deeed3", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "%load_ext autoreload\n", 35 | "%autoreload 2\n", 36 | "\n", 37 | "import numpy as np\n", 38 | "import matplotlib.pyplot as plt\n", 39 | "\n", 40 | "from sklearn import datasets\n", 41 | "\n", 42 | "import tensorflow as tf\n", 43 | "from tensorflow.keras import (\n", 44 | " layers,\n", 45 | " models,\n", 46 | " regularizers,\n", 47 | " metrics,\n", 48 | " optimizers,\n", 49 | " callbacks,\n", 50 | ")\n", 51 | "import tensorflow_probability as tfp" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "id": "339e6268-ebd7-4feb-86db-1fe7abccdbe5", 57 | "metadata": {}, 58 | "source": [ 59 | "## 0. Parameters " 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "id": "1b2ee6ce-129f-4833-b0c5-fa567381c4e0", 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "COUPLING_DIM = 256\n", 70 | "COUPLING_LAYERS = 2\n", 71 | "INPUT_DIM = 2\n", 72 | "REGULARIZATION = 0.01\n", 73 | "BATCH_SIZE = 256\n", 74 | "EPOCHS = 300" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "id": "9a73e5a4-1638-411c-8d3c-29f823424458", 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "# Load the data\n", 85 | "data = datasets.make_moons(30000, noise=0.05)[0].astype(\"float32\")\n", 86 | "norm = layers.Normalization()\n", 87 | "norm.adapt(data)\n", 88 | "normalized_data = norm(data)\n", 89 | "plt.scatter(\n", 90 | " normalized_data.numpy()[:, 0], normalized_data.numpy()[:, 1], c=\"green\"\n", 91 | ")\n", 92 | "plt.show()" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "id": "aff50401-3abe-4c10-bba8-b35bc13ad7d5", 98 | "metadata": { 99 | "tags": [] 100 | }, 101 | "source": [ 102 | "## 2. Build the RealNVP network " 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "id": "71a2a4a1-690e-4c94-b323-86f0e5b691d5", 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "def Coupling(input_dim, coupling_dim, reg):\n", 113 | " input_layer = layers.Input(shape=input_dim)\n", 114 | "\n", 115 | " s_layer_1 = layers.Dense(\n", 116 | " coupling_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n", 117 | " )(input_layer)\n", 118 | " s_layer_2 = layers.Dense(\n", 119 | " coupling_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n", 120 | " )(s_layer_1)\n", 121 | " s_layer_3 = layers.Dense(\n", 122 | " coupling_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n", 123 | " )(s_layer_2)\n", 124 | " s_layer_4 = layers.Dense(\n", 125 | " coupling_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n", 126 | " )(s_layer_3)\n", 127 | " s_layer_5 = layers.Dense(\n", 128 | " input_dim, activation=\"tanh\", kernel_regularizer=regularizers.l2(reg)\n", 129 | " )(s_layer_4)\n", 130 | "\n", 131 | " t_layer_1 = layers.Dense(\n", 132 | " coupling_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n", 133 | " )(input_layer)\n", 134 | " t_layer_2 = layers.Dense(\n", 135 | " coupling_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n", 136 | " )(t_layer_1)\n", 137 | " t_layer_3 = layers.Dense(\n", 138 | " coupling_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n", 139 | " )(t_layer_2)\n", 140 | " t_layer_4 = layers.Dense(\n", 141 | " coupling_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n", 142 | " )(t_layer_3)\n", 143 | " t_layer_5 = layers.Dense(\n", 144 | " input_dim, activation=\"linear\", kernel_regularizer=regularizers.l2(reg)\n", 145 | " )(t_layer_4)\n", 146 | "\n", 147 | " return models.Model(inputs=input_layer, outputs=[s_layer_5, t_layer_5])" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "id": "0f4dcd44-4189-4f39-b262-7afedb00a5a9", 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "class RealNVP(models.Model):\n", 158 | " def __init__(\n", 159 | " self, input_dim, coupling_layers, coupling_dim, regularization\n", 160 | " ):\n", 161 | " super(RealNVP, self).__init__()\n", 162 | " self.coupling_layers = coupling_layers\n", 163 | " self.distribution = tfp.distributions.MultivariateNormalDiag(\n", 164 | " loc=[0.0, 0.0], scale_diag=[1.0, 1.0]\n", 165 | " )\n", 166 | " self.masks = np.array(\n", 167 | " [[0, 1], [1, 0]] * (coupling_layers // 2), dtype=\"float32\"\n", 168 | " )\n", 169 | " self.loss_tracker = metrics.Mean(name=\"loss\")\n", 170 | " self.layers_list = [\n", 171 | " Coupling(input_dim, coupling_dim, regularization)\n", 172 | " for i in range(coupling_layers)\n", 173 | " ]\n", 174 | "\n", 175 | " @property\n", 176 | " def metrics(self):\n", 177 | " return [self.loss_tracker]\n", 178 | "\n", 179 | " def call(self, x, training=True):\n", 180 | " log_det_inv = 0\n", 181 | " direction = 1\n", 182 | " if training:\n", 183 | " direction = -1\n", 184 | " for i in range(self.coupling_layers)[::direction]:\n", 185 | " x_masked = x * self.masks[i]\n", 186 | " reversed_mask = 1 - self.masks[i]\n", 187 | " s, t = self.layers_list[i](x_masked)\n", 188 | " s *= reversed_mask\n", 189 | " t *= reversed_mask\n", 190 | " gate = (direction - 1) / 2\n", 191 | " x = (\n", 192 | " reversed_mask\n", 193 | " * (x * tf.exp(direction * s) + direction * t * tf.exp(gate * s))\n", 194 | " + x_masked\n", 195 | " )\n", 196 | " log_det_inv += gate * tf.reduce_sum(s, axis=1)\n", 197 | " return x, log_det_inv\n", 198 | "\n", 199 | " def log_loss(self, x):\n", 200 | " y, logdet = self(x)\n", 201 | " log_likelihood = self.distribution.log_prob(y) + logdet\n", 202 | " return -tf.reduce_mean(log_likelihood)\n", 203 | "\n", 204 | " def train_step(self, data):\n", 205 | " with tf.GradientTape() as tape:\n", 206 | " loss = self.log_loss(data)\n", 207 | " g = tape.gradient(loss, self.trainable_variables)\n", 208 | " self.optimizer.apply_gradients(zip(g, self.trainable_variables))\n", 209 | " self.loss_tracker.update_state(loss)\n", 210 | " return {\"loss\": self.loss_tracker.result()}\n", 211 | "\n", 212 | " def test_step(self, data):\n", 213 | " loss = self.log_loss(data)\n", 214 | " self.loss_tracker.update_state(loss)\n", 215 | " return {\"loss\": self.loss_tracker.result()}\n", 216 | "\n", 217 | "\n", 218 | "model = RealNVP(\n", 219 | " input_dim=INPUT_DIM,\n", 220 | " coupling_layers=COUPLING_LAYERS,\n", 221 | " coupling_dim=COUPLING_DIM,\n", 222 | " regularization=REGULARIZATION,\n", 223 | ")" 224 | ] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "id": "35b14665-4359-447b-be58-3fd58ba69084", 229 | "metadata": {}, 230 | "source": [ 231 | "## 3. Train the RealNVP network " 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "id": "d9ec362d-41fa-473a-ad56-ebeec6cfd3b8", 238 | "metadata": {}, 239 | "outputs": [], 240 | "source": [ 241 | "# Compile and train the model\n", 242 | "model.compile(optimizer=optimizers.Adam(learning_rate=0.0001))" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": null, 248 | "id": "c525e44b-b3bb-489c-9d35-fcfe3e714e6a", 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")\n", 253 | "\n", 254 | "\n", 255 | "class ImageGenerator(callbacks.Callback):\n", 256 | " def __init__(self, num_samples):\n", 257 | " self.num_samples = num_samples\n", 258 | "\n", 259 | " def generate(self):\n", 260 | " # From data to latent space.\n", 261 | " z, _ = model(normalized_data)\n", 262 | "\n", 263 | " # From latent space to data.\n", 264 | " samples = model.distribution.sample(self.num_samples)\n", 265 | " x, _ = model.predict(samples, verbose=0)\n", 266 | "\n", 267 | " return x, z, samples\n", 268 | "\n", 269 | " def display(self, x, z, samples, save_to=None):\n", 270 | " f, axes = plt.subplots(2, 2)\n", 271 | " f.set_size_inches(8, 5)\n", 272 | "\n", 273 | " axes[0, 0].scatter(\n", 274 | " normalized_data[:, 0], normalized_data[:, 1], color=\"r\", s=1\n", 275 | " )\n", 276 | " axes[0, 0].set(title=\"Data space X\", xlabel=\"x_1\", ylabel=\"x_2\")\n", 277 | " axes[0, 0].set_xlim([-2, 2])\n", 278 | " axes[0, 0].set_ylim([-2, 2])\n", 279 | " axes[0, 1].scatter(z[:, 0], z[:, 1], color=\"r\", s=1)\n", 280 | " axes[0, 1].set(title=\"f(X)\", xlabel=\"z_1\", ylabel=\"z_2\")\n", 281 | " axes[0, 1].set_xlim([-2, 2])\n", 282 | " axes[0, 1].set_ylim([-2, 2])\n", 283 | " axes[1, 0].scatter(samples[:, 0], samples[:, 1], color=\"g\", s=1)\n", 284 | " axes[1, 0].set(title=\"Latent space Z\", xlabel=\"z_1\", ylabel=\"z_2\")\n", 285 | " axes[1, 0].set_xlim([-2, 2])\n", 286 | " axes[1, 0].set_ylim([-2, 2])\n", 287 | " axes[1, 1].scatter(x[:, 0], x[:, 1], color=\"g\", s=1)\n", 288 | " axes[1, 1].set(title=\"g(Z)\", xlabel=\"x_1\", ylabel=\"x_2\")\n", 289 | " axes[1, 1].set_xlim([-2, 2])\n", 290 | " axes[1, 1].set_ylim([-2, 2])\n", 291 | "\n", 292 | " plt.subplots_adjust(wspace=0.3, hspace=0.6)\n", 293 | " if save_to:\n", 294 | " plt.savefig(save_to)\n", 295 | " print(f\"\\nSaved to {save_to}\")\n", 296 | "\n", 297 | " plt.show()\n", 298 | "\n", 299 | " def on_epoch_end(self, epoch, logs=None):\n", 300 | " if epoch % 10 == 0:\n", 301 | " x, z, samples = self.generate()\n", 302 | " self.display(\n", 303 | " x,\n", 304 | " z,\n", 305 | " samples,\n", 306 | " save_to=\"./output/generated_img_%03d.png\" % (epoch),\n", 307 | " )\n", 308 | "\n", 309 | "\n", 310 | "img_generator_callback = ImageGenerator(num_samples=3000)" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": null, 316 | "id": "bd6a5a71-eb55-4ec0-9c8c-cb11a382ff90", 317 | "metadata": { 318 | "tags": [] 319 | }, 320 | "outputs": [], 321 | "source": [ 322 | "model.fit(\n", 323 | " normalized_data,\n", 324 | " batch_size=BATCH_SIZE,\n", 325 | " epochs=EPOCHS,\n", 326 | " callbacks=[tensorboard_callback, img_generator_callback],\n", 327 | ")" 328 | ] 329 | }, 330 | { 331 | "cell_type": "markdown", 332 | "id": "fb1f295f-ade0-4040-a6a5-a7b428b08ebc", 333 | "metadata": {}, 334 | "source": [ 335 | "## 4. Generate images " 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": null, 341 | "id": "8db3cfe3-339e-463d-8af5-fbd403385fca", 342 | "metadata": {}, 343 | "outputs": [], 344 | "source": [ 345 | "x, z, samples = img_generator_callback.generate()" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": null, 351 | "id": "80087297-3f47-4e0c-ac89-8758d4386d7c", 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [ 355 | "img_generator_callback.display(x, z, samples)" 356 | ] 357 | } 358 | ], 359 | "metadata": { 360 | "kernelspec": { 361 | "display_name": "Python 3 (ipykernel)", 362 | "language": "python", 363 | "name": "python3" 364 | }, 365 | "language_info": { 366 | "codemirror_mode": { 367 | "name": "ipython", 368 | "version": 3 369 | }, 370 | "file_extension": ".py", 371 | "mimetype": "text/x-python", 372 | "name": "python", 373 | "nbconvert_exporter": "python", 374 | "pygments_lexer": "ipython3", 375 | "version": "3.8.10" 376 | } 377 | }, 378 | "nbformat": 4, 379 | "nbformat_minor": 5 380 | } 381 | -------------------------------------------------------------------------------- /notebooks/07_ebm/01_ebm/checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/07_ebm/01_ebm/ebm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b076bd1a-b236-4fbc-953d-8295b25122ae", 6 | "metadata": {}, 7 | "source": [ 8 | "# ⚡️ Energy-Based Models" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "9235cbd1-f136-411c-88d9-f69f270c0b96", 14 | "metadata": {}, 15 | "source": [ 16 | "In this notebook, we'll walk through the steps required to train your own Energy Based Model to predict the distribution of a demo dataset" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "2531aef5-c81a-4b53-a344-4b979dd4eec5", 22 | "metadata": {}, 23 | "source": [ 24 | "The code is adapted from the excellent ['Deep Energy-Based Generative Models' tutorial](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial8/Deep_Energy_Models.html) created by Phillip Lippe." 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "id": "84acc7be-6764-4668-b2bb-178f63deeed3", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "%load_ext autoreload\n", 35 | "%autoreload 2\n", 36 | "\n", 37 | "import numpy as np\n", 38 | "\n", 39 | "import tensorflow as tf\n", 40 | "from tensorflow.keras import (\n", 41 | " datasets,\n", 42 | " layers,\n", 43 | " models,\n", 44 | " optimizers,\n", 45 | " activations,\n", 46 | " metrics,\n", 47 | " callbacks,\n", 48 | ")\n", 49 | "\n", 50 | "from notebooks.utils import display, sample_batch\n", 51 | "import random" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "id": "339e6268-ebd7-4feb-86db-1fe7abccdbe5", 57 | "metadata": {}, 58 | "source": [ 59 | "## 0. Parameters " 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "id": "1b2ee6ce-129f-4833-b0c5-fa567381c4e0", 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "IMAGE_SIZE = 32\n", 70 | "CHANNELS = 1\n", 71 | "STEP_SIZE = 10\n", 72 | "STEPS = 60\n", 73 | "NOISE = 0.005\n", 74 | "ALPHA = 0.1\n", 75 | "GRADIENT_CLIP = 0.03\n", 76 | "BATCH_SIZE = 128\n", 77 | "BUFFER_SIZE = 8192\n", 78 | "LEARNING_RATE = 0.0001\n", 79 | "EPOCHS = 60\n", 80 | "LOAD_MODEL = False" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "id": "9a73e5a4-1638-411c-8d3c-29f823424458", 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "# Load the data\n", 91 | "(x_train, _), (x_test, _) = datasets.mnist.load_data()" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "id": "20697102-8c8d-4582-88d4-f8e2af84e060", 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "# Preprocess the data\n", 102 | "\n", 103 | "\n", 104 | "def preprocess(imgs):\n", 105 | " \"\"\"\n", 106 | " Normalize and reshape the images\n", 107 | " \"\"\"\n", 108 | " imgs = (imgs.astype(\"float32\") - 127.5) / 127.5\n", 109 | " imgs = np.pad(imgs, ((0, 0), (2, 2), (2, 2)), constant_values=-1.0)\n", 110 | " imgs = np.expand_dims(imgs, -1)\n", 111 | " return imgs\n", 112 | "\n", 113 | "\n", 114 | "x_train = preprocess(x_train)\n", 115 | "x_test = preprocess(x_test)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "id": "13668819-2e42-4661-8682-33ff2c24ae8b", 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "x_train = tf.data.Dataset.from_tensor_slices(x_train).batch(BATCH_SIZE)\n", 126 | "x_test = tf.data.Dataset.from_tensor_slices(x_test).batch(BATCH_SIZE)" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "id": "a7e1a420-699e-4869-8d10-3c049dbad030", 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "# Show some items of clothing from the training set\n", 137 | "train_sample = sample_batch(x_train)\n", 138 | "display(train_sample)" 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "id": "f53945d9-b7c5-49d0-a356-bcf1d1e1798b", 144 | "metadata": {}, 145 | "source": [ 146 | "## 2. Build the EBM network " 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "id": "8936d951-3281-4424-9cce-59433976bf2f", 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "ebm_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS))\n", 157 | "x = layers.Conv2D(\n", 158 | " 16, kernel_size=5, strides=2, padding=\"same\", activation=activations.swish\n", 159 | ")(ebm_input)\n", 160 | "x = layers.Conv2D(\n", 161 | " 32, kernel_size=3, strides=2, padding=\"same\", activation=activations.swish\n", 162 | ")(x)\n", 163 | "x = layers.Conv2D(\n", 164 | " 64, kernel_size=3, strides=2, padding=\"same\", activation=activations.swish\n", 165 | ")(x)\n", 166 | "x = layers.Conv2D(\n", 167 | " 64, kernel_size=3, strides=2, padding=\"same\", activation=activations.swish\n", 168 | ")(x)\n", 169 | "x = layers.Flatten()(x)\n", 170 | "x = layers.Dense(64, activation=activations.swish)(x)\n", 171 | "ebm_output = layers.Dense(1)(x)\n", 172 | "model = models.Model(ebm_input, ebm_output)\n", 173 | "model.summary()" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "id": "32221908-8819-48fa-8e57-0dc5179ca2cf", 180 | "metadata": { 181 | "tags": [] 182 | }, 183 | "outputs": [], 184 | "source": [ 185 | "if LOAD_MODEL:\n", 186 | " model.load_weights(\"./models/model.h5\")" 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "id": "1f392424-45a9-49cc-8ea0-c1bec9064d74", 192 | "metadata": {}, 193 | "source": [ 194 | "## 2. Set up a Langevin sampler function " 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "id": "bf10775a-0fbf-42df-aca5-be4b256a0c2b", 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "# Function to generate samples using Langevin Dynamics\n", 205 | "def generate_samples(\n", 206 | " model, inp_imgs, steps, step_size, noise, return_img_per_step=False\n", 207 | "):\n", 208 | " imgs_per_step = []\n", 209 | " for _ in range(steps):\n", 210 | " inp_imgs += tf.random.normal(inp_imgs.shape, mean=0, stddev=noise)\n", 211 | " inp_imgs = tf.clip_by_value(inp_imgs, -1.0, 1.0)\n", 212 | " with tf.GradientTape() as tape:\n", 213 | " tape.watch(inp_imgs)\n", 214 | " out_score = model(inp_imgs)\n", 215 | " grads = tape.gradient(out_score, inp_imgs)\n", 216 | " grads = tf.clip_by_value(grads, -GRADIENT_CLIP, GRADIENT_CLIP)\n", 217 | " inp_imgs += step_size * grads\n", 218 | " inp_imgs = tf.clip_by_value(inp_imgs, -1.0, 1.0)\n", 219 | " if return_img_per_step:\n", 220 | " imgs_per_step.append(inp_imgs)\n", 221 | " if return_img_per_step:\n", 222 | " return tf.stack(imgs_per_step, axis=0)\n", 223 | " else:\n", 224 | " return inp_imgs" 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "id": "180fb0a1-ed16-47c2-b326-ad66071cd6e2", 230 | "metadata": {}, 231 | "source": [ 232 | "## 3. Set up a buffer to store examples " 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": null, 238 | "id": "52615dcd-be2b-4e05-b729-0ec45ea6ef98", 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [ 242 | "class Buffer:\n", 243 | " def __init__(self, model):\n", 244 | " super().__init__()\n", 245 | " self.model = model\n", 246 | " self.examples = [\n", 247 | " tf.random.uniform(shape=(1, IMAGE_SIZE, IMAGE_SIZE, CHANNELS)) * 2\n", 248 | " - 1\n", 249 | " for _ in range(BATCH_SIZE)\n", 250 | " ]\n", 251 | "\n", 252 | " def sample_new_exmps(self, steps, step_size, noise):\n", 253 | " n_new = np.random.binomial(BATCH_SIZE, 0.05)\n", 254 | " rand_imgs = (\n", 255 | " tf.random.uniform((n_new, IMAGE_SIZE, IMAGE_SIZE, CHANNELS)) * 2 - 1\n", 256 | " )\n", 257 | " old_imgs = tf.concat(\n", 258 | " random.choices(self.examples, k=BATCH_SIZE - n_new), axis=0\n", 259 | " )\n", 260 | " inp_imgs = tf.concat([rand_imgs, old_imgs], axis=0)\n", 261 | " inp_imgs = generate_samples(\n", 262 | " self.model, inp_imgs, steps=steps, step_size=step_size, noise=noise\n", 263 | " )\n", 264 | " self.examples = tf.split(inp_imgs, BATCH_SIZE, axis=0) + self.examples\n", 265 | " self.examples = self.examples[:BUFFER_SIZE]\n", 266 | " return inp_imgs" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": null, 272 | "id": "71a2a4a1-690e-4c94-b323-86f0e5b691d5", 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [ 276 | "class EBM(models.Model):\n", 277 | " def __init__(self):\n", 278 | " super(EBM, self).__init__()\n", 279 | " self.model = model\n", 280 | " self.buffer = Buffer(self.model)\n", 281 | " self.alpha = ALPHA\n", 282 | " self.loss_metric = metrics.Mean(name=\"loss\")\n", 283 | " self.reg_loss_metric = metrics.Mean(name=\"reg\")\n", 284 | " self.cdiv_loss_metric = metrics.Mean(name=\"cdiv\")\n", 285 | " self.real_out_metric = metrics.Mean(name=\"real\")\n", 286 | " self.fake_out_metric = metrics.Mean(name=\"fake\")\n", 287 | "\n", 288 | " @property\n", 289 | " def metrics(self):\n", 290 | " return [\n", 291 | " self.loss_metric,\n", 292 | " self.reg_loss_metric,\n", 293 | " self.cdiv_loss_metric,\n", 294 | " self.real_out_metric,\n", 295 | " self.fake_out_metric,\n", 296 | " ]\n", 297 | "\n", 298 | " def train_step(self, real_imgs):\n", 299 | " real_imgs += tf.random.normal(\n", 300 | " shape=tf.shape(real_imgs), mean=0, stddev=NOISE\n", 301 | " )\n", 302 | " real_imgs = tf.clip_by_value(real_imgs, -1.0, 1.0)\n", 303 | " fake_imgs = self.buffer.sample_new_exmps(\n", 304 | " steps=STEPS, step_size=STEP_SIZE, noise=NOISE\n", 305 | " )\n", 306 | " inp_imgs = tf.concat([real_imgs, fake_imgs], axis=0)\n", 307 | " with tf.GradientTape() as training_tape:\n", 308 | " real_out, fake_out = tf.split(self.model(inp_imgs), 2, axis=0)\n", 309 | " cdiv_loss = tf.reduce_mean(fake_out, axis=0) - tf.reduce_mean(\n", 310 | " real_out, axis=0\n", 311 | " )\n", 312 | " reg_loss = self.alpha * tf.reduce_mean(\n", 313 | " real_out**2 + fake_out**2, axis=0\n", 314 | " )\n", 315 | " loss = cdiv_loss + reg_loss\n", 316 | " grads = training_tape.gradient(loss, self.model.trainable_variables)\n", 317 | " self.optimizer.apply_gradients(\n", 318 | " zip(grads, self.model.trainable_variables)\n", 319 | " )\n", 320 | " self.loss_metric.update_state(loss)\n", 321 | " self.reg_loss_metric.update_state(reg_loss)\n", 322 | " self.cdiv_loss_metric.update_state(cdiv_loss)\n", 323 | " self.real_out_metric.update_state(tf.reduce_mean(real_out, axis=0))\n", 324 | " self.fake_out_metric.update_state(tf.reduce_mean(fake_out, axis=0))\n", 325 | " return {m.name: m.result() for m in self.metrics}\n", 326 | "\n", 327 | " def test_step(self, real_imgs):\n", 328 | " batch_size = real_imgs.shape[0]\n", 329 | " fake_imgs = (\n", 330 | " tf.random.uniform((batch_size, IMAGE_SIZE, IMAGE_SIZE, CHANNELS))\n", 331 | " * 2\n", 332 | " - 1\n", 333 | " )\n", 334 | " inp_imgs = tf.concat([real_imgs, fake_imgs], axis=0)\n", 335 | " real_out, fake_out = tf.split(self.model(inp_imgs), 2, axis=0)\n", 336 | " cdiv = tf.reduce_mean(fake_out, axis=0) - tf.reduce_mean(\n", 337 | " real_out, axis=0\n", 338 | " )\n", 339 | " self.cdiv_loss_metric.update_state(cdiv)\n", 340 | " self.real_out_metric.update_state(tf.reduce_mean(real_out, axis=0))\n", 341 | " self.fake_out_metric.update_state(tf.reduce_mean(fake_out, axis=0))\n", 342 | " return {m.name: m.result() for m in self.metrics[2:]}" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": null, 348 | "id": "6337e801-eb59-4abe-84dc-9536cf4dc257", 349 | "metadata": {}, 350 | "outputs": [], 351 | "source": [ 352 | "ebm = EBM()" 353 | ] 354 | }, 355 | { 356 | "cell_type": "markdown", 357 | "id": "35b14665-4359-447b-be58-3fd58ba69084", 358 | "metadata": {}, 359 | "source": [ 360 | "## 3. Train the EBM network " 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": null, 366 | "id": "d9ec362d-41fa-473a-ad56-ebeec6cfd3b8", 367 | "metadata": {}, 368 | "outputs": [], 369 | "source": [ 370 | "# Compile and train the model\n", 371 | "ebm.compile(\n", 372 | " optimizer=optimizers.Adam(learning_rate=LEARNING_RATE), run_eagerly=True\n", 373 | ")" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": null, 379 | "id": "8ceca4de-f634-40ff-beb8-09ba42fd0f75", 380 | "metadata": {}, 381 | "outputs": [], 382 | "source": [ 383 | "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")\n", 384 | "\n", 385 | "\n", 386 | "class ImageGenerator(callbacks.Callback):\n", 387 | " def __init__(self, num_img):\n", 388 | " self.num_img = num_img\n", 389 | "\n", 390 | " def on_epoch_end(self, epoch, logs=None):\n", 391 | " start_imgs = (\n", 392 | " np.random.uniform(\n", 393 | " size=(self.num_img, IMAGE_SIZE, IMAGE_SIZE, CHANNELS)\n", 394 | " )\n", 395 | " * 2\n", 396 | " - 1\n", 397 | " )\n", 398 | " generated_images = generate_samples(\n", 399 | " ebm.model,\n", 400 | " start_imgs,\n", 401 | " steps=1000,\n", 402 | " step_size=STEP_SIZE,\n", 403 | " noise=NOISE,\n", 404 | " return_img_per_step=False,\n", 405 | " )\n", 406 | " generated_images = generated_images.numpy()\n", 407 | " display(\n", 408 | " generated_images,\n", 409 | " save_to=\"./output/generated_img_%03d.png\" % (epoch),\n", 410 | " )\n", 411 | "\n", 412 | " example_images = tf.concat(\n", 413 | " random.choices(ebm.buffer.examples, k=10), axis=0\n", 414 | " )\n", 415 | " example_images = example_images.numpy()\n", 416 | " display(\n", 417 | " example_images, save_to=\"./output/example_img_%03d.png\" % (epoch)\n", 418 | " )\n", 419 | "\n", 420 | "\n", 421 | "image_generator_callback = ImageGenerator(num_img=10)" 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": null, 427 | "id": "627c1387-f29a-4cce-85a8-0903c1890e23", 428 | "metadata": {}, 429 | "outputs": [], 430 | "source": [ 431 | "class SaveModel(callbacks.Callback):\n", 432 | " def on_epoch_end(self, epoch, logs=None):\n", 433 | " model.save_weights(\"./models/model.h5\")\n", 434 | "\n", 435 | "\n", 436 | "save_model_callback = SaveModel()" 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": null, 442 | "id": "bd6a5a71-eb55-4ec0-9c8c-cb11a382ff90", 443 | "metadata": { 444 | "scrolled": true, 445 | "tags": [] 446 | }, 447 | "outputs": [], 448 | "source": [ 449 | "ebm.fit(\n", 450 | " x_train,\n", 451 | " shuffle=True,\n", 452 | " epochs=60,\n", 453 | " validation_data=x_test,\n", 454 | " callbacks=[\n", 455 | " save_model_callback,\n", 456 | " tensorboard_callback,\n", 457 | " image_generator_callback,\n", 458 | " ],\n", 459 | ")" 460 | ] 461 | }, 462 | { 463 | "cell_type": "markdown", 464 | "id": "fb1f295f-ade0-4040-a6a5-a7b428b08ebc", 465 | "metadata": {}, 466 | "source": [ 467 | "## 4. Generate images " 468 | ] 469 | }, 470 | { 471 | "cell_type": "code", 472 | "execution_count": null, 473 | "id": "8db3cfe3-339e-463d-8af5-fbd403385fca", 474 | "metadata": {}, 475 | "outputs": [], 476 | "source": [ 477 | "start_imgs = (\n", 478 | " np.random.uniform(size=(10, IMAGE_SIZE, IMAGE_SIZE, CHANNELS)) * 2 - 1\n", 479 | ")" 480 | ] 481 | }, 482 | { 483 | "cell_type": "code", 484 | "execution_count": null, 485 | "id": "80087297-3f47-4e0c-ac89-8758d4386d7c", 486 | "metadata": {}, 487 | "outputs": [], 488 | "source": [ 489 | "display(start_imgs)" 490 | ] 491 | }, 492 | { 493 | "cell_type": "code", 494 | "execution_count": null, 495 | "id": "eaf4b749-5f6e-4a12-863f-b0bbcd23549c", 496 | "metadata": { 497 | "scrolled": true, 498 | "tags": [] 499 | }, 500 | "outputs": [], 501 | "source": [ 502 | "gen_img = generate_samples(\n", 503 | " ebm.model,\n", 504 | " start_imgs,\n", 505 | " steps=1000,\n", 506 | " step_size=STEP_SIZE,\n", 507 | " noise=NOISE,\n", 508 | " return_img_per_step=True,\n", 509 | ")" 510 | ] 511 | }, 512 | { 513 | "cell_type": "code", 514 | "execution_count": null, 515 | "id": "eac707f6-0597-499c-9a52-7cade6724795", 516 | "metadata": { 517 | "tags": [] 518 | }, 519 | "outputs": [], 520 | "source": [ 521 | "display(gen_img[-1].numpy())" 522 | ] 523 | }, 524 | { 525 | "cell_type": "code", 526 | "execution_count": null, 527 | "id": "8476aaa1-e0e7-44dc-a1fd-cc30344b8dcb", 528 | "metadata": {}, 529 | "outputs": [], 530 | "source": [ 531 | "imgs = []\n", 532 | "for i in [0, 1, 3, 5, 10, 30, 50, 100, 300, 999]:\n", 533 | " imgs.append(gen_img[i].numpy()[6])\n", 534 | "\n", 535 | "display(np.array(imgs))" 536 | ] 537 | }, 538 | { 539 | "cell_type": "code", 540 | "execution_count": null, 541 | "id": "3e05d5a0-e124-40d1-80ed-552260c6e350", 542 | "metadata": {}, 543 | "outputs": [], 544 | "source": [] 545 | } 546 | ], 547 | "metadata": { 548 | "kernelspec": { 549 | "display_name": "Python 3 (ipykernel)", 550 | "language": "python", 551 | "name": "python3" 552 | }, 553 | "language_info": { 554 | "codemirror_mode": { 555 | "name": "ipython", 556 | "version": 3 557 | }, 558 | "file_extension": ".py", 559 | "mimetype": "text/x-python", 560 | "name": "python", 561 | "nbconvert_exporter": "python", 562 | "pygments_lexer": "ipython3", 563 | "version": "3.8.10" 564 | } 565 | }, 566 | "nbformat": 4, 567 | "nbformat_minor": 5 568 | } 569 | -------------------------------------------------------------------------------- /notebooks/07_ebm/01_ebm/logs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/07_ebm/01_ebm/models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/07_ebm/01_ebm/output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/08_diffusion/01_ddm/checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/08_diffusion/01_ddm/logs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/08_diffusion/01_ddm/models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/08_diffusion/01_ddm/output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/09_transformer/gpt/checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/09_transformer/gpt/logs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/09_transformer/gpt/models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/09_transformer/gpt/output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/11_music/01_transformer/checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/11_music/01_transformer/logs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/11_music/01_transformer/models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/11_music/01_transformer/output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/11_music/01_transformer/parsed_data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/11_music/01_transformer/transformer_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle as pkl 3 | import music21 4 | import keras 5 | import tensorflow as tf 6 | 7 | from fractions import Fraction 8 | 9 | 10 | def parse_midi_files(file_list, parser, seq_len, parsed_data_path=None): 11 | notes_list = [] 12 | duration_list = [] 13 | notes = [] 14 | durations = [] 15 | 16 | for i, file in enumerate(file_list): 17 | print(i + 1, "Parsing %s" % file) 18 | score = parser.parse(file).chordify() 19 | 20 | notes.append("START") 21 | durations.append("0.0") 22 | 23 | for element in score.flat: 24 | note_name = None 25 | duration_name = None 26 | 27 | if isinstance(element, music21.key.Key): 28 | note_name = str(element.tonic.name) + ":" + str(element.mode) 29 | duration_name = "0.0" 30 | 31 | elif isinstance(element, music21.meter.TimeSignature): 32 | note_name = str(element.ratioString) + "TS" 33 | duration_name = "0.0" 34 | 35 | elif isinstance(element, music21.chord.Chord): 36 | note_name = element.pitches[-1].nameWithOctave 37 | duration_name = str(element.duration.quarterLength) 38 | 39 | elif isinstance(element, music21.note.Rest): 40 | note_name = str(element.name) 41 | duration_name = str(element.duration.quarterLength) 42 | 43 | elif isinstance(element, music21.note.Note): 44 | note_name = str(element.nameWithOctave) 45 | duration_name = str(element.duration.quarterLength) 46 | 47 | if note_name and duration_name: 48 | notes.append(note_name) 49 | durations.append(duration_name) 50 | print(f"{len(notes)} notes parsed") 51 | 52 | notes_list = [] 53 | duration_list = [] 54 | 55 | print(f"Building sequences of length {seq_len}") 56 | for i in range(len(notes) - seq_len): 57 | notes_list.append(" ".join(notes[i : (i + seq_len)])) 58 | duration_list.append(" ".join(durations[i : (i + seq_len)])) 59 | 60 | if parsed_data_path: 61 | with open(os.path.join(parsed_data_path, "notes"), "wb") as f: 62 | pkl.dump(notes_list, f) 63 | with open(os.path.join(parsed_data_path, "durations"), "wb") as f: 64 | pkl.dump(duration_list, f) 65 | 66 | return notes_list, duration_list 67 | 68 | 69 | def load_parsed_files(parsed_data_path): 70 | with open(os.path.join(parsed_data_path, "notes"), "rb") as f: 71 | notes = pkl.load(f) 72 | with open(os.path.join(parsed_data_path, "durations"), "rb") as f: 73 | durations = pkl.load(f) 74 | return notes, durations 75 | 76 | 77 | def get_midi_note(sample_note, sample_duration): 78 | new_note = None 79 | 80 | if "TS" in sample_note: 81 | new_note = music21.meter.TimeSignature(sample_note.split("TS")[0]) 82 | 83 | elif "major" in sample_note or "minor" in sample_note: 84 | tonic, mode = sample_note.split(":") 85 | new_note = music21.key.Key(tonic, mode) 86 | 87 | elif sample_note == "rest": 88 | new_note = music21.note.Rest() 89 | new_note.duration = music21.duration.Duration( 90 | float(Fraction(sample_duration)) 91 | ) 92 | new_note.storedInstrument = music21.instrument.Violoncello() 93 | 94 | elif "." in sample_note: 95 | notes_in_chord = sample_note.split(".") 96 | chord_notes = [] 97 | for current_note in notes_in_chord: 98 | n = music21.note.Note(current_note) 99 | n.duration = music21.duration.Duration( 100 | float(Fraction(sample_duration)) 101 | ) 102 | n.storedInstrument = music21.instrument.Violoncello() 103 | chord_notes.append(n) 104 | new_note = music21.chord.Chord(chord_notes) 105 | 106 | elif sample_note == "rest": 107 | new_note = music21.note.Rest() 108 | new_note.duration = music21.duration.Duration( 109 | float(Fraction(sample_duration)) 110 | ) 111 | new_note.storedInstrument = music21.instrument.Violoncello() 112 | 113 | elif sample_note != "START": 114 | new_note = music21.note.Note(sample_note) 115 | new_note.duration = music21.duration.Duration( 116 | float(Fraction(sample_duration)) 117 | ) 118 | new_note.storedInstrument = music21.instrument.Violoncello() 119 | 120 | return new_note 121 | 122 | 123 | class SinePositionEncoding(keras.layers.Layer): 124 | """Sinusoidal positional encoding layer. 125 | This layer calculates the position encoding as a mix of sine and cosine 126 | functions with geometrically increasing wavelengths. Defined and formulized 127 | in [Attention is All You Need](https://arxiv.org/abs/1706.03762). 128 | Takes as input an embedded token tensor. The input must have shape 129 | [batch_size, sequence_length, feature_size]. This layer will return a 130 | positional encoding the same size as the embedded token tensor, which 131 | can be added directly to the embedded token tensor. 132 | Args: 133 | max_wavelength: The maximum angular wavelength of the sine/cosine 134 | curves, as described in Attention is All You Need. Defaults to 135 | 10000. 136 | Examples: 137 | ```python 138 | # create a simple embedding layer with sinusoidal positional encoding 139 | seq_len = 100 140 | vocab_size = 1000 141 | embedding_dim = 32 142 | inputs = keras.Input((seq_len,), dtype=tf.float32) 143 | embedding = keras.layers.Embedding( 144 | input_dim=vocab_size, output_dim=embedding_dim 145 | )(inputs) 146 | positional_encoding = keras_nlp.layers.SinePositionEncoding()(embedding) 147 | outputs = embedding + positional_encoding 148 | ``` 149 | References: 150 | - [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762) 151 | """ 152 | 153 | def __init__( 154 | self, 155 | max_wavelength=10000, 156 | **kwargs, 157 | ): 158 | super().__init__(**kwargs) 159 | self.max_wavelength = max_wavelength 160 | 161 | def call(self, inputs): 162 | # TODO(jbischof): replace `hidden_size` with`hidden_dim` for consistency 163 | # with other layers. 164 | input_shape = tf.shape(inputs) 165 | # length of sequence is the second last dimension of the inputs 166 | seq_length = input_shape[-2] 167 | hidden_size = input_shape[-1] 168 | position = tf.cast(tf.range(seq_length), self.compute_dtype) 169 | min_freq = tf.cast(1 / self.max_wavelength, dtype=self.compute_dtype) 170 | timescales = tf.pow( 171 | min_freq, 172 | tf.cast(2 * (tf.range(hidden_size) // 2), self.compute_dtype) 173 | / tf.cast(hidden_size, self.compute_dtype), 174 | ) 175 | angles = tf.expand_dims(position, 1) * tf.expand_dims(timescales, 0) 176 | # even indices are sine, odd are cosine 177 | cos_mask = tf.cast(tf.range(hidden_size) % 2, self.compute_dtype) 178 | sin_mask = 1 - cos_mask 179 | # embedding shape is [seq_length, hidden_size] 180 | positional_encodings = ( 181 | tf.sin(angles) * sin_mask + tf.cos(angles) * cos_mask 182 | ) 183 | 184 | return tf.broadcast_to(positional_encodings, input_shape) 185 | 186 | def get_config(self): 187 | config = super().get_config() 188 | config.update( 189 | { 190 | "max_wavelength": self.max_wavelength, 191 | } 192 | ) 193 | return config 194 | -------------------------------------------------------------------------------- /notebooks/11_music/02_musegan/checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/11_music/02_musegan/logs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/11_music/02_musegan/models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/11_music/02_musegan/musegan_utils.py: -------------------------------------------------------------------------------- 1 | import music21 2 | import numpy as np 3 | from matplotlib import pyplot as plt 4 | 5 | 6 | def binarise_output(output): 7 | # output is a set of scores: [batch size , steps , pitches , tracks] 8 | max_pitches = np.argmax(output, axis=3) 9 | return max_pitches 10 | 11 | 12 | def notes_to_midi(output, n_bars, n_tracks, n_steps_per_bar, filename): 13 | for score_num in range(len(output)): 14 | max_pitches = binarise_output(output) 15 | midi_note_score = max_pitches[score_num].reshape( 16 | [n_bars * n_steps_per_bar, n_tracks] 17 | ) 18 | parts = music21.stream.Score() 19 | parts.append(music21.tempo.MetronomeMark(number=66)) 20 | for i in range(n_tracks): 21 | last_x = int(midi_note_score[:, i][0]) 22 | s = music21.stream.Part() 23 | dur = 0 24 | for idx, x in enumerate(midi_note_score[:, i]): 25 | x = int(x) 26 | if (x != last_x or idx % 4 == 0) and idx > 0: 27 | n = music21.note.Note(last_x) 28 | n.duration = music21.duration.Duration(dur) 29 | s.append(n) 30 | dur = 0 31 | last_x = x 32 | dur = dur + 0.25 33 | n = music21.note.Note(last_x) 34 | n.duration = music21.duration.Duration(dur) 35 | s.append(n) 36 | parts.append(s) 37 | parts.write( 38 | "midi", fp="./output/{}_{}.midi".format(filename, score_num) 39 | ) 40 | 41 | 42 | def draw_bar(data, score_num, bar, part): 43 | plt.imshow( 44 | data[score_num, bar, :, :, part].transpose([1, 0]), 45 | origin="lower", 46 | cmap="Greys", 47 | vmin=-1, 48 | vmax=1, 49 | ) 50 | 51 | 52 | def draw_score(data, score_num): 53 | n_bars = data.shape[1] 54 | n_tracks = data.shape[-1] 55 | 56 | fig, axes = plt.subplots( 57 | ncols=n_bars, nrows=n_tracks, figsize=(12, 8), sharey=True, sharex=True 58 | ) 59 | fig.subplots_adjust(0, 0, 0.2, 1.5, 0, 0) 60 | 61 | for bar in range(n_bars): 62 | for track in range(n_tracks): 63 | if n_bars > 1: 64 | axes[track, bar].imshow( 65 | data[score_num, bar, :, :, track].transpose([1, 0]), 66 | origin="lower", 67 | cmap="Greys", 68 | ) 69 | else: 70 | axes[track].imshow( 71 | data[score_num, bar, :, :, track].transpose([1, 0]), 72 | origin="lower", 73 | cmap="Greys", 74 | ) 75 | -------------------------------------------------------------------------------- /notebooks/11_music/02_musegan/output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /notebooks/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | 4 | def sample_batch(dataset): 5 | batch = dataset.take(1).get_single_element() 6 | if isinstance(batch, tuple): 7 | batch = batch[0] 8 | return batch.numpy() 9 | 10 | 11 | def display( 12 | images, n=10, size=(20, 3), cmap="gray_r", as_type="float32", save_to=None 13 | ): 14 | """ 15 | Displays n random images from each one of the supplied arrays. 16 | """ 17 | if images.max() > 1.0: 18 | images = images / 255.0 19 | elif images.min() < 0.0: 20 | images = (images + 1.0) / 2.0 21 | 22 | plt.figure(figsize=size) 23 | for i in range(n): 24 | _ = plt.subplot(1, n, i + 1) 25 | plt.imshow(images[i].astype(as_type), cmap=cmap) 26 | plt.axis("off") 27 | 28 | if save_to: 29 | plt.savefig(save_to) 30 | print(f"\nSaved to {save_to}") 31 | 32 | plt.show() 33 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.6.3 2 | jupyterlab==3.5.3 3 | scipy==1.10.0 4 | scikit-learn==1.2.1 5 | scikit-image==0.19.3 6 | pandas==1.5.3 7 | music21==8.1.0 8 | black[jupyter]==23.1.0 9 | click==8.0.4 10 | flake8==5.0.4 11 | kaggle==1.5.12 12 | pydot==1.4.2 13 | ipywidgets==8.0.4 14 | tensorflow==2.10.1 15 | tensorboard==2.10.1 16 | tensorflow_probability==0.18.0 17 | flake8-nb==0.5.2 -------------------------------------------------------------------------------- /sample.env: -------------------------------------------------------------------------------- 1 | JUPYTER_PORT=8888 2 | TENSORBOARD_PORT=6006 3 | KAGGLE_USERNAME= 4 | KAGGLE_KEY= -------------------------------------------------------------------------------- /scripts/download.sh: -------------------------------------------------------------------------------- 1 | DATASET=$1 2 | 3 | if [ $DATASET = "faces" ] 4 | then 5 | source scripts/downloaders/download_kaggle_data.sh jessicali9530 celeba-dataset 6 | elif [ $DATASET = "bricks" ] 7 | then 8 | source scripts/downloaders/download_kaggle_data.sh joosthazelzet lego-brick-images 9 | elif [ $DATASET = "recipes" ] 10 | then 11 | source scripts/downloaders/download_kaggle_data.sh hugodarwood epirecipes 12 | elif [ $DATASET = "flowers" ] 13 | then 14 | source scripts/downloaders/download_kaggle_data.sh nunenuh pytorch-challange-flower-dataset 15 | elif [ $DATASET = "wines" ] 16 | then 17 | source scripts/downloaders/download_kaggle_data.sh zynicide wine-reviews 18 | elif [ $DATASET = "cellosuites" ] 19 | then 20 | source scripts/downloaders/download_bach_cello_data.sh 21 | elif [ $DATASET = "chorales" ] 22 | then 23 | source scripts/downloaders/download_bach_chorale_data.sh 24 | else 25 | echo "Invalid dataset name - please choose from: faces, bricks, recipes, flowers, wines, cellosuites, chorales" 26 | fi 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /scripts/downloaders/download_bach_cello_data.sh: -------------------------------------------------------------------------------- 1 | docker compose exec app bash -c " 2 | mkdir /app/data/bach-cello/ 3 | cd /app/data/bach-cello/ && 4 | echo 'Downloading...' && 5 | curl -O http://www.jsbach.net/midi/cs1-1pre.mid -s && 6 | curl -O http://www.jsbach.net/midi/cs1-2all.mid -s && 7 | curl -O http://www.jsbach.net/midi/cs1-3cou.mid -s && 8 | curl -O http://www.jsbach.net/midi/cs1-4sar.mid -s && 9 | curl -O http://www.jsbach.net/midi/cs1-5men.mid -s && 10 | curl -O http://www.jsbach.net/midi/cs1-6gig.mid -s && 11 | curl -O http://www.jsbach.net/midi/cs2-1pre.mid -s && 12 | curl -O http://www.jsbach.net/midi/cs2-2all.mid -s && 13 | curl -O http://www.jsbach.net/midi/cs2-3cou.mid -s && 14 | curl -O http://www.jsbach.net/midi/cs2-4sar.mid -s && 15 | curl -O http://www.jsbach.net/midi/cs2-5men.mid -s && 16 | curl -O http://www.jsbach.net/midi/cs2-6gig.mid -s && 17 | curl -O http://www.jsbach.net/midi/cs3-1pre.mid -s && 18 | curl -O http://www.jsbach.net/midi/cs3-2all.mid -s && 19 | curl -O http://www.jsbach.net/midi/cs3-3cou.mid -s && 20 | curl -O http://www.jsbach.net/midi/cs3-4sar.mid -s && 21 | curl -O http://www.jsbach.net/midi/cs3-5bou.mid -s && 22 | curl -O http://www.jsbach.net/midi/cs3-6gig.mid -s && 23 | curl -O http://www.jsbach.net/midi/cs4-1pre.mid -s && 24 | curl -O http://www.jsbach.net/midi/cs4-2all.mid -s && 25 | curl -O http://www.jsbach.net/midi/cs4-3cou.mid -s && 26 | curl -O http://www.jsbach.net/midi/cs4-4sar.mid -s && 27 | curl -O http://www.jsbach.net/midi/cs4-5bou.mid -s && 28 | curl -O http://www.jsbach.net/midi/cs4-6gig.mid -s && 29 | curl -O http://www.jsbach.net/midi/cs5-1pre.mid -s && 30 | curl -O http://www.jsbach.net/midi/cs5-2all.mid -s && 31 | curl -O http://www.jsbach.net/midi/cs5-3cou.mid -s && 32 | curl -O http://www.jsbach.net/midi/cs5-4sar.mid -s && 33 | curl -O http://www.jsbach.net/midi/cs5-5gav.mid -s && 34 | curl -O http://www.jsbach.net/midi/cs5-6gig.mid -s && 35 | curl -O http://www.jsbach.net/midi/cs6-1pre.mid -s && 36 | curl -O http://www.jsbach.net/midi/cs6-2all.mid -s && 37 | curl -O http://www.jsbach.net/midi/cs6-3cou.mid -s && 38 | curl -O http://www.jsbach.net/midi/cs6-4sar.mid -s && 39 | curl -O http://www.jsbach.net/midi/cs6-5gav.mid -s && 40 | curl -O http://www.jsbach.net/midi/cs6-6gig.mid -s && 41 | echo '🚀 Done!' 42 | " -------------------------------------------------------------------------------- /scripts/downloaders/download_bach_chorale_data.sh: -------------------------------------------------------------------------------- 1 | docker compose exec app bash -c " 2 | mkdir /app/data/bach-chorales/ 3 | cd /app/data/bach-chorales/ && 4 | echo 'Downloading...' && 5 | curl -LO https://github.com/czhuang/JSB-Chorales-dataset/raw/master/Jsb16thSeparated.npz -s && 6 | echo '🚀 Done!' 7 | " 8 | -------------------------------------------------------------------------------- /scripts/downloaders/download_kaggle_data.sh: -------------------------------------------------------------------------------- 1 | USER=$1 2 | DATASET=$2 3 | 4 | docker compose exec app bash -c "cd /app/data/ && kaggle datasets download -d $USER/$DATASET && echo 'Unzipping...' && unzip -q -o /app/data/$DATASET.zip -d /app/data/$DATASET && rm /app/data/$DATASET.zip && echo '🚀 Done!'" 5 | 6 | -------------------------------------------------------------------------------- /scripts/format.sh: -------------------------------------------------------------------------------- 1 | docker compose exec app black -l 80 . 2 | docker compose exec app flake8 --max-line-length=80 --exclude=./data --ignore=W503,E203,E402 3 | docker compose exec app flake8_nb --max-line-length=80 --exclude=./data --ignore=W503,E203,E402 -------------------------------------------------------------------------------- /scripts/tensorboard.sh: -------------------------------------------------------------------------------- 1 | CHAPTER=$1 2 | EXAMPLE=$2 3 | echo "Attaching Tensorboard to ./notebooks/$CHAPTER/$EXAMPLE/logs/" 4 | docker compose exec -e CHAPTER=$CHAPTER -e EXAMPLE=$EXAMPLE app bash -c 'tensorboard --logdir "./notebooks/$CHAPTER/$EXAMPLE/logs" --host 0.0.0.0 --port $TENSORBOARD_PORT' --------------------------------------------------------------------------------