├── .dockerignore
├── .gitignore
├── LICENCE
├── README.md
├── data
└── .gitignore
├── docker-compose.gpu.yml
├── docker-compose.yml
├── docker
├── Dockerfile
└── Dockerfile.gpu
├── docs
├── docker.md
├── googlecloud.md
├── history.csv
├── model_sizes.png
└── timeline.png
├── img
├── book_cover.png
└── structure.png
├── notebooks
├── 02_deeplearning
│ ├── 01_mlp
│ │ ├── checkpoint
│ │ │ └── .gitignore
│ │ ├── logs
│ │ │ └── .gitignore
│ │ ├── mlp.ipynb
│ │ ├── models
│ │ │ └── .gitignore
│ │ └── output
│ │ │ └── .gitignore
│ └── 02_cnn
│ │ ├── checkpoint
│ │ └── .gitignore
│ │ ├── cnn.ipynb
│ │ ├── convolutions.ipynb
│ │ ├── logs
│ │ └── .gitignore
│ │ ├── models
│ │ └── .gitignore
│ │ └── output
│ │ └── .gitignore
├── 03_vae
│ ├── 01_autoencoder
│ │ ├── autoencoder.ipynb
│ │ ├── checkpoint
│ │ │ └── .gitignore
│ │ ├── logs
│ │ │ └── .gitignore
│ │ ├── models
│ │ │ └── .gitignore
│ │ └── output
│ │ │ └── .gitignore
│ ├── 02_vae_fashion
│ │ ├── checkpoint
│ │ │ └── .gitignore
│ │ ├── logs
│ │ │ └── .gitignore
│ │ ├── models
│ │ │ └── .gitignore
│ │ ├── output
│ │ │ └── .gitignore
│ │ └── vae_fashion.ipynb
│ └── 03_vae_faces
│ │ ├── checkpoint
│ │ └── .gitignore
│ │ ├── logs
│ │ └── .gitignore
│ │ ├── models
│ │ └── .gitignore
│ │ ├── output
│ │ └── .gitignore
│ │ ├── vae_faces.ipynb
│ │ └── vae_utils.py
├── 04_gan
│ ├── 01_dcgan
│ │ ├── checkpoint
│ │ │ └── .gitignore
│ │ ├── dcgan.ipynb
│ │ ├── logs
│ │ │ └── .gitignore
│ │ ├── models
│ │ │ └── .gitignore
│ │ └── output
│ │ │ └── .gitignore
│ ├── 02_wgan_gp
│ │ ├── checkpoint
│ │ │ └── .gitignore
│ │ ├── logs
│ │ │ └── .gitignore
│ │ ├── models
│ │ │ └── .gitignore
│ │ ├── output
│ │ │ └── .gitignore
│ │ └── wgan_gp.ipynb
│ └── 03_cgan
│ │ ├── cgan.ipynb
│ │ ├── checkpoint
│ │ └── .gitignore
│ │ ├── logs
│ │ └── .gitignore
│ │ ├── models
│ │ └── .gitignore
│ │ └── output
│ │ └── .gitignore
├── 05_autoregressive
│ ├── 01_lstm
│ │ ├── checkpoint
│ │ │ └── .gitignore
│ │ ├── logs
│ │ │ └── .gitignore
│ │ ├── lstm.ipynb
│ │ ├── models
│ │ │ └── .gitignore
│ │ └── output
│ │ │ └── .gitignore
│ ├── 02_pixelcnn
│ │ ├── checkpoint
│ │ │ └── .gitignore
│ │ ├── logs
│ │ │ └── .gitignore
│ │ ├── models
│ │ │ └── .gitignore
│ │ ├── output
│ │ │ └── .gitignore
│ │ └── pixelcnn.ipynb
│ └── 03_pixelcnn_md
│ │ ├── checkpoint
│ │ └── .gitignore
│ │ ├── logs
│ │ └── .gitignore
│ │ ├── models
│ │ └── .gitignore
│ │ ├── output
│ │ └── .gitignore
│ │ └── pixelcnn_md.ipynb
├── 06_normflow
│ └── 01_realnvp
│ │ ├── checkpoint
│ │ └── .gitignore
│ │ ├── logs
│ │ └── .gitignore
│ │ ├── models
│ │ └── .gitignore
│ │ ├── output
│ │ └── .gitignore
│ │ └── realnvp.ipynb
├── 07_ebm
│ └── 01_ebm
│ │ ├── checkpoint
│ │ └── .gitignore
│ │ ├── ebm.ipynb
│ │ ├── logs
│ │ └── .gitignore
│ │ ├── models
│ │ └── .gitignore
│ │ └── output
│ │ └── .gitignore
├── 08_diffusion
│ └── 01_ddm
│ │ ├── checkpoint
│ │ └── .gitignore
│ │ ├── ddm.ipynb
│ │ ├── logs
│ │ └── .gitignore
│ │ ├── models
│ │ └── .gitignore
│ │ ├── output
│ │ └── .gitignore
│ │ └── sinusoidal_embedding.ipynb
├── 09_transformer
│ └── gpt
│ │ ├── checkpoint
│ │ └── .gitignore
│ │ ├── gpt.ipynb
│ │ ├── logs
│ │ └── .gitignore
│ │ ├── models
│ │ └── .gitignore
│ │ └── output
│ │ └── .gitignore
├── 11_music
│ ├── 01_transformer
│ │ ├── checkpoint
│ │ │ └── .gitignore
│ │ ├── logs
│ │ │ └── .gitignore
│ │ ├── models
│ │ │ └── .gitignore
│ │ ├── output
│ │ │ └── .gitignore
│ │ ├── parsed_data
│ │ │ └── .gitignore
│ │ ├── transformer.ipynb
│ │ └── transformer_utils.py
│ └── 02_musegan
│ │ ├── checkpoint
│ │ └── .gitignore
│ │ ├── logs
│ │ └── .gitignore
│ │ ├── models
│ │ └── .gitignore
│ │ ├── musegan.ipynb
│ │ ├── musegan_utils.py
│ │ └── output
│ │ └── .gitignore
└── utils.py
├── requirements.txt
├── sample.env
└── scripts
├── download.sh
├── downloaders
├── download_bach_cello_data.sh
├── download_bach_chorale_data.sh
└── download_kaggle_data.sh
├── format.sh
└── tensorboard.sh
/.dockerignore:
--------------------------------------------------------------------------------
1 | data/
2 | .git/
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | docs/history.twbx
2 |
3 | .DS_Store
4 |
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | target/
75 |
76 | # Jupyter Notebook
77 | .ipynb_checkpoints
78 |
79 | # pyenv
80 | .python-version
81 |
82 | # celery beat schedule file
83 | celerybeat-schedule
84 |
85 | # SageMath parsed files
86 | *.sage.py
87 |
88 | # Environments
89 | .env
90 | .venv
91 | env/
92 | venv/
93 | ENV/
94 | env.bak/
95 | venv.bak/
96 |
97 | # Spyder project settings
98 | .spyderproject
99 | .spyproject
100 |
101 | # Rope project settings
102 | .ropeproject
103 |
104 | # mkdocs documentation
105 | /site
106 |
107 | # mypy
108 | .mypy_cache/
109 |
110 |
--------------------------------------------------------------------------------
/LICENCE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🦜 Generative Deep Learning - 2nd Edition Codebase
2 |
3 | The official code repository for the second edition of the O'Reilly book *Generative Deep Learning: Teaching Machines to Paint, Write, Compose and Play*.
4 |
5 | [O'Reilly link](https://www.oreilly.com/library/view/generative-deep-learning/9781098134174/)
6 |
7 | [Amazon US link](https://www.amazon.com/Generative-Deep-Learning-Teaching-Machines/dp/1098134184/)
8 |
9 |
10 |
11 | ## 📖 Book Chapters
12 |
13 | Below is a outline of the book chapters:
14 |
15 | *Part I: Introduction to Generative Deep Learning*
16 |
17 | 1. Generative Modeling
18 | 2. Deep Learning
19 |
20 | *Part II: Methods*
21 |
22 | 3. Variational Autoencoders
23 | 4. Generative Adversarial Networks
24 | 5. Autoregressive Models
25 | 6. Normalizing Flows
26 | 7. Energy-Based Models
27 | 8. Diffusion Models
28 |
29 | *Part III: Applications*
30 |
31 | 9. Transformers
32 | 10. Advanced GANs
33 | 11. Music Generation
34 | 12. World Models
35 | 13. Multimodal Models
36 | 14. Conclusion
37 |
38 | ## 🌟 Star History
39 |
40 |
41 |
42 | ## 🚀 Getting Started
43 |
44 | ### Kaggle API
45 |
46 | To download some of the datasets for the book, you will need a Kaggle account and an API token. To use the Kaggle API:
47 |
48 | 1. Sign up for a [Kaggle account](https://www.kaggle.com).
49 | 2. Go to the 'Account' tab of your user profile
50 | 3. Select 'Create API Token'. This will trigger the download of `kaggle.json`, a file containing your API credentials.
51 |
52 | ### The .env file
53 |
54 | Create a file called `.env` in the root directory, containing the following values (replacing the Kaggle username and API key with the values from the JSON):
55 |
56 | ```
57 | JUPYTER_PORT=8888
58 | TENSORBOARD_PORT=6006
59 | KAGGLE_USERNAME=
60 | KAGGLE_KEY=
61 | ```
62 |
63 | ### Get set up with Docker
64 |
65 | This codebase is designed to be run with [Docker](https://docs.docker.com/get-docker/).
66 |
67 | If you've never used Docker before, don't worry! I have included a guide to Docker in the [Docker README](./docs/docker.md) file in this repository. This includes a full run through of why Docker is awesome and a brief guide to the `Dockerfile` and `docker-compose.yml` for this project.
68 |
69 | ### Building the Docker image
70 |
71 | If you do not have a GPU, run the following command:
72 |
73 | ```
74 | docker compose build
75 | ```
76 |
77 | If you do have a GPU that you wish to use, run the following command:
78 |
79 | ```
80 | docker compose -f docker-compose.gpu.yml build
81 | ```
82 |
83 | ### Running the container
84 |
85 | If you do not have a GPU, run the following command:
86 |
87 | ```
88 | docker compose up
89 | ```
90 |
91 | If you do have a GPU that you wish to use, run the following command:
92 |
93 | ```
94 | docker compose -f docker-compose.gpu.yml up
95 | ```
96 |
97 | Jupyter will be available in your local browser, on the port specified in your env file - for example
98 |
99 | ```
100 | http://localhost:8888
101 | ```
102 |
103 | The notebooks that accompany the book are available in the `/notebooks` folder, organized by chapter and example.
104 |
105 | ## 🏞️ Downloading data
106 |
107 | The codebase comes with an in-built data downloader helper script.
108 |
109 | Run the data downloader as follows (from outside the container), choosing one of the named datasets below:
110 |
111 | ```
112 | bash scripts/download.sh [faces, bricks, recipes, flowers, wines, cellosuites, chorales]
113 | ```
114 |
115 | ## 📈 Tensorboard
116 |
117 | Tensorboard is really useful for monitoring experiments and seeing how your model training is progressing.
118 |
119 | To launch Tensorboard, run the following script (from outside the container):
120 | * `` - the required chapter (e.g. `03_vae`)
121 | * `` - the required example (e.g. `02_vae_fashion`)
122 |
123 | ```
124 | bash scripts/tensorboard.sh
125 | ```
126 |
127 | Tensorboard will be available in your local browser on the port specified in your `.env` file - for example:
128 | ```
129 | http://localhost:6006
130 | ```
131 |
132 | ## ☁️ Using a cloud virtual machine
133 |
134 | To set up a virtual machine with GPU in Google Cloud Platform, follow the instructions in the [Google Cloud README](./docs/googlecloud.md) file in this repository.
135 |
136 | ## 📦 Other resources
137 |
138 | Some of the examples in this book are adapted from the excellent open source implementations that are available through the [Keras website](https://keras.io/examples/generative/). I highly recommend you check out this resource as new models and examples are constantly being added.
139 |
140 |
141 |
--------------------------------------------------------------------------------
/data/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/docker-compose.gpu.yml:
--------------------------------------------------------------------------------
1 | version: '3'
2 |
3 | services:
4 | app:
5 | build:
6 | context: .
7 | dockerfile: ./docker/Dockerfile.gpu
8 | tty: true
9 | volumes:
10 | - ./data/:/app/data
11 | - ./notebooks/:/app/notebooks
12 | - ./scripts/:/app/scripts
13 | deploy:
14 | resources:
15 | reservations:
16 | devices:
17 | - driver: nvidia
18 | count: 1
19 | capabilities: [gpu]
20 | ports:
21 | - "$JUPYTER_PORT:$JUPYTER_PORT"
22 | - "$TENSORBOARD_PORT:$TENSORBOARD_PORT"
23 | env_file:
24 | - ./.env
25 | entrypoint: jupyter lab --ip 0.0.0.0 --port=$JUPYTER_PORT --no-browser --allow-root
26 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: '3'
2 |
3 | services:
4 | app:
5 | build:
6 | context: .
7 | dockerfile: ./docker/Dockerfile
8 | tty: true
9 | volumes:
10 | - ./data/:/app/data
11 | - ./notebooks/:/app/notebooks
12 | - ./scripts/:/app/scripts
13 | ports:
14 | - "$JUPYTER_PORT:$JUPYTER_PORT"
15 | - "$TENSORBOARD_PORT:$TENSORBOARD_PORT"
16 | env_file:
17 | - ./.env
18 | entrypoint: jupyter lab --ip 0.0.0.0 --port=$JUPYTER_PORT --no-browser --allow-root
19 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM ubuntu:20.04
2 |
3 | ENV DEBIAN_FRONTEND noninteractive
4 |
5 | RUN apt-get update
6 | RUN apt-get install -y unzip graphviz curl musescore3 python3-pip
7 |
8 | RUN pip install --upgrade pip
9 |
10 | WORKDIR /app
11 |
12 | COPY ./requirements.txt /app
13 | RUN pip install -r /app/requirements.txt
14 |
15 | # Hack to get around tensorflow-io issue - https://github.com/tensorflow/io/issues/1755
16 | RUN pip install tensorflow-io
17 | RUN pip uninstall -y tensorflow-io
18 |
19 | COPY /notebooks/. /app/notebooks
20 | COPY /scripts/. /app/scripts
21 |
22 | ENV PYTHONPATH="${PYTHONPATH}:/app"
--------------------------------------------------------------------------------
/docker/Dockerfile.gpu:
--------------------------------------------------------------------------------
1 | FROM tensorflow/tensorflow:2.10.1-gpu
2 |
3 | RUN rm /etc/apt/sources.list.d/cuda.list
4 |
5 | RUN apt-get update
6 | RUN apt-get install -y unzip graphviz curl musescore3
7 |
8 | RUN pip install --upgrade pip
9 |
10 | WORKDIR /app
11 |
12 | COPY ./requirements.txt /app
13 | RUN pip install -r /app/requirements.txt
14 |
15 | # Hack to get around tensorflow-io issue - https://github.com/tensorflow/io/issues/1755
16 | RUN pip install tensorflow-io
17 | RUN pip uninstall -y tensorflow-io
18 |
19 | COPY /notebooks/. /app/notebooks
20 | COPY /scripts/. /app/scripts
21 |
22 | ENV PYTHONPATH="${PYTHONPATH}:/app"
--------------------------------------------------------------------------------
/docs/docker.md:
--------------------------------------------------------------------------------
1 | # 🐳 Getting started with Docker
2 |
3 | Before we start, you will need to download Docker Desktop for free from https://www.docker.com/products/docker-desktop/
4 |
5 | In this section, we will explain the reason for using Docker and how to get started.
6 |
7 | The overall structure of the codebase and how Docker can be used to interact with it is shown below.
8 |
9 |
10 |
11 | We will run through everything in the diagram in this `README`.
12 |
13 | ## 🙋♂️ Why is Docker useful?
14 |
15 | A common problem you may have faced when using other people's code is that it simply doesn't run on your machine. This can be especially frustrating, especially when it runs just fine on the developer's machine and they cannot help you with specific bugs that occur when deploying onto your operating system.
16 |
17 | There are two reasons why this might be the case:
18 |
19 | 1. The underlying Python version or packages that you have installed are not identical to the ones that were used during development of the codebase
20 |
21 | 2. There are fundamental differences between your operating system and the developer's operating system that mean that the code does not run in the same way.
22 |
23 | The first of these problems can be mostly solved by using *virtual environments*. The idea here is that you create an environment for each project that contains the required Python interpreter and specific required packages and is separated from all other projects on your machine. For example, you could have one project running Python 3.8 with TensorFlow 2.10 and another running Python 3.7 with TensorFlow 2.8. Both versions of Python and TensorFlow would exist on your machine, but within isolated environments so they cannot conflict with each other.
24 |
25 | Anaconda and Virtualenv are two popular ways to create virtual environments. Typically, codebases are shipped with a `requirements.txt` file that contains the specific Python package versions that you need to run the codebase.
26 |
27 | However, solely using virtual environments does not solve the second problem - how can you ensure that your operating system is set up in the same way as the developer's? For example, the code may have been developed on a Windows machine, but you have a Mac and the errors you are seeing are specific to macOS. Or perhaps the codebase manipulates the filesystem using scripts written in `bash`, which you do not have on your machine.
28 |
29 | Docker solves this problem. Instead of specifying only the required Python packages, the developer specifies a recipe to build what is known as an *image*, which contains everything required to run the app, including an operating system and all dependencies. The recipe is called the `Dockerfile` - it is a bit like the `requirements.txt` files, but for the entire runtime environment rather than just the Python packages.
30 |
31 | Let's first take a look at the Dockerfile for the Generative Deep Learning codebase and see how it contains all the information required to build the image.
32 |
33 | ## 📝 The Dockerfile
34 |
35 | In the codebase that you pulled from GitHub, there is a file simply called 'Dockerfile' inside the `docker` folder. This is the recipe that Docker will use to build the image. We'll walk through line by line, explaining what each step does.
36 |
37 | ```
38 | FROM ubuntu:20.04 #<1>
39 | ENV DEBIAN_FRONTEND noninteractive
40 |
41 | RUN apt-get update #<2>
42 | RUN apt-get install -y unzip graphviz curl musescore3 python3-pip
43 |
44 | RUN pip install --upgrade pip #<3>
45 |
46 | WORKDIR /app #<4>
47 |
48 | COPY ./requirements.txt /app #<5>
49 | RUN pip install -r /app/requirements.txt
50 |
51 | # Hack to get around tensorflow-io issue - https://github.com/tensorflow/io/issues/1755
52 | RUN pip install tensorflow-io
53 | RUN pip uninstall -y tensorflow-io
54 |
55 | COPY /notebooks/. /app/notebooks #<6>
56 | COPY /scripts/. /app/scripts
57 |
58 | ENV PYTHONPATH="${PYTHONPATH}:/app" #<7>
59 | ```
60 |
61 | 1. The first line defines the base image. Our base image is an Ubuntu 20.04 (Linux) operating system. This is pulled from DockerHub - the online store of publicly available images (`https://hub.docker.com/_/ubuntu`).
62 | 2. Update `apt-get`, the Linux package manager and install relevant packages
63 | 3. Upgrade `pip` the Python package manager
64 | 4. Change the working directory to `/app`.
65 | 5. Copy the `requirements.txt` file into the image and use `pip` to install all relevant Python packages
66 | 6. Copy relevant folders into the image
67 | 7. Update the `PYTHONPATH` so that we can import functions that we write from our `/app` directory
68 |
69 | You can see how the Dockerfile can be thought of as a recipe for building a particular run-time environment. The magic of Docker is that you do not need to worry about installing a resource intensive virtual machine on your computer - Docker is lightweight and allows you to build an environment template purely using code.
70 |
71 | A running version of an image is called a *container*. You can think of the image as like a cookie cutter, that can be used to create a particular cookie (the container). There is one other file that we need to look at before we finally get to build our image and run the container - the docker-compose.yaml file.
72 |
73 | ## 🎼 The docker-compose.yaml file
74 |
75 | Docker Compose is an extension to Docker that allows you to define how you would like your containers to run, through a simple YAML file, called 'docker-compose.yaml'.
76 |
77 | For example, you can specify which ports and folders should be mapped between your machine and the container. Folder mapping allows the container to treat folders on your machine as if they were folders inside the container. Therefore any changes you make to mapped files and folders on your machine will be immediately reflected inside the container. Port mapping will forward any traffic on a local port through to the container. For example, we could map port 8888 on your local machine to port 8888 in the container, so that if you visit `localhost:8888` in your browser, you will see Jupyter, which is running on port 8888 inside the container. The ports do not have have to be the same - for example you could map port 8000 to port 8888 in the container if you wanted.
78 |
79 | The alternative to using Docker Compose is specify all of these parameters in the command line when using the `docker run` command. However, this is cumbersome and it is much easier to just use a Docker Compose YAML file. Docker Compose also gives you the ability to specify multiple services that should be run at the same time (for example, a application container and a database), and how they should talk to each other. However, our purposes, we only need a single service - the container that runs Jupyter, so that we can interact with the generative deep learning codebase.
80 |
81 | Let's now take a look at the Docker Compose YAML file.
82 |
83 | ```
84 | version: '3' #<1>
85 | services: #<2>
86 | app: #<3>
87 | build: #<4>
88 | context: .
89 | dockerfile: ./docker/Dockerfile
90 | tty: true #<5>
91 | volumes: #<6>
92 | - ./data/:/app/data
93 | - ./notebooks/:/app/notebooks
94 | - ./scripts/:/app/scripts
95 | ports: #<7>
96 | - "$JUPYTER_PORT:$JUPYTER_PORT"
97 | - "$TENSORBOARD_PORT:$TENSORBOARD_PORT"
98 | env_file: #<8>
99 | - ./.env
100 | entrypoint: jupyter lab --ip 0.0.0.0 --port=$JUPYTER_PORT --no-browser --allow-root #<9>
101 | ```
102 |
103 | 1. This specifies the version of Docker Compose to use (currently version 3)
104 | 2. Here, we specify the services we wish to launch
105 | 3. We only have one service, which we call `app`
106 | 4. Here, we tell Docker where to find the Dockerfile (the same directory as the docker-compose.yaml file)
107 | 5. This allows us to open up an interactive command line inside the container, if we wish
108 | 6. Here, we map folders on our local machine (e.g. ./data), to folders inside the container (e.g. /app/data).
109 | 7. Here, we specify the port mappings - the dollar sign means that it will use the ports as specified in the `.env` file (e.g. `JUPYTER_PORT=8888`)
110 | 8. The location of the `.env` file on your local machine.
111 | 9. The command that should be run when the container runs - here, we run JupyterLab.
112 |
113 | ## 🧱 Building the image and running the container
114 |
115 | We're now at a point where we have everything we need to build our image and run the container. Building the image is simply a case of running the command shown below in your terminal, from the root folder.
116 |
117 | ```
118 | docker compose build
119 | ```
120 |
121 | You should see Docker start to run through the steps in the Dockerfile. You can use the command `docker images` to see a list of the images that have been built on your machine.
122 |
123 | To run the container based on the image we have just created, we use the command shown below:
124 |
125 | ```
126 | docker compose up
127 | ```
128 |
129 | You should see that Docker launches the Jupyter notebook server within the container and provides you with a URL to the running server.
130 |
131 | Because we have mapped port 8888 in the container to port 8888 on your machine, you can simply navigate to the address starting `http://127.0.0.1:8888/lab?token=` into a web browser and you should see the running Jupyter server. The folders that we mapped across in the `docker-compose.yaml` file should be visible on the left hand side.
132 |
133 | Congratulations! You now have a functioning Docker container that you can use to start working through the Generative Deep Learning codebase! To stop running the Jupyter server, you use `Ctrl-C` and to bring down the running container, you use the command `docker compose down`. Because the volumes are mapped, you won't lose any of your work that you save whilst working in the Jupyter notebooks, even if you bring the container down.
134 |
135 | ## ⚡️ Using a GPU
136 |
137 | The default `Dockerfile` and `docker-compose.yaml` file assume that you do not want to use a local GPU to train your models. If you do have a GPU that you wish to use (for example, you are using a cloud VM), I have provided two extra files called `Dockerfile-gpu` and `docker-compose.gpu.yml` files that can be used in place of the default files.
138 |
139 | For example, to build an image that includes support for GPU, use the command shown below:
140 |
141 | ```
142 | docker compose -f docker-compose.gpu.yml build
143 | ```
144 |
145 | To run this image, use the following command:
146 |
147 | ```
148 | docker compose -f docker-compose.gpu.yml up
149 | ```
150 |
--------------------------------------------------------------------------------
/docs/googlecloud.md:
--------------------------------------------------------------------------------
1 | # ⚡️ Setting up a VM with GPU in Google Cloud Platform
2 |
3 | ## ☁️ Google Cloud Platform
4 |
5 | Google Cloud Platform is a public cloud vendor that offers a wide variety of computing services, running on Google infrastructure. It is extremely useful for spinning up resources on demand, that otherwise would require a large upfront investment in hardware and set up time.
6 |
7 | There are many cloud vendors - the largest being Amazon Web Services (AWS), Microsoft Azure and Google Cloud Platform (GCP). Through these cloud platforms, you can spin up servers, databases other services through the online user interface and easily turn them off or tear them down when you are finished using them. This makes cloud computing an excellent choice for anyone who wants to get started with state-of-the-art machine learning, because you can get access to powerful resources at the click of a button, without having to invest significant amounts of money in your own hardware.
8 |
9 | In this book, we will demonstrate how to set up a virtual server with a GPU on Google Cloud Platform. You can use this environment to run the codebase that accompanies this book on accelerated GPU hardware, for faster training.
10 |
11 | ## 🚀 Getting started
12 |
13 | ### Get a Google Cloud Platform Account
14 |
15 | Firstly, you'll need a Google Cloud Console account - visit https://cloud.google.com/cloud-console/ to get set up.
16 |
17 | If you've never used GCP, you get access to $300 free credit over 90 days (correct at time of writing), which is more than enough to run all of the training examples in this book.
18 |
19 | Once you're logged in, the first thing you need to do is create a project. You can call this whatever you like, for example `generative-deep-learning`. Then you have to set up a billing account for the project - if you are using the free trial credits, you will not be automatically billed after the credits are used up.
20 |
21 | You can now go ahead and set up a compute engine with an attached GPU, within this project. You can see which project you are currently in next to the Google Cloud Platform logo in the top left corner of the screen.
22 |
23 | ### Set up a compute engine
24 |
25 | Firstly, search for 'Compute Engine' in the search bar and click 'Create Instance' (or navigate straight to https://console.cloud.google.com/compute/instancesAdd
26 |
27 | You will see a screen where you can select the configuration of your virtual machine (VM). Below, we run through a description of each field and a recommended value to choose.
28 |
29 | | Option | Description | Example Value
30 | | --- | --- | ---
31 | | Name | What you would like to call the compute instance? | gdl-compute
32 | | Region | Where the compute instance should be physically located? | us-central1-c (Iowa)
33 | | Machine family | What type of machine would you like? | GPU
34 | | GPU type | What GPU would you like? | NVIDIA Tesla T4
35 | | Machine type | What specification of machine would you like? | n1-standard-4 (4vCPU, 15 GB memory)
36 | | Boot disk > Operating System| What boot disk OS would you like? | Deep Learning on Linux
37 |
38 | All other options can be left as the default.
39 |
40 | The NVIDIA Tesla T4 is a sufficiently powerful GPU to use to run the examples in this book. Note that other more powerful GPUs are available, but these are more expensive per hour. Choosing the 'Deep Learning on Linux' boot disk means that the NVIDIA CUDA stack will be installed automatically at no extra cost.
41 |
42 | You may wish to select a different region for your machine - price does vary between regions and some regions are not compatible with GPU (see https://cloud.google.com/compute/docs/gpus/gpu-regions-zones for the full list of GPU regions). In the `us-central1-c` region, the whole virtual machine should be priced at around $0.55 per hour (without any sustained use discount applied).
43 |
44 | Click 'Create' to build your virtual machine - this may take a few minutes. You'll be able to see when the machine is ready to access, as there will be a green tick next to the machine name.
45 |
46 | ### Accessing the VM
47 |
48 | The easiest way to access the VM is by using the Google Cloud CLI and Visual Studio Code.
49 |
50 | Firstly, you'll need to install the Google Cloud CLI onto your machine. You can do this by following the instructions at the following link: https://cloud.google.com/sdk/docs/install
51 |
52 | Then, to set up the configuration for the SSH connection to your virtual machine, first run the command below in your local terminal:
53 |
54 | ```
55 | gcloud compute config-ssh
56 | ```
57 |
58 | Then run the command shown below to connect to the virtual machine, replacing `` with the name of your virtual machine (e.g. `gdl_compute`).
59 |
60 | ```
61 | gcloud compute ssh
62 | ```
63 |
64 | The first time you connect, it will ask if you want to install NVIDIA drivers - select 'yes'.
65 |
66 | To start coding on the VM, I recommend installing VSCode from https://code.visualstudio.com/download, which is a modern Interactive Development Environment (IDE) through which you can easily access you virtual machine and manipulate the file system.
67 |
68 | Make sure you have the Remote SSH extension installed (see https://code.visualstudio.com/blogs/2019/10/03/remote-ssh-tips-and-tricks).
69 |
70 | In the bottom left-hand corner of VSCode, you'll see a green box with two arrows - click this and then select 'Connect Current Window to Host'. Then select your virtual machine. VSCode will then connect via SSH. Congratulations - you are now connected to the virtual machine!
71 |
72 | ### Cloning the codebase
73 |
74 | You can then clone the codebase onto the VM:
75 |
76 | ```
77 | git clone https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition.git
78 | ```
79 |
80 | Then in VSCode, if you click 'Open Folder', you can select the folder that you have just cloned into, to see the files in the repository.
81 |
82 | Lastly, you'll need to make sure that Docker Compose is installed on the virtual machine using the commands shown below
83 |
84 | ```
85 | sudo apt-get update
86 | sudo apt-get install docker-compose-plugin
87 | ```
88 |
89 | Now you can build the image and run the container, using the GPU-specific commands discussed in the main `README` of this codebase.
90 |
91 | Congratulations - you're now running the Generative Deep Learning codebase on a cloud virtual machine, with attached GPU!
92 |
93 | If you navigate to `http://127.0.0.1:8888/` in your browser, you should see Jupyter Lab, as VS Code automatically forwards the Jupyter Lab port in the container to the same port on your machine. If not, check that the ports are mapped under the 'Ports' tab in VSCode.
94 |
95 | ## 🚦 Starting and stopping the virtual machine
96 |
97 | When you are not using the virtual machine, it's important to turn it off so that you don't incur unnecessary charges.
98 |
99 | You can do this by selecting 'Stop' in the dropdown menu net to your virtual machine, in GCP. A stopped virtual machine will not be accessible via SSH, but the storage will be persisted, so you will not lose any of your files or progress. The storage will still incur a small charge.
100 |
101 | To turn the virtual machine back on, select 'Start / Resume' from the dropdown menu net to your virtual machine, in GCP. This will take a few minutes. When the green tick appears next to the machine description, you'll be able to access your machine in VS Code as before.
102 |
103 | ### Changing IP address
104 |
105 | When you stop and restart your virtual machine, the external IP address of the server may change. You can see the current external IP address of the virtual machine in the console, as shown in <>. If this happens, you will need to update the external IP address of the aliased SSH connection, which you can do within VS Code.
106 |
107 | Firstly, click on the green SSH connection arrows in the bottom left-hand corner of VS Code and then 'Open SSH Configuration File...'. Select your SSH configuration file and you should see a block of text similar to that shown in <>. This is the text added by Google Cloud CLI to allow you an easy way to connect to your virtual machine.
108 |
109 | To update the external IP address, you just need to change the `HostName` value and save the file.
110 |
111 | ### Destroying the VM
112 |
113 | To completely destroy the machine, select 'Delete' from the dropdown menu net to your virtual machine in GCP. This will completely remove the virtual machine and all attached stored, and there will be no further incurred charges.
114 |
115 | In this section, we have seen how to build a virtual machine with attached GPU in Google Cloud Platform, from scratch. This should allow you to build sophisticated generative deep learning models, trained on sophisticated hardware that you can scale to your requirements.
--------------------------------------------------------------------------------
/docs/history.csv:
--------------------------------------------------------------------------------
1 | Name,Paper Date,paper,hide,type,parameters
2 | LSTM,01/11/1997,http://www.bioinf.jku.at/publications/older/2604.pdf,1,Autoregressive / Transformer,
3 | VAE,20/12/2013,https://arxiv.org/abs/1312.6114,0,Variational Autoencoder,
4 | Encoder Decoder,03/06/2014,https://arxiv.org/abs/1406.1078,1,Autoregressive / Transformer,
5 | GAN,10/06/2014,https://arxiv.org/abs/1406.2661,0,Generative Adversarial Network,
6 | Attention,01/09/2014,https://arxiv.org/abs/1409.0473,1,General,
7 | GRU,03/09/2014,https://arxiv.org/abs/1409.1259,0,Autoregressive / Transformer,
8 | CGAN,06/11/2014,https://arxiv.org/abs/1411.1784,0,Generative Adversarial Network,
9 | Diffusion Process,12/03/2015,https://arxiv.org/abs/1503.03585,1,Energy-Based / Diffusion Models,
10 | UNet,18/05/2015,https://arxiv.org/abs/1505.04597,0,General,
11 | Neural Style,26/08/2015,https://arxiv.org/abs/1508.06576,1,General,
12 | DCGAN,19/11/2015,https://arxiv.org/abs/1511.06434,0,Generative Adversarial Network,
13 | ResNet,10/12/2015,https://arxiv.org/abs/1512.03385,0,General,
14 | VAE-GAN,31/12/2015,https://arxiv.org/abs/1512.09300,0,Variational Autoencoder,
15 | Self Attention,25/01/2016,https://arxiv.org/abs/1601.06733,1,Autoregressive / Transformer,
16 | PixelRNN,25/01/2016,https://arxiv.org/abs/1601.06759,0,Autoregressive / Transformer,
17 | RealNVP,27/05/2016,https://arxiv.org/abs/1605.08803v3,0,Normalizing Flow,
18 | PixelCNN,16/06/2016,https://arxiv.org/abs/1606.05328,0,Autoregressive / Transformer,
19 | pix2pix,21/11/2016,https://arxiv.org/abs/1611.07004,0,Generative Adversarial Network,
20 | Stack GAN,10/12/2016,https://arxiv.org/abs/1612.03242,1,Generative Adversarial Network,
21 | PixelCNN++,19/01/2017,https://arxiv.org/abs/1701.05517,0,Autoregressive / Transformer,
22 | WGAN,26/01/2017,https://arxiv.org/abs/1701.07875,0,Generative Adversarial Network,
23 | CycleGAN,30/03/2017,https://arxiv.org/abs/1703.10593,0,Generative Adversarial Network,
24 | WGAN GP,31/03/2017,https://arxiv.org/abs/1704.00028,1,Generative Adversarial Network,
25 | Transformers,12/06/2017,https://arxiv.org/abs/1706.03762,0,Autoregressive / Transformer,
26 | MuseGAN,19/09/2017,https://arxiv.org/abs/1709.06298,0,Generative Adversarial Network,
27 | ProGAN,27/10/2017,https://arxiv.org/abs/1710.10196,0,Generative Adversarial Network,
28 | VQ-VAE,02/11/2017,https://arxiv.org/abs/1711.00937v2,0,Variational Autoencoder,
29 | World Models,27/03/2018,https://arxiv.org/abs/1803.10122,0,Variational Autoencoder,
30 | SAGAN,21/05/2018,https://arxiv.org/abs/1805.08318v2,0,Generative Adversarial Network,
31 | GPT,11/06/2018,https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf,0,Autoregressive / Transformer,0.117
32 | GLOW,09/07/2018,https://arxiv.org/abs/1807.03039,0,Normalizing Flow,
33 | Universal Transformer,10/07/2018,https://arxiv.org/abs/1807.03819,1,Autoregressive / Transformer,
34 | BigGAN,28/09/2018,https://arxiv.org/abs/1809.11096,0,Generative Adversarial Network,
35 | FFJORD,02/10/2018,https://arxiv.org/abs/1810.01367,0,Normalizing Flow,
36 | BERT,11/10/2018,https://arxiv.org/abs/1810.04805,0,Autoregressive / Transformer,0.345
37 | StyleGAN,12/12/2018,https://arxiv.org/abs/1812.04948,0,Generative Adversarial Network,
38 | Music Transformer,12/12/2018,https://arxiv.org/abs/1809.04281,0,Autoregressive / Transformer,
39 | GPT-2,14/02/2019,https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf,0,Autoregressive / Transformer,1.5
40 | MuseNet,25/04/2019,https://openai.com/blog/musenet/,0,Autoregressive / Transformer,
41 | VQ-VAE-2,02/06/2019,https://arxiv.org/abs/1906.00446v1,0,Variational Autoencoder,
42 | NCSN,12/07/2019,https://arxiv.org/abs/1907.05600,0,Energy-Based / Diffusion Models,
43 | T5,23/10/2019,https://arxiv.org/abs/1910.10683,0,Autoregressive / Transformer,11
44 | StyleGAN2,03/12/2019,https://arxiv.org/abs/1912.04958,0,Generative Adversarial Network,
45 | NeRF,19/03/2020,https://arxiv.org/abs/2003.08934,0,General,
46 | GPT-3,28/05/2020,https://arxiv.org/abs/2005.14165,0,Autoregressive / Transformer,175
47 | DDPM,19/06/2020,https://arxiv.org/abs/2006.11239,0,Energy-Based / Diffusion Models,
48 | DDIM,06/10/2020,https://arxiv.org/abs/2010.02502,0,Energy-Based / Diffusion Models,
49 | Vision Transformer,22/10/2020,https://arxiv.org/abs/2010.11929,0,Autoregressive / Transformer,
50 | VQ-GAN,17/12/2020,https://arxiv.org/abs/2012.09841,0,Generative Adversarial Network,
51 | DALL.E,24/02/2021,https://arxiv.org/abs/2102.12092,0,Multimodal Models,12
52 | CLIP,26/02/2021,https://arxiv.org/abs/2103.00020,0,Multimodal Models,
53 | GPT-Neo,21/03/2021,https://github.com/EleutherAI/gpt-neo,0,Autoregressive / Transformer,2.7
54 | GPT-J,10/06/2021,https://github.com/kingoflolz/mesh-transformer-jax,0,Autoregressive / Transformer,6
55 | StyleGAN3,23/06/2021,https://arxiv.org/abs/2106.12423,0,Generative Adversarial Network,
56 | Codex,07/07/2021,https://arxiv.org/abs/2107.03374,0,Autoregressive / Transformer,
57 | ViT VQ-GAN,09/10/2021,https://arxiv.org/abs/2110.04627,0,Generative Adversarial Network,
58 | Megatron-Turing NLG,11/10/2021,https://developer.nvidia.com/blog/using-deepspeed-and-megatron-to-train-megatron-turing-nlg-530b-the-worlds-largest-and-most-powerful-generative-language-model/,0,Autoregressive / Transformer,530
59 | Gopher,08/12/2021,https://arxiv.org/abs/2112.11446,0,Autoregressive / Transformer,280
60 | GLIDE,20/12/2021,https://arxiv.org/abs/2112.10741,0,Multimodal Models,5
61 | Latent Diffusion,20/12/2021,https://arxiv.org/abs/2112.10752,0,Energy-Based / Diffusion Models,
62 | LaMDA,20/01/2022,https://arxiv.org/abs/2201.08239,0,Autoregressive / Transformer,137
63 | StyleGAN-XL,01/02/2022,https://arxiv.org/abs/2202.00273v2,0,Generative Adversarial Network,0
64 | GPT-NeoX,02/02/2022,https://github.com/EleutherAI/gpt-neox,0,Autoregressive / Transformer,20
65 | Chinchilla,29/03/2022,https://arxiv.org/abs/2203.15556v1,0,Autoregressive / Transformer,70
66 | PaLM,05/04/2022,https://arxiv.org/abs/2204.02311,0,Autoregressive / Transformer,540
67 | DALL.E 2,13/04/2022,https://arxiv.org/abs/2204.06125,0,Multimodal Models,3.5
68 | Flamingo,29/04/2022,https://arxiv.org/abs/2204.14198,0,Multimodal Models,80
69 | OPT,02/05/2022,https://arxiv.org/abs/2205.01068,0,Autoregressive / Transformer,175
70 | Imagen,23/05/2022,https://arxiv.org/abs/2205.11487,0,Multimodal Models,4.6
71 | Parti,22/06/2022,https://arxiv.org/abs/2206.10789,0,Multimodal Models,20
72 | BLOOM,16/07/2022,https://arxiv.org/abs/2211.05100,0,Autoregressive / Transformer,176
73 | Stable Diffusion,22/08/2022,https://stability.ai/blog/stable-diffusion-public-release,0,Multimodal Models,0.89
74 | ChatGPT,30/11/2022,https://chat.openai.com/,0,Autoregressive / Transformer,
75 | MUSE,02/01/2023,https://arxiv.org/abs/2301.00704,0,Multimodal Models,3
76 | MusicLM,26/01/2023,https://arxiv.org/abs/2301.11325,0,Multimodal Models,
77 | Dreamix,02/02/2023,https://arxiv.org/pdf/2302.01329.pdf,0,Multimodal Models,
78 | Toolformer,09/02/2023,https://arxiv.org/pdf/2302.04761.pdf,0,Autoregressive / Transformer,
79 | ControlNet,10/02/2023,https://arxiv.org/abs/2302.05543,0,Multimodal Models,
80 | LLaMA,24/02/2023,https://arxiv.org/abs/2302.13971,0,Autoregressive / Transformer,65
81 | PaLM-E,06/03/2023,https://arxiv.org/abs/2303.03378,0,Multimodal Models,562
82 | Visual ChatGPT,08/03/2023,https://arxiv.org/abs/2303.04671,0,Multimodal Models,
83 | Alpaca,13/03/2023,https://github.com/tatsu-lab/stanford_alpaca,1,Autoregressive / Transformer,
84 | GPT-4,16/03/2023,https://cdn.openai.com/papers/gpt-4.pdf,0,Autoregressive / Transformer,1000
85 | Luminous,14/04/2022,https://www.aleph-alpha.com/luminous,0,Autoregressive / Transformer,
86 | Flan-T5,20/10/2022,https://arxiv.org/abs/2210.11416,0,Autoregressive / Transformer,11
87 | Falcon,17/03/2023,https://falconllm.tii.ae/,0,Autoregressive / Transformer,40
88 | PaLM 2,10/05/2023,https://ai.google/discover/palm2/,0,Autoregressive / Transformer,340
89 | PanGu-Σ,20/03/2023,https://arxiv.org/abs/2303.10845,1,Autoregressive / Transformer,1085
90 | GPT-3.5,30/11/2022,,0,Autoregressive / Transformer,175
91 | Llama 2,17/07/2023,,0,Autoregressive / Transformer,70
--------------------------------------------------------------------------------
/docs/model_sizes.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidADSP/Generative_Deep_Learning_2nd_Edition/9b1048dbe0d698b486ed16d529cf16fcd3aea29d/docs/model_sizes.png
--------------------------------------------------------------------------------
/docs/timeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidADSP/Generative_Deep_Learning_2nd_Edition/9b1048dbe0d698b486ed16d529cf16fcd3aea29d/docs/timeline.png
--------------------------------------------------------------------------------
/img/book_cover.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidADSP/Generative_Deep_Learning_2nd_Edition/9b1048dbe0d698b486ed16d529cf16fcd3aea29d/img/book_cover.png
--------------------------------------------------------------------------------
/img/structure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidADSP/Generative_Deep_Learning_2nd_Edition/9b1048dbe0d698b486ed16d529cf16fcd3aea29d/img/structure.png
--------------------------------------------------------------------------------
/notebooks/02_deeplearning/01_mlp/checkpoint/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/02_deeplearning/01_mlp/logs/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/02_deeplearning/01_mlp/mlp.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 👀 Multilayer perceptron (MLP)"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "In this notebook, we'll walk through the steps required to train your own multilayer perceptron on the CIFAR dataset"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "import numpy as np\n",
24 | "import matplotlib.pyplot as plt\n",
25 | "\n",
26 | "from tensorflow.keras import layers, models, optimizers, utils, datasets\n",
27 | "from notebooks.utils import display"
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "metadata": {
33 | "tags": []
34 | },
35 | "source": [
36 | "## 0. Parameters "
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "metadata": {},
43 | "outputs": [],
44 | "source": [
45 | "NUM_CLASSES = 10"
46 | ]
47 | },
48 | {
49 | "cell_type": "markdown",
50 | "metadata": {},
51 | "source": [
52 | "## 1. Prepare the Data "
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": null,
58 | "metadata": {},
59 | "outputs": [],
60 | "source": [
61 | "(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": null,
67 | "metadata": {},
68 | "outputs": [],
69 | "source": [
70 | "x_train = x_train.astype(\"float32\") / 255.0\n",
71 | "x_test = x_test.astype(\"float32\") / 255.0\n",
72 | "\n",
73 | "y_train = utils.to_categorical(y_train, NUM_CLASSES)\n",
74 | "y_test = utils.to_categorical(y_test, NUM_CLASSES)"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "display(x_train[:10])\n",
84 | "print(y_train[:10])"
85 | ]
86 | },
87 | {
88 | "cell_type": "markdown",
89 | "metadata": {},
90 | "source": [
91 | "## 2. Build the model "
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": null,
97 | "metadata": {},
98 | "outputs": [],
99 | "source": [
100 | "input_layer = layers.Input((32, 32, 3))\n",
101 | "\n",
102 | "x = layers.Flatten()(input_layer)\n",
103 | "x = layers.Dense(200, activation=\"relu\")(x)\n",
104 | "x = layers.Dense(150, activation=\"relu\")(x)\n",
105 | "\n",
106 | "output_layer = layers.Dense(NUM_CLASSES, activation=\"softmax\")(x)\n",
107 | "\n",
108 | "model = models.Model(input_layer, output_layer)\n",
109 | "\n",
110 | "model.summary()"
111 | ]
112 | },
113 | {
114 | "cell_type": "markdown",
115 | "metadata": {
116 | "tags": []
117 | },
118 | "source": [
119 | "## 3. Train the model "
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": null,
125 | "metadata": {},
126 | "outputs": [],
127 | "source": [
128 | "opt = optimizers.Adam(learning_rate=0.0005)\n",
129 | "model.compile(\n",
130 | " loss=\"categorical_crossentropy\", optimizer=opt, metrics=[\"accuracy\"]\n",
131 | ")"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": null,
137 | "metadata": {},
138 | "outputs": [],
139 | "source": [
140 | "model.fit(x_train, y_train, batch_size=32, epochs=10, shuffle=True)"
141 | ]
142 | },
143 | {
144 | "cell_type": "markdown",
145 | "metadata": {
146 | "tags": []
147 | },
148 | "source": [
149 | "## 4. Evaluation "
150 | ]
151 | },
152 | {
153 | "cell_type": "code",
154 | "execution_count": null,
155 | "metadata": {},
156 | "outputs": [],
157 | "source": [
158 | "model.evaluate(x_test, y_test)"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": null,
164 | "metadata": {},
165 | "outputs": [],
166 | "source": [
167 | "CLASSES = np.array(\n",
168 | " [\n",
169 | " \"airplane\",\n",
170 | " \"automobile\",\n",
171 | " \"bird\",\n",
172 | " \"cat\",\n",
173 | " \"deer\",\n",
174 | " \"dog\",\n",
175 | " \"frog\",\n",
176 | " \"horse\",\n",
177 | " \"ship\",\n",
178 | " \"truck\",\n",
179 | " ]\n",
180 | ")\n",
181 | "\n",
182 | "preds = model.predict(x_test)\n",
183 | "preds_single = CLASSES[np.argmax(preds, axis=-1)]\n",
184 | "actual_single = CLASSES[np.argmax(y_test, axis=-1)]"
185 | ]
186 | },
187 | {
188 | "cell_type": "code",
189 | "execution_count": null,
190 | "metadata": {
191 | "scrolled": true
192 | },
193 | "outputs": [],
194 | "source": [
195 | "n_to_show = 10\n",
196 | "indices = np.random.choice(range(len(x_test)), n_to_show)\n",
197 | "\n",
198 | "fig = plt.figure(figsize=(15, 3))\n",
199 | "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n",
200 | "\n",
201 | "for i, idx in enumerate(indices):\n",
202 | " img = x_test[idx]\n",
203 | " ax = fig.add_subplot(1, n_to_show, i + 1)\n",
204 | " ax.axis(\"off\")\n",
205 | " ax.text(\n",
206 | " 0.5,\n",
207 | " -0.35,\n",
208 | " \"pred = \" + str(preds_single[idx]),\n",
209 | " fontsize=10,\n",
210 | " ha=\"center\",\n",
211 | " transform=ax.transAxes,\n",
212 | " )\n",
213 | " ax.text(\n",
214 | " 0.5,\n",
215 | " -0.7,\n",
216 | " \"act = \" + str(actual_single[idx]),\n",
217 | " fontsize=10,\n",
218 | " ha=\"center\",\n",
219 | " transform=ax.transAxes,\n",
220 | " )\n",
221 | " ax.imshow(img)"
222 | ]
223 | }
224 | ],
225 | "metadata": {
226 | "kernelspec": {
227 | "display_name": "Python 3 (ipykernel)",
228 | "language": "python",
229 | "name": "python3"
230 | },
231 | "language_info": {
232 | "codemirror_mode": {
233 | "name": "ipython",
234 | "version": 3
235 | },
236 | "file_extension": ".py",
237 | "mimetype": "text/x-python",
238 | "name": "python",
239 | "nbconvert_exporter": "python",
240 | "pygments_lexer": "ipython3",
241 | "version": "3.8.10"
242 | },
243 | "vscode": {
244 | "interpreter": {
245 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
246 | }
247 | }
248 | },
249 | "nbformat": 4,
250 | "nbformat_minor": 4
251 | }
252 |
--------------------------------------------------------------------------------
/notebooks/02_deeplearning/01_mlp/models/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/02_deeplearning/01_mlp/output/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/02_deeplearning/02_cnn/checkpoint/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/02_deeplearning/02_cnn/cnn.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 🏞 Convolutional Neural Network"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "In this notebook, we'll walk through the steps required to train your own convolutional neural network (CNN) on the CIFAR dataset"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "import numpy as np\n",
24 | "\n",
25 | "from tensorflow.keras import layers, models, optimizers, utils, datasets\n",
26 | "from notebooks.utils import display"
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {
32 | "tags": []
33 | },
34 | "source": [
35 | "## 0. Parameters "
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": null,
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "NUM_CLASSES = 10"
45 | ]
46 | },
47 | {
48 | "cell_type": "markdown",
49 | "metadata": {},
50 | "source": [
51 | "## 1. Prepare the Data "
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": null,
57 | "metadata": {},
58 | "outputs": [],
59 | "source": [
60 | "(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": null,
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "x_train = x_train.astype(\"float32\") / 255.0\n",
70 | "x_test = x_test.astype(\"float32\") / 255.0\n",
71 | "\n",
72 | "y_train = utils.to_categorical(y_train, NUM_CLASSES)\n",
73 | "y_test = utils.to_categorical(y_test, NUM_CLASSES)"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": null,
79 | "metadata": {},
80 | "outputs": [],
81 | "source": [
82 | "display(x_train[:10])\n",
83 | "print(y_train[:10])"
84 | ]
85 | },
86 | {
87 | "cell_type": "markdown",
88 | "metadata": {},
89 | "source": [
90 | "## 2. Build the model "
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "execution_count": null,
96 | "metadata": {},
97 | "outputs": [],
98 | "source": [
99 | "input_layer = layers.Input((32, 32, 3))\n",
100 | "\n",
101 | "x = layers.Conv2D(filters=32, kernel_size=3, strides=1, padding=\"same\")(\n",
102 | " input_layer\n",
103 | ")\n",
104 | "x = layers.BatchNormalization()(x)\n",
105 | "x = layers.LeakyReLU()(x)\n",
106 | "\n",
107 | "x = layers.Conv2D(filters=32, kernel_size=3, strides=2, padding=\"same\")(x)\n",
108 | "x = layers.BatchNormalization()(x)\n",
109 | "x = layers.LeakyReLU()(x)\n",
110 | "\n",
111 | "x = layers.Conv2D(filters=64, kernel_size=3, strides=1, padding=\"same\")(x)\n",
112 | "x = layers.BatchNormalization()(x)\n",
113 | "x = layers.LeakyReLU()(x)\n",
114 | "\n",
115 | "x = layers.Conv2D(filters=64, kernel_size=3, strides=2, padding=\"same\")(x)\n",
116 | "x = layers.BatchNormalization()(x)\n",
117 | "x = layers.LeakyReLU()(x)\n",
118 | "\n",
119 | "x = layers.Flatten()(x)\n",
120 | "\n",
121 | "x = layers.Dense(128)(x)\n",
122 | "x = layers.BatchNormalization()(x)\n",
123 | "x = layers.LeakyReLU()(x)\n",
124 | "x = layers.Dropout(rate=0.5)(x)\n",
125 | "\n",
126 | "x = layers.Dense(NUM_CLASSES)(x)\n",
127 | "output_layer = layers.Activation(\"softmax\")(x)\n",
128 | "\n",
129 | "model = models.Model(input_layer, output_layer)\n",
130 | "\n",
131 | "model.summary()"
132 | ]
133 | },
134 | {
135 | "cell_type": "markdown",
136 | "metadata": {
137 | "tags": []
138 | },
139 | "source": [
140 | "## 3. Train the model "
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": null,
146 | "metadata": {},
147 | "outputs": [],
148 | "source": [
149 | "opt = optimizers.Adam(learning_rate=0.0005)\n",
150 | "model.compile(\n",
151 | " loss=\"categorical_crossentropy\", optimizer=opt, metrics=[\"accuracy\"]\n",
152 | ")"
153 | ]
154 | },
155 | {
156 | "cell_type": "code",
157 | "execution_count": null,
158 | "metadata": {
159 | "tags": []
160 | },
161 | "outputs": [],
162 | "source": [
163 | "model.fit(\n",
164 | " x_train,\n",
165 | " y_train,\n",
166 | " batch_size=32,\n",
167 | " epochs=10,\n",
168 | " shuffle=True,\n",
169 | " validation_data=(x_test, y_test),\n",
170 | ")"
171 | ]
172 | },
173 | {
174 | "cell_type": "markdown",
175 | "metadata": {
176 | "tags": []
177 | },
178 | "source": [
179 | "## 4. Evaluation "
180 | ]
181 | },
182 | {
183 | "cell_type": "code",
184 | "execution_count": null,
185 | "metadata": {},
186 | "outputs": [],
187 | "source": [
188 | "model.evaluate(x_test, y_test, batch_size=1000)"
189 | ]
190 | },
191 | {
192 | "cell_type": "code",
193 | "execution_count": null,
194 | "metadata": {},
195 | "outputs": [],
196 | "source": [
197 | "CLASSES = np.array(\n",
198 | " [\n",
199 | " \"airplane\",\n",
200 | " \"automobile\",\n",
201 | " \"bird\",\n",
202 | " \"cat\",\n",
203 | " \"deer\",\n",
204 | " \"dog\",\n",
205 | " \"frog\",\n",
206 | " \"horse\",\n",
207 | " \"ship\",\n",
208 | " \"truck\",\n",
209 | " ]\n",
210 | ")\n",
211 | "\n",
212 | "preds = model.predict(x_test)\n",
213 | "preds_single = CLASSES[np.argmax(preds, axis=-1)]\n",
214 | "actual_single = CLASSES[np.argmax(y_test, axis=-1)]"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": null,
220 | "metadata": {},
221 | "outputs": [],
222 | "source": [
223 | "import matplotlib.pyplot as plt\n",
224 | "\n",
225 | "n_to_show = 10\n",
226 | "indices = np.random.choice(range(len(x_test)), n_to_show)\n",
227 | "\n",
228 | "fig = plt.figure(figsize=(15, 3))\n",
229 | "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n",
230 | "\n",
231 | "for i, idx in enumerate(indices):\n",
232 | " img = x_test[idx]\n",
233 | " ax = fig.add_subplot(1, n_to_show, i + 1)\n",
234 | " ax.axis(\"off\")\n",
235 | " ax.text(\n",
236 | " 0.5,\n",
237 | " -0.35,\n",
238 | " \"pred = \" + str(preds_single[idx]),\n",
239 | " fontsize=10,\n",
240 | " ha=\"center\",\n",
241 | " transform=ax.transAxes,\n",
242 | " )\n",
243 | " ax.text(\n",
244 | " 0.5,\n",
245 | " -0.7,\n",
246 | " \"act = \" + str(actual_single[idx]),\n",
247 | " fontsize=10,\n",
248 | " ha=\"center\",\n",
249 | " transform=ax.transAxes,\n",
250 | " )\n",
251 | " ax.imshow(img)"
252 | ]
253 | }
254 | ],
255 | "metadata": {
256 | "kernelspec": {
257 | "display_name": "Python 3 (ipykernel)",
258 | "language": "python",
259 | "name": "python3"
260 | },
261 | "language_info": {
262 | "codemirror_mode": {
263 | "name": "ipython",
264 | "version": 3
265 | },
266 | "file_extension": ".py",
267 | "mimetype": "text/x-python",
268 | "name": "python",
269 | "nbconvert_exporter": "python",
270 | "pygments_lexer": "ipython3",
271 | "version": "3.8.10"
272 | },
273 | "vscode": {
274 | "interpreter": {
275 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
276 | }
277 | }
278 | },
279 | "nbformat": 4,
280 | "nbformat_minor": 4
281 | }
282 |
--------------------------------------------------------------------------------
/notebooks/02_deeplearning/02_cnn/convolutions.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 🎆 Convolutions"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "In this notebook, we'll walk through how convolutional filters can pick out different aspects of an image"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "%matplotlib inline\n",
24 | "import matplotlib.pyplot as plt\n",
25 | "import numpy as np\n",
26 | "from skimage import data\n",
27 | "from skimage.color import rgb2gray\n",
28 | "from skimage.transform import resize"
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {},
34 | "source": [
35 | "## 0. Original Input Image "
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": null,
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "im = rgb2gray(data.coffee())\n",
45 | "im = resize(im, (64, 64))\n",
46 | "print(im.shape)\n",
47 | "\n",
48 | "plt.axis(\"off\")\n",
49 | "plt.imshow(im, cmap=\"gray\")"
50 | ]
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "metadata": {},
55 | "source": [
56 | "## Horizontal Edge Filter "
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "filter1 = np.array([[1, 1, 1], [0, 0, 0], [-1, -1, -1]])\n",
66 | "\n",
67 | "new_image = np.zeros(im.shape)\n",
68 | "\n",
69 | "im_pad = np.pad(im, 1, \"constant\")\n",
70 | "\n",
71 | "for i in range(im.shape[0]):\n",
72 | " for j in range(im.shape[1]):\n",
73 | " try:\n",
74 | " new_image[i, j] = (\n",
75 | " im_pad[i - 1, j - 1] * filter1[0, 0]\n",
76 | " + im_pad[i - 1, j] * filter1[0, 1]\n",
77 | " + im_pad[i - 1, j + 1] * filter1[0, 2]\n",
78 | " + im_pad[i, j - 1] * filter1[1, 0]\n",
79 | " + im_pad[i, j] * filter1[1, 1]\n",
80 | " + im_pad[i, j + 1] * filter1[1, 2]\n",
81 | " + im_pad[i + 1, j - 1] * filter1[2, 0]\n",
82 | " + im_pad[i + 1, j] * filter1[2, 1]\n",
83 | " + im_pad[i + 1, j + 1] * filter1[2, 2]\n",
84 | " )\n",
85 | " except:\n",
86 | " pass\n",
87 | "\n",
88 | "plt.axis(\"off\")\n",
89 | "plt.imshow(new_image, cmap=\"Greys\")"
90 | ]
91 | },
92 | {
93 | "cell_type": "markdown",
94 | "metadata": {},
95 | "source": [
96 | "## Vertical Edge Filter "
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": null,
102 | "metadata": {},
103 | "outputs": [],
104 | "source": [
105 | "filter2 = np.array([[-1, 0, 1], [-1, 0, 1], [-1, 0, 1]])\n",
106 | "\n",
107 | "new_image = np.zeros(im.shape)\n",
108 | "\n",
109 | "im_pad = np.pad(im, 1, \"constant\")\n",
110 | "\n",
111 | "for i in range(im.shape[0]):\n",
112 | " for j in range(im.shape[1]):\n",
113 | " try:\n",
114 | " new_image[i, j] = (\n",
115 | " im_pad[i - 1, j - 1] * filter2[0, 0]\n",
116 | " + im_pad[i - 1, j] * filter2[0, 1]\n",
117 | " + im_pad[i - 1, j + 1] * filter2[0, 2]\n",
118 | " + im_pad[i, j - 1] * filter2[1, 0]\n",
119 | " + im_pad[i, j] * filter2[1, 1]\n",
120 | " + im_pad[i, j + 1] * filter2[1, 2]\n",
121 | " + im_pad[i + 1, j - 1] * filter2[2, 0]\n",
122 | " + im_pad[i + 1, j] * filter2[2, 1]\n",
123 | " + im_pad[i + 1, j + 1] * filter2[2, 2]\n",
124 | " )\n",
125 | " except:\n",
126 | " pass\n",
127 | "\n",
128 | "plt.axis(\"off\")\n",
129 | "plt.imshow(new_image, cmap=\"Greys\")"
130 | ]
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "metadata": {},
135 | "source": [
136 | "## Horizontal Edge Filter with Stride 2 "
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": null,
142 | "metadata": {},
143 | "outputs": [],
144 | "source": [
145 | "filter1 = np.array([[1, 1, 1], [0, 0, 0], [-1, -1, -1]])\n",
146 | "\n",
147 | "stride = 2\n",
148 | "\n",
149 | "new_image = np.zeros((int(im.shape[0] / stride), int(im.shape[1] / stride)))\n",
150 | "\n",
151 | "im_pad = np.pad(im, 1, \"constant\")\n",
152 | "\n",
153 | "for i in range(0, im.shape[0], stride):\n",
154 | " for j in range(0, im.shape[1], stride):\n",
155 | " try:\n",
156 | " new_image[int(i / stride), int(j / stride)] = (\n",
157 | " im_pad[i - 1, j - 1] * filter1[0, 0]\n",
158 | " + im_pad[i - 1, j] * filter1[0, 1]\n",
159 | " + im_pad[i - 1, j + 1] * filter1[0, 2]\n",
160 | " + im_pad[i, j - 1] * filter1[1, 0]\n",
161 | " + im_pad[i, j] * filter1[1, 1]\n",
162 | " + im_pad[i, j + 1] * filter1[1, 2]\n",
163 | " + im_pad[i + 1, j - 1] * filter1[2, 0]\n",
164 | " + im_pad[i + 1, j] * filter1[2, 1]\n",
165 | " + im_pad[i + 1, j + 1] * filter1[2, 2]\n",
166 | " )\n",
167 | " except:\n",
168 | " pass\n",
169 | "\n",
170 | "plt.axis(\"off\")\n",
171 | "plt.imshow(new_image, cmap=\"Greys\")"
172 | ]
173 | },
174 | {
175 | "cell_type": "markdown",
176 | "metadata": {},
177 | "source": [
178 | "## Vertical Edge Filter with Stride 2 "
179 | ]
180 | },
181 | {
182 | "cell_type": "code",
183 | "execution_count": null,
184 | "metadata": {},
185 | "outputs": [],
186 | "source": [
187 | "filter2 = np.array([[-1, 0, 1], [-1, 0, 1], [-1, 0, 1]])\n",
188 | "\n",
189 | "stride = 2\n",
190 | "\n",
191 | "new_image = np.zeros((int(im.shape[0] / stride), int(im.shape[1] / stride)))\n",
192 | "\n",
193 | "im_pad = np.pad(im, 1, \"constant\")\n",
194 | "\n",
195 | "for i in range(0, im.shape[0], stride):\n",
196 | " for j in range(0, im.shape[1], stride):\n",
197 | " try:\n",
198 | " new_image[int(i / stride), int(j / stride)] = (\n",
199 | " im_pad[i - 1, j - 1] * filter2[0, 0]\n",
200 | " + im_pad[i - 1, j] * filter2[0, 1]\n",
201 | " + im_pad[i - 1, j + 1] * filter2[0, 2]\n",
202 | " + im_pad[i, j - 1] * filter2[1, 0]\n",
203 | " + im_pad[i, j] * filter2[1, 1]\n",
204 | " + im_pad[i, j + 1] * filter2[1, 2]\n",
205 | " + im_pad[i + 1, j - 1] * filter2[2, 0]\n",
206 | " + im_pad[i + 1, j] * filter2[2, 1]\n",
207 | " + im_pad[i + 1, j + 1] * filter2[2, 2]\n",
208 | " )\n",
209 | " except:\n",
210 | " pass\n",
211 | "\n",
212 | "plt.axis(\"off\")\n",
213 | "plt.imshow(new_image, cmap=\"Greys\")"
214 | ]
215 | }
216 | ],
217 | "metadata": {
218 | "kernelspec": {
219 | "display_name": "Python 3 (ipykernel)",
220 | "language": "python",
221 | "name": "python3"
222 | },
223 | "language_info": {
224 | "codemirror_mode": {
225 | "name": "ipython",
226 | "version": 3
227 | },
228 | "file_extension": ".py",
229 | "mimetype": "text/x-python",
230 | "name": "python",
231 | "nbconvert_exporter": "python",
232 | "pygments_lexer": "ipython3",
233 | "version": "3.8.10"
234 | },
235 | "vscode": {
236 | "interpreter": {
237 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
238 | }
239 | }
240 | },
241 | "nbformat": 4,
242 | "nbformat_minor": 4
243 | }
244 |
--------------------------------------------------------------------------------
/notebooks/02_deeplearning/02_cnn/logs/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/02_deeplearning/02_cnn/models/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/02_deeplearning/02_cnn/output/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/03_vae/01_autoencoder/autoencoder.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "b076bd1a-b236-4fbc-953d-8295b25122ae",
6 | "metadata": {},
7 | "source": [
8 | "# 👖 Autoencoders on Fashion MNIST"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "id": "9235cbd1-f136-411c-88d9-f69f270c0b96",
14 | "metadata": {},
15 | "source": [
16 | "In this notebook, we'll walk through the steps required to train your own autoencoder on the fashion MNIST dataset."
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "id": "84acc7be-6764-4668-b2bb-178f63deeed3",
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "%load_ext autoreload\n",
27 | "%autoreload 2\n",
28 | "\n",
29 | "import numpy as np\n",
30 | "import matplotlib.pyplot as plt\n",
31 | "\n",
32 | "from tensorflow.keras import layers, models, datasets, callbacks\n",
33 | "import tensorflow.keras.backend as K\n",
34 | "\n",
35 | "from notebooks.utils import display"
36 | ]
37 | },
38 | {
39 | "cell_type": "markdown",
40 | "id": "339e6268-ebd7-4feb-86db-1fe7abccdbe5",
41 | "metadata": {},
42 | "source": [
43 | "## 0. Parameters "
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": null,
49 | "id": "1b2ee6ce-129f-4833-b0c5-fa567381c4e0",
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "IMAGE_SIZE = 32\n",
54 | "CHANNELS = 1\n",
55 | "BATCH_SIZE = 100\n",
56 | "BUFFER_SIZE = 1000\n",
57 | "VALIDATION_SPLIT = 0.2\n",
58 | "EMBEDDING_DIM = 2\n",
59 | "EPOCHS = 3"
60 | ]
61 | },
62 | {
63 | "cell_type": "markdown",
64 | "id": "b7716fac-0010-49b0-b98e-53be2259edde",
65 | "metadata": {},
66 | "source": [
67 | "## 1. Prepare the data "
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": null,
73 | "id": "9a73e5a4-1638-411c-8d3c-29f823424458",
74 | "metadata": {},
75 | "outputs": [],
76 | "source": [
77 | "# Load the data\n",
78 | "(x_train, y_train), (x_test, y_test) = datasets.fashion_mnist.load_data()"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": null,
84 | "id": "ebae2f0d-59fd-4796-841f-7213eae638de",
85 | "metadata": {},
86 | "outputs": [],
87 | "source": [
88 | "# Preprocess the data\n",
89 | "\n",
90 | "\n",
91 | "def preprocess(imgs):\n",
92 | " \"\"\"\n",
93 | " Normalize and reshape the images\n",
94 | " \"\"\"\n",
95 | " imgs = imgs.astype(\"float32\") / 255.0\n",
96 | " imgs = np.pad(imgs, ((0, 0), (2, 2), (2, 2)), constant_values=0.0)\n",
97 | " imgs = np.expand_dims(imgs, -1)\n",
98 | " return imgs\n",
99 | "\n",
100 | "\n",
101 | "x_train = preprocess(x_train)\n",
102 | "x_test = preprocess(x_test)"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": null,
108 | "id": "fa53709f-7f3f-483b-9db8-2e5f9b9942c2",
109 | "metadata": {},
110 | "outputs": [],
111 | "source": [
112 | "# Show some items of clothing from the training set\n",
113 | "display(x_train)"
114 | ]
115 | },
116 | {
117 | "cell_type": "markdown",
118 | "id": "aff50401-3abe-4c10-bba8-b35bc13ad7d5",
119 | "metadata": {
120 | "tags": []
121 | },
122 | "source": [
123 | "## 2. Build the autoencoder "
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": null,
129 | "id": "086e2584-c60d-4990-89f4-2092c44e023e",
130 | "metadata": {},
131 | "outputs": [],
132 | "source": [
133 | "# Encoder\n",
134 | "encoder_input = layers.Input(\n",
135 | " shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS), name=\"encoder_input\"\n",
136 | ")\n",
137 | "x = layers.Conv2D(32, (3, 3), strides=2, activation=\"relu\", padding=\"same\")(\n",
138 | " encoder_input\n",
139 | ")\n",
140 | "x = layers.Conv2D(64, (3, 3), strides=2, activation=\"relu\", padding=\"same\")(x)\n",
141 | "x = layers.Conv2D(128, (3, 3), strides=2, activation=\"relu\", padding=\"same\")(x)\n",
142 | "shape_before_flattening = K.int_shape(x)[1:] # the decoder will need this!\n",
143 | "\n",
144 | "x = layers.Flatten()(x)\n",
145 | "encoder_output = layers.Dense(EMBEDDING_DIM, name=\"encoder_output\")(x)\n",
146 | "\n",
147 | "encoder = models.Model(encoder_input, encoder_output)\n",
148 | "encoder.summary()"
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": null,
154 | "id": "6c409e63-1aea-42e2-8324-c3e2a12073ee",
155 | "metadata": {},
156 | "outputs": [],
157 | "source": [
158 | "# Decoder\n",
159 | "decoder_input = layers.Input(shape=(EMBEDDING_DIM,), name=\"decoder_input\")\n",
160 | "x = layers.Dense(np.prod(shape_before_flattening))(decoder_input)\n",
161 | "x = layers.Reshape(shape_before_flattening)(x)\n",
162 | "x = layers.Conv2DTranspose(\n",
163 | " 128, (3, 3), strides=2, activation=\"relu\", padding=\"same\"\n",
164 | ")(x)\n",
165 | "x = layers.Conv2DTranspose(\n",
166 | " 64, (3, 3), strides=2, activation=\"relu\", padding=\"same\"\n",
167 | ")(x)\n",
168 | "x = layers.Conv2DTranspose(\n",
169 | " 32, (3, 3), strides=2, activation=\"relu\", padding=\"same\"\n",
170 | ")(x)\n",
171 | "decoder_output = layers.Conv2D(\n",
172 | " CHANNELS,\n",
173 | " (3, 3),\n",
174 | " strides=1,\n",
175 | " activation=\"sigmoid\",\n",
176 | " padding=\"same\",\n",
177 | " name=\"decoder_output\",\n",
178 | ")(x)\n",
179 | "\n",
180 | "decoder = models.Model(decoder_input, decoder_output)\n",
181 | "decoder.summary()"
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": null,
187 | "id": "34dc7c69-26a8-4c17-aa24-792f1b0a69b4",
188 | "metadata": {},
189 | "outputs": [],
190 | "source": [
191 | "# Autoencoder\n",
192 | "autoencoder = models.Model(\n",
193 | " encoder_input, decoder(encoder_output)\n",
194 | ") # decoder(encoder_output)\n",
195 | "autoencoder.summary()"
196 | ]
197 | },
198 | {
199 | "cell_type": "markdown",
200 | "id": "35b14665-4359-447b-be58-3fd58ba69084",
201 | "metadata": {},
202 | "source": [
203 | "## 3. Train the autoencoder "
204 | ]
205 | },
206 | {
207 | "cell_type": "code",
208 | "execution_count": null,
209 | "id": "b429fdad-ea9c-45a2-a556-eb950d793824",
210 | "metadata": {},
211 | "outputs": [],
212 | "source": [
213 | "# Compile the autoencoder\n",
214 | "autoencoder.compile(optimizer=\"adam\", loss=\"binary_crossentropy\")"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": null,
220 | "id": "c525e44b-b3bb-489c-9d35-fcfe3e714e6a",
221 | "metadata": {},
222 | "outputs": [],
223 | "source": [
224 | "# Create a model save checkpoint\n",
225 | "model_checkpoint_callback = callbacks.ModelCheckpoint(\n",
226 | " filepath=\"./checkpoint\",\n",
227 | " save_weights_only=False,\n",
228 | " save_freq=\"epoch\",\n",
229 | " monitor=\"loss\",\n",
230 | " mode=\"min\",\n",
231 | " save_best_only=True,\n",
232 | " verbose=0,\n",
233 | ")\n",
234 | "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")"
235 | ]
236 | },
237 | {
238 | "cell_type": "code",
239 | "execution_count": null,
240 | "id": "b3c497b7-fa40-48df-b2bf-541239cc9400",
241 | "metadata": {
242 | "tags": []
243 | },
244 | "outputs": [],
245 | "source": [
246 | "autoencoder.fit(\n",
247 | " x_train,\n",
248 | " x_train,\n",
249 | " epochs=EPOCHS,\n",
250 | " batch_size=BATCH_SIZE,\n",
251 | " shuffle=True,\n",
252 | " validation_data=(x_test, x_test),\n",
253 | " callbacks=[model_checkpoint_callback, tensorboard_callback],\n",
254 | ")"
255 | ]
256 | },
257 | {
258 | "cell_type": "code",
259 | "execution_count": null,
260 | "id": "edb847d1-c22d-4923-ba92-0ecde0f12fca",
261 | "metadata": {},
262 | "outputs": [],
263 | "source": [
264 | "# Save the final models\n",
265 | "autoencoder.save(\"./models/autoencoder\")\n",
266 | "encoder.save(\"./models/encoder\")\n",
267 | "decoder.save(\"./models/decoder\")"
268 | ]
269 | },
270 | {
271 | "cell_type": "markdown",
272 | "id": "bc0f31bc-77e6-49e8-bb76-51bca124744c",
273 | "metadata": {
274 | "tags": []
275 | },
276 | "source": [
277 | "## 4. Reconstruct using the autoencoder "
278 | ]
279 | },
280 | {
281 | "cell_type": "code",
282 | "execution_count": null,
283 | "id": "d4d83729-71a2-4494-86a5-e17830974ef0",
284 | "metadata": {},
285 | "outputs": [],
286 | "source": [
287 | "n_to_predict = 5000\n",
288 | "example_images = x_test[:n_to_predict]\n",
289 | "example_labels = y_test[:n_to_predict]"
290 | ]
291 | },
292 | {
293 | "cell_type": "code",
294 | "execution_count": null,
295 | "id": "5c9b2a91-7cea-4595-a857-11f5ab00875e",
296 | "metadata": {},
297 | "outputs": [],
298 | "source": [
299 | "predictions = autoencoder.predict(example_images)\n",
300 | "\n",
301 | "print(\"Example real clothing items\")\n",
302 | "display(example_images)\n",
303 | "print(\"Reconstructions\")\n",
304 | "display(predictions)"
305 | ]
306 | },
307 | {
308 | "cell_type": "markdown",
309 | "id": "b77c88bb-ada4-4091-94e3-764f1385f1fc",
310 | "metadata": {},
311 | "source": [
312 | "## 5. Embed using the encoder "
313 | ]
314 | },
315 | {
316 | "cell_type": "code",
317 | "execution_count": null,
318 | "id": "5e723c1c-136b-47e5-9972-ee964712d148",
319 | "metadata": {},
320 | "outputs": [],
321 | "source": [
322 | "# Encode the example images\n",
323 | "embeddings = encoder.predict(example_images)"
324 | ]
325 | },
326 | {
327 | "cell_type": "code",
328 | "execution_count": null,
329 | "id": "2ed4e9bd-df14-4832-a765-dfaf36d49fca",
330 | "metadata": {},
331 | "outputs": [],
332 | "source": [
333 | "# Some examples of the embeddings\n",
334 | "print(embeddings[:10])"
335 | ]
336 | },
337 | {
338 | "cell_type": "code",
339 | "execution_count": null,
340 | "id": "3bb208e8-6351-49ac-a68c-679a830f13bf",
341 | "metadata": {},
342 | "outputs": [],
343 | "source": [
344 | "# Show the encoded points in 2D space\n",
345 | "figsize = 8\n",
346 | "\n",
347 | "plt.figure(figsize=(figsize, figsize))\n",
348 | "plt.scatter(embeddings[:, 0], embeddings[:, 1], c=\"black\", alpha=0.5, s=3)\n",
349 | "plt.show()"
350 | ]
351 | },
352 | {
353 | "cell_type": "code",
354 | "execution_count": null,
355 | "id": "138a34ca-67b4-42b7-a9fa-f7ffe397df49",
356 | "metadata": {},
357 | "outputs": [],
358 | "source": [
359 | "# Colour the embeddings by their label (clothing type - see table)\n",
360 | "example_labels = y_test[:n_to_predict]\n",
361 | "\n",
362 | "figsize = 8\n",
363 | "plt.figure(figsize=(figsize, figsize))\n",
364 | "plt.scatter(\n",
365 | " embeddings[:, 0],\n",
366 | " embeddings[:, 1],\n",
367 | " cmap=\"rainbow\",\n",
368 | " c=example_labels,\n",
369 | " alpha=0.8,\n",
370 | " s=3,\n",
371 | ")\n",
372 | "plt.colorbar()\n",
373 | "plt.show()"
374 | ]
375 | },
376 | {
377 | "cell_type": "markdown",
378 | "id": "e0616b71-3354-419c-8ddb-f64fc29850ca",
379 | "metadata": {},
380 | "source": [
381 | "## 6. Generate using the decoder "
382 | ]
383 | },
384 | {
385 | "cell_type": "code",
386 | "execution_count": null,
387 | "id": "2d494893-059f-42e4-825e-31c06fa3cd09",
388 | "metadata": {},
389 | "outputs": [],
390 | "source": [
391 | "# Get the range of the existing embeddings\n",
392 | "mins, maxs = np.min(embeddings, axis=0), np.max(embeddings, axis=0)\n",
393 | "\n",
394 | "# Sample some points in the latent space\n",
395 | "grid_width, grid_height = (6, 3)\n",
396 | "sample = np.random.uniform(\n",
397 | " mins, maxs, size=(grid_width * grid_height, EMBEDDING_DIM)\n",
398 | ")"
399 | ]
400 | },
401 | {
402 | "cell_type": "code",
403 | "execution_count": null,
404 | "id": "ba3b1c66-c89d-436a-b009-19f1f5a785e5",
405 | "metadata": {},
406 | "outputs": [],
407 | "source": [
408 | "# Decode the sampled points\n",
409 | "reconstructions = decoder.predict(sample)"
410 | ]
411 | },
412 | {
413 | "cell_type": "code",
414 | "execution_count": null,
415 | "id": "feea9b9d-8d3e-43f5-9ead-cd9e38367c00",
416 | "metadata": {},
417 | "outputs": [],
418 | "source": [
419 | "# Draw a plot of...\n",
420 | "figsize = 8\n",
421 | "plt.figure(figsize=(figsize, figsize))\n",
422 | "\n",
423 | "# ... the original embeddings ...\n",
424 | "plt.scatter(embeddings[:, 0], embeddings[:, 1], c=\"black\", alpha=0.5, s=2)\n",
425 | "\n",
426 | "# ... and the newly generated points in the latent space\n",
427 | "plt.scatter(sample[:, 0], sample[:, 1], c=\"#00B0F0\", alpha=1, s=40)\n",
428 | "plt.show()\n",
429 | "\n",
430 | "# Add underneath a grid of the decoded images\n",
431 | "fig = plt.figure(figsize=(figsize, grid_height * 2))\n",
432 | "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n",
433 | "\n",
434 | "for i in range(grid_width * grid_height):\n",
435 | " ax = fig.add_subplot(grid_height, grid_width, i + 1)\n",
436 | " ax.axis(\"off\")\n",
437 | " ax.text(\n",
438 | " 0.5,\n",
439 | " -0.35,\n",
440 | " str(np.round(sample[i, :], 1)),\n",
441 | " fontsize=10,\n",
442 | " ha=\"center\",\n",
443 | " transform=ax.transAxes,\n",
444 | " )\n",
445 | " ax.imshow(reconstructions[i, :, :], cmap=\"Greys\")"
446 | ]
447 | },
448 | {
449 | "cell_type": "code",
450 | "execution_count": null,
451 | "id": "f64434a4-41c5-4225-ad31-9cf83f8797e1",
452 | "metadata": {},
453 | "outputs": [],
454 | "source": [
455 | "# Colour the embeddings by their label (clothing type - see table)\n",
456 | "figsize = 12\n",
457 | "grid_size = 15\n",
458 | "plt.figure(figsize=(figsize, figsize))\n",
459 | "plt.scatter(\n",
460 | " embeddings[:, 0],\n",
461 | " embeddings[:, 1],\n",
462 | " cmap=\"rainbow\",\n",
463 | " c=example_labels,\n",
464 | " alpha=0.8,\n",
465 | " s=300,\n",
466 | ")\n",
467 | "plt.colorbar()\n",
468 | "\n",
469 | "x = np.linspace(min(embeddings[:, 0]), max(embeddings[:, 0]), grid_size)\n",
470 | "y = np.linspace(max(embeddings[:, 1]), min(embeddings[:, 1]), grid_size)\n",
471 | "xv, yv = np.meshgrid(x, y)\n",
472 | "xv = xv.flatten()\n",
473 | "yv = yv.flatten()\n",
474 | "grid = np.array(list(zip(xv, yv)))\n",
475 | "\n",
476 | "reconstructions = decoder.predict(grid)\n",
477 | "# plt.scatter(grid[:, 0], grid[:, 1], c=\"black\", alpha=1, s=10)\n",
478 | "plt.show()\n",
479 | "\n",
480 | "fig = plt.figure(figsize=(figsize, figsize))\n",
481 | "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n",
482 | "for i in range(grid_size**2):\n",
483 | " ax = fig.add_subplot(grid_size, grid_size, i + 1)\n",
484 | " ax.axis(\"off\")\n",
485 | " ax.imshow(reconstructions[i, :, :], cmap=\"Greys\")"
486 | ]
487 | }
488 | ],
489 | "metadata": {
490 | "kernelspec": {
491 | "display_name": "Python 3 (ipykernel)",
492 | "language": "python",
493 | "name": "python3"
494 | },
495 | "language_info": {
496 | "codemirror_mode": {
497 | "name": "ipython",
498 | "version": 3
499 | },
500 | "file_extension": ".py",
501 | "mimetype": "text/x-python",
502 | "name": "python",
503 | "nbconvert_exporter": "python",
504 | "pygments_lexer": "ipython3",
505 | "version": "3.8.10"
506 | },
507 | "vscode": {
508 | "interpreter": {
509 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
510 | }
511 | }
512 | },
513 | "nbformat": 4,
514 | "nbformat_minor": 5
515 | }
516 |
--------------------------------------------------------------------------------
/notebooks/03_vae/01_autoencoder/checkpoint/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/03_vae/01_autoencoder/logs/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/03_vae/01_autoencoder/models/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/03_vae/01_autoencoder/output/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/03_vae/02_vae_fashion/checkpoint/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/03_vae/02_vae_fashion/logs/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/03_vae/02_vae_fashion/models/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/03_vae/02_vae_fashion/output/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/03_vae/03_vae_faces/checkpoint/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/03_vae/03_vae_faces/logs/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/03_vae/03_vae_faces/models/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/03_vae/03_vae_faces/output/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/03_vae/03_vae_faces/vae_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 |
5 | def get_vector_from_label(data, vae, embedding_dim, label):
6 | current_sum_POS = np.zeros(shape=embedding_dim, dtype="float32")
7 | current_n_POS = 0
8 | current_mean_POS = np.zeros(shape=embedding_dim, dtype="float32")
9 |
10 | current_sum_NEG = np.zeros(shape=embedding_dim, dtype="float32")
11 | current_n_NEG = 0
12 | current_mean_NEG = np.zeros(shape=embedding_dim, dtype="float32")
13 |
14 | current_vector = np.zeros(shape=embedding_dim, dtype="float32")
15 | current_dist = 0
16 |
17 | print("label: " + label)
18 | print("images : POS move : NEG move :distance : 𝛥 distance")
19 | while current_n_POS < 10000:
20 | batch = list(data.take(1).get_single_element())
21 | im = batch[0]
22 | attribute = batch[1]
23 |
24 | _, _, z = vae.encoder.predict(np.array(im), verbose=0)
25 |
26 | z_POS = z[attribute == 1]
27 | z_NEG = z[attribute == -1]
28 |
29 | if len(z_POS) > 0:
30 | current_sum_POS = current_sum_POS + np.sum(z_POS, axis=0)
31 | current_n_POS += len(z_POS)
32 | new_mean_POS = current_sum_POS / current_n_POS
33 | movement_POS = np.linalg.norm(new_mean_POS - current_mean_POS)
34 |
35 | if len(z_NEG) > 0:
36 | current_sum_NEG = current_sum_NEG + np.sum(z_NEG, axis=0)
37 | current_n_NEG += len(z_NEG)
38 | new_mean_NEG = current_sum_NEG / current_n_NEG
39 | movement_NEG = np.linalg.norm(new_mean_NEG - current_mean_NEG)
40 |
41 | current_vector = new_mean_POS - new_mean_NEG
42 | new_dist = np.linalg.norm(current_vector)
43 | dist_change = new_dist - current_dist
44 |
45 | print(
46 | str(current_n_POS)
47 | + " : "
48 | + str(np.round(movement_POS, 3))
49 | + " : "
50 | + str(np.round(movement_NEG, 3))
51 | + " : "
52 | + str(np.round(new_dist, 3))
53 | + " : "
54 | + str(np.round(dist_change, 3))
55 | )
56 |
57 | current_mean_POS = np.copy(new_mean_POS)
58 | current_mean_NEG = np.copy(new_mean_NEG)
59 | current_dist = np.copy(new_dist)
60 |
61 | if np.sum([movement_POS, movement_NEG]) < 0.08:
62 | current_vector = current_vector / current_dist
63 | print("Found the " + label + " vector")
64 | break
65 |
66 | return current_vector
67 |
68 |
69 | def add_vector_to_images(data, vae, feature_vec):
70 | n_to_show = 5
71 | factors = [-4, -3, -2, -1, 0, 1, 2, 3, 4]
72 |
73 | example_batch = list(data.take(1).get_single_element())
74 | example_images = example_batch[0]
75 |
76 | _, _, z_points = vae.encoder.predict(example_images, verbose=0)
77 |
78 | fig = plt.figure(figsize=(18, 10))
79 |
80 | counter = 1
81 |
82 | for i in range(n_to_show):
83 | img = example_images[i]
84 | sub = fig.add_subplot(n_to_show, len(factors) + 1, counter)
85 | sub.axis("off")
86 | sub.imshow(img)
87 |
88 | counter += 1
89 |
90 | for factor in factors:
91 | changed_z_point = z_points[i] + feature_vec * factor
92 | changed_image = vae.decoder.predict(
93 | np.array([changed_z_point]), verbose=0
94 | )[0]
95 |
96 | sub = fig.add_subplot(n_to_show, len(factors) + 1, counter)
97 | sub.axis("off")
98 | sub.imshow(changed_image)
99 |
100 | counter += 1
101 |
102 | plt.show()
103 |
104 |
105 | def morph_faces(data, vae):
106 | factors = np.arange(0, 1, 0.1)
107 |
108 | example_batch = list(data.take(1).get_single_element())[:2]
109 | example_images = example_batch[0]
110 | _, _, z_points = vae.encoder.predict(example_images, verbose=0)
111 |
112 | fig = plt.figure(figsize=(18, 8))
113 |
114 | counter = 1
115 |
116 | img = example_images[0]
117 | sub = fig.add_subplot(1, len(factors) + 2, counter)
118 | sub.axis("off")
119 | sub.imshow(img)
120 |
121 | counter += 1
122 |
123 | for factor in factors:
124 | changed_z_point = z_points[0] * (1 - factor) + z_points[1] * factor
125 | changed_image = vae.decoder.predict(
126 | np.array([changed_z_point]), verbose=0
127 | )[0]
128 | sub = fig.add_subplot(1, len(factors) + 2, counter)
129 | sub.axis("off")
130 | sub.imshow(changed_image)
131 |
132 | counter += 1
133 |
134 | img = example_images[1]
135 | sub = fig.add_subplot(1, len(factors) + 2, counter)
136 | sub.axis("off")
137 | sub.imshow(img)
138 |
139 | plt.show()
140 |
--------------------------------------------------------------------------------
/notebooks/04_gan/01_dcgan/checkpoint/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/04_gan/01_dcgan/dcgan.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "b076bd1a-b236-4fbc-953d-8295b25122ae",
6 | "metadata": {},
7 | "source": [
8 | "# 🧱 DCGAN - Bricks Data"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "id": "1603ea4b-8345-4e2e-ae7c-01c9953900e8",
14 | "metadata": {},
15 | "source": [
16 | "In this notebook, we'll walk through the steps required to train your own DCGAN on the bricks dataset"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "id": "4e0d56cc-4773-4029-97d8-26f882ba79c9",
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "%load_ext autoreload\n",
27 | "%autoreload 2\n",
28 | "import numpy as np\n",
29 | "import matplotlib.pyplot as plt\n",
30 | "\n",
31 | "import tensorflow as tf\n",
32 | "from tensorflow.keras import (\n",
33 | " layers,\n",
34 | " models,\n",
35 | " callbacks,\n",
36 | " losses,\n",
37 | " utils,\n",
38 | " metrics,\n",
39 | " optimizers,\n",
40 | ")\n",
41 | "\n",
42 | "from notebooks.utils import display, sample_batch"
43 | ]
44 | },
45 | {
46 | "cell_type": "markdown",
47 | "id": "339e6268-ebd7-4feb-86db-1fe7abccdbe5",
48 | "metadata": {},
49 | "source": [
50 | "## 0. Parameters "
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": null,
56 | "id": "1b2ee6ce-129f-4833-b0c5-fa567381c4e0",
57 | "metadata": {},
58 | "outputs": [],
59 | "source": [
60 | "IMAGE_SIZE = 64\n",
61 | "CHANNELS = 1\n",
62 | "BATCH_SIZE = 128\n",
63 | "Z_DIM = 100\n",
64 | "EPOCHS = 300\n",
65 | "LOAD_MODEL = False\n",
66 | "ADAM_BETA_1 = 0.5\n",
67 | "ADAM_BETA_2 = 0.999\n",
68 | "LEARNING_RATE = 0.0002\n",
69 | "NOISE_PARAM = 0.1"
70 | ]
71 | },
72 | {
73 | "cell_type": "markdown",
74 | "id": "b7716fac-0010-49b0-b98e-53be2259edde",
75 | "metadata": {},
76 | "source": [
77 | "## 1. Prepare the data "
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": null,
83 | "id": "73f4c594-3f6d-4c8e-94c1-2c2ba7bce076",
84 | "metadata": {},
85 | "outputs": [],
86 | "source": [
87 | "train_data = utils.image_dataset_from_directory(\n",
88 | " \"/app/data/lego-brick-images/dataset/\",\n",
89 | " labels=None,\n",
90 | " color_mode=\"grayscale\",\n",
91 | " image_size=(IMAGE_SIZE, IMAGE_SIZE),\n",
92 | " batch_size=BATCH_SIZE,\n",
93 | " shuffle=True,\n",
94 | " seed=42,\n",
95 | " interpolation=\"bilinear\",\n",
96 | ")"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": null,
102 | "id": "a995473f-c389-4158-92d2-93a2fa937916",
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "def preprocess(img):\n",
107 | " \"\"\"\n",
108 | " Normalize and reshape the images\n",
109 | " \"\"\"\n",
110 | " img = (tf.cast(img, \"float32\") - 127.5) / 127.5\n",
111 | " return img\n",
112 | "\n",
113 | "\n",
114 | "train = train_data.map(lambda x: preprocess(x))"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": null,
120 | "id": "80bcdbdd-fb1e-451f-b89c-03fd9b80deb5",
121 | "metadata": {},
122 | "outputs": [],
123 | "source": [
124 | "train_sample = sample_batch(train)"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": null,
130 | "id": "fa53709f-7f3f-483b-9db8-2e5f9b9942c2",
131 | "metadata": {},
132 | "outputs": [],
133 | "source": [
134 | "display(train_sample)"
135 | ]
136 | },
137 | {
138 | "cell_type": "markdown",
139 | "id": "aff50401-3abe-4c10-bba8-b35bc13ad7d5",
140 | "metadata": {
141 | "tags": []
142 | },
143 | "source": [
144 | "## 2. Build the GAN "
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": null,
150 | "id": "9230b5bf-b4a8-48d5-b73b-6899a598f296",
151 | "metadata": {},
152 | "outputs": [],
153 | "source": [
154 | "discriminator_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS))\n",
155 | "x = layers.Conv2D(64, kernel_size=4, strides=2, padding=\"same\", use_bias=False)(\n",
156 | " discriminator_input\n",
157 | ")\n",
158 | "x = layers.LeakyReLU(0.2)(x)\n",
159 | "x = layers.Dropout(0.3)(x)\n",
160 | "x = layers.Conv2D(\n",
161 | " 128, kernel_size=4, strides=2, padding=\"same\", use_bias=False\n",
162 | ")(x)\n",
163 | "x = layers.BatchNormalization(momentum=0.9)(x)\n",
164 | "x = layers.LeakyReLU(0.2)(x)\n",
165 | "x = layers.Dropout(0.3)(x)\n",
166 | "x = layers.Conv2D(\n",
167 | " 256, kernel_size=4, strides=2, padding=\"same\", use_bias=False\n",
168 | ")(x)\n",
169 | "x = layers.BatchNormalization(momentum=0.9)(x)\n",
170 | "x = layers.LeakyReLU(0.2)(x)\n",
171 | "x = layers.Dropout(0.3)(x)\n",
172 | "x = layers.Conv2D(\n",
173 | " 512, kernel_size=4, strides=2, padding=\"same\", use_bias=False\n",
174 | ")(x)\n",
175 | "x = layers.BatchNormalization(momentum=0.9)(x)\n",
176 | "x = layers.LeakyReLU(0.2)(x)\n",
177 | "x = layers.Dropout(0.3)(x)\n",
178 | "x = layers.Conv2D(\n",
179 | " 1,\n",
180 | " kernel_size=4,\n",
181 | " strides=1,\n",
182 | " padding=\"valid\",\n",
183 | " use_bias=False,\n",
184 | " activation=\"sigmoid\",\n",
185 | ")(x)\n",
186 | "discriminator_output = layers.Flatten()(x)\n",
187 | "\n",
188 | "discriminator = models.Model(discriminator_input, discriminator_output)\n",
189 | "discriminator.summary()"
190 | ]
191 | },
192 | {
193 | "cell_type": "code",
194 | "execution_count": null,
195 | "id": "b30dcc08-3869-4b67-a295-61f13d5d4e4c",
196 | "metadata": {},
197 | "outputs": [],
198 | "source": [
199 | "generator_input = layers.Input(shape=(Z_DIM,))\n",
200 | "x = layers.Reshape((1, 1, Z_DIM))(generator_input)\n",
201 | "x = layers.Conv2DTranspose(\n",
202 | " 512, kernel_size=4, strides=1, padding=\"valid\", use_bias=False\n",
203 | ")(x)\n",
204 | "x = layers.BatchNormalization(momentum=0.9)(x)\n",
205 | "x = layers.LeakyReLU(0.2)(x)\n",
206 | "x = layers.Conv2DTranspose(\n",
207 | " 256, kernel_size=4, strides=2, padding=\"same\", use_bias=False\n",
208 | ")(x)\n",
209 | "x = layers.BatchNormalization(momentum=0.9)(x)\n",
210 | "x = layers.LeakyReLU(0.2)(x)\n",
211 | "x = layers.Conv2DTranspose(\n",
212 | " 128, kernel_size=4, strides=2, padding=\"same\", use_bias=False\n",
213 | ")(x)\n",
214 | "x = layers.BatchNormalization(momentum=0.9)(x)\n",
215 | "x = layers.LeakyReLU(0.2)(x)\n",
216 | "x = layers.Conv2DTranspose(\n",
217 | " 64, kernel_size=4, strides=2, padding=\"same\", use_bias=False\n",
218 | ")(x)\n",
219 | "x = layers.BatchNormalization(momentum=0.9)(x)\n",
220 | "x = layers.LeakyReLU(0.2)(x)\n",
221 | "generator_output = layers.Conv2DTranspose(\n",
222 | " CHANNELS,\n",
223 | " kernel_size=4,\n",
224 | " strides=2,\n",
225 | " padding=\"same\",\n",
226 | " use_bias=False,\n",
227 | " activation=\"tanh\",\n",
228 | ")(x)\n",
229 | "generator = models.Model(generator_input, generator_output)\n",
230 | "generator.summary()"
231 | ]
232 | },
233 | {
234 | "cell_type": "code",
235 | "execution_count": null,
236 | "id": "ed493725-488b-4390-8c64-661f3b97a632",
237 | "metadata": {
238 | "tags": []
239 | },
240 | "outputs": [],
241 | "source": [
242 | "class DCGAN(models.Model):\n",
243 | " def __init__(self, discriminator, generator, latent_dim):\n",
244 | " super(DCGAN, self).__init__()\n",
245 | " self.discriminator = discriminator\n",
246 | " self.generator = generator\n",
247 | " self.latent_dim = latent_dim\n",
248 | "\n",
249 | " def compile(self, d_optimizer, g_optimizer):\n",
250 | " super(DCGAN, self).compile()\n",
251 | " self.loss_fn = losses.BinaryCrossentropy()\n",
252 | " self.d_optimizer = d_optimizer\n",
253 | " self.g_optimizer = g_optimizer\n",
254 | " self.d_loss_metric = metrics.Mean(name=\"d_loss\")\n",
255 | " self.d_real_acc_metric = metrics.BinaryAccuracy(name=\"d_real_acc\")\n",
256 | " self.d_fake_acc_metric = metrics.BinaryAccuracy(name=\"d_fake_acc\")\n",
257 | " self.d_acc_metric = metrics.BinaryAccuracy(name=\"d_acc\")\n",
258 | " self.g_loss_metric = metrics.Mean(name=\"g_loss\")\n",
259 | " self.g_acc_metric = metrics.BinaryAccuracy(name=\"g_acc\")\n",
260 | "\n",
261 | " @property\n",
262 | " def metrics(self):\n",
263 | " return [\n",
264 | " self.d_loss_metric,\n",
265 | " self.d_real_acc_metric,\n",
266 | " self.d_fake_acc_metric,\n",
267 | " self.d_acc_metric,\n",
268 | " self.g_loss_metric,\n",
269 | " self.g_acc_metric,\n",
270 | " ]\n",
271 | "\n",
272 | " def train_step(self, real_images):\n",
273 | " # Sample random points in the latent space\n",
274 | " batch_size = tf.shape(real_images)[0]\n",
275 | " random_latent_vectors = tf.random.normal(\n",
276 | " shape=(batch_size, self.latent_dim)\n",
277 | " )\n",
278 | "\n",
279 | " # Train the discriminator on fake images\n",
280 | " with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n",
281 | " generated_images = self.generator(\n",
282 | " random_latent_vectors, training=True\n",
283 | " )\n",
284 | " real_predictions = self.discriminator(real_images, training=True)\n",
285 | " fake_predictions = self.discriminator(\n",
286 | " generated_images, training=True\n",
287 | " )\n",
288 | "\n",
289 | " real_labels = tf.ones_like(real_predictions)\n",
290 | " real_noisy_labels = real_labels + NOISE_PARAM * tf.random.uniform(\n",
291 | " tf.shape(real_predictions)\n",
292 | " )\n",
293 | " fake_labels = tf.zeros_like(fake_predictions)\n",
294 | " fake_noisy_labels = fake_labels - NOISE_PARAM * tf.random.uniform(\n",
295 | " tf.shape(fake_predictions)\n",
296 | " )\n",
297 | "\n",
298 | " d_real_loss = self.loss_fn(real_noisy_labels, real_predictions)\n",
299 | " d_fake_loss = self.loss_fn(fake_noisy_labels, fake_predictions)\n",
300 | " d_loss = (d_real_loss + d_fake_loss) / 2.0\n",
301 | "\n",
302 | " g_loss = self.loss_fn(real_labels, fake_predictions)\n",
303 | "\n",
304 | " gradients_of_discriminator = disc_tape.gradient(\n",
305 | " d_loss, self.discriminator.trainable_variables\n",
306 | " )\n",
307 | " gradients_of_generator = gen_tape.gradient(\n",
308 | " g_loss, self.generator.trainable_variables\n",
309 | " )\n",
310 | "\n",
311 | " self.d_optimizer.apply_gradients(\n",
312 | " zip(gradients_of_discriminator, discriminator.trainable_variables)\n",
313 | " )\n",
314 | " self.g_optimizer.apply_gradients(\n",
315 | " zip(gradients_of_generator, generator.trainable_variables)\n",
316 | " )\n",
317 | "\n",
318 | " # Update metrics\n",
319 | " self.d_loss_metric.update_state(d_loss)\n",
320 | " self.d_real_acc_metric.update_state(real_labels, real_predictions)\n",
321 | " self.d_fake_acc_metric.update_state(fake_labels, fake_predictions)\n",
322 | " self.d_acc_metric.update_state(\n",
323 | " [real_labels, fake_labels], [real_predictions, fake_predictions]\n",
324 | " )\n",
325 | " self.g_loss_metric.update_state(g_loss)\n",
326 | " self.g_acc_metric.update_state(real_labels, fake_predictions)\n",
327 | "\n",
328 | " return {m.name: m.result() for m in self.metrics}"
329 | ]
330 | },
331 | {
332 | "cell_type": "code",
333 | "execution_count": null,
334 | "id": "e898dd8e-f562-4517-8351-fc2f8b617a24",
335 | "metadata": {},
336 | "outputs": [],
337 | "source": [
338 | "# Create a DCGAN\n",
339 | "dcgan = DCGAN(\n",
340 | " discriminator=discriminator, generator=generator, latent_dim=Z_DIM\n",
341 | ")"
342 | ]
343 | },
344 | {
345 | "cell_type": "code",
346 | "execution_count": null,
347 | "id": "800a3c6e-fb11-4792-b6bc-9a43a7c977ad",
348 | "metadata": {},
349 | "outputs": [],
350 | "source": [
351 | "if LOAD_MODEL:\n",
352 | " dcgan.load_weights(\"./checkpoint/checkpoint.ckpt\")"
353 | ]
354 | },
355 | {
356 | "cell_type": "markdown",
357 | "id": "35b14665-4359-447b-be58-3fd58ba69084",
358 | "metadata": {},
359 | "source": [
360 | "## 3. Train the GAN "
361 | ]
362 | },
363 | {
364 | "cell_type": "code",
365 | "execution_count": null,
366 | "id": "245e6374-5f5b-4efa-be0a-07b182f82d2d",
367 | "metadata": {},
368 | "outputs": [],
369 | "source": [
370 | "dcgan.compile(\n",
371 | " d_optimizer=optimizers.Adam(\n",
372 | " learning_rate=LEARNING_RATE, beta_1=ADAM_BETA_1, beta_2=ADAM_BETA_2\n",
373 | " ),\n",
374 | " g_optimizer=optimizers.Adam(\n",
375 | " learning_rate=LEARNING_RATE, beta_1=ADAM_BETA_1, beta_2=ADAM_BETA_2\n",
376 | " ),\n",
377 | ")"
378 | ]
379 | },
380 | {
381 | "cell_type": "code",
382 | "execution_count": null,
383 | "id": "349865fe-ffbe-450e-97be-043ae1740e78",
384 | "metadata": {},
385 | "outputs": [],
386 | "source": [
387 | "# Create a model save checkpoint\n",
388 | "model_checkpoint_callback = callbacks.ModelCheckpoint(\n",
389 | " filepath=\"./checkpoint/checkpoint.ckpt\",\n",
390 | " save_weights_only=True,\n",
391 | " save_freq=\"epoch\",\n",
392 | " verbose=0,\n",
393 | ")\n",
394 | "\n",
395 | "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")\n",
396 | "\n",
397 | "\n",
398 | "class ImageGenerator(callbacks.Callback):\n",
399 | " def __init__(self, num_img, latent_dim):\n",
400 | " self.num_img = num_img\n",
401 | " self.latent_dim = latent_dim\n",
402 | "\n",
403 | " def on_epoch_end(self, epoch, logs=None):\n",
404 | " random_latent_vectors = tf.random.normal(\n",
405 | " shape=(self.num_img, self.latent_dim)\n",
406 | " )\n",
407 | " generated_images = self.model.generator(random_latent_vectors)\n",
408 | " generated_images = generated_images * 127.5 + 127.5\n",
409 | " generated_images = generated_images.numpy()\n",
410 | " display(\n",
411 | " generated_images,\n",
412 | " save_to=\"./output/generated_img_%03d.png\" % (epoch),\n",
413 | " )"
414 | ]
415 | },
416 | {
417 | "cell_type": "code",
418 | "execution_count": null,
419 | "id": "a8913a77-f472-4008-9039-dba00e6db980",
420 | "metadata": {
421 | "tags": []
422 | },
423 | "outputs": [],
424 | "source": [
425 | "dcgan.fit(\n",
426 | " train,\n",
427 | " epochs=EPOCHS,\n",
428 | " callbacks=[\n",
429 | " model_checkpoint_callback,\n",
430 | " tensorboard_callback,\n",
431 | " ImageGenerator(num_img=10, latent_dim=Z_DIM),\n",
432 | " ],\n",
433 | ")"
434 | ]
435 | },
436 | {
437 | "cell_type": "code",
438 | "execution_count": null,
439 | "id": "369bde44-2e39-4bc6-8549-a3a27ecce55c",
440 | "metadata": {},
441 | "outputs": [],
442 | "source": [
443 | "# Save the final models\n",
444 | "generator.save(\"./models/generator\")\n",
445 | "discriminator.save(\"./models/discriminator\")"
446 | ]
447 | },
448 | {
449 | "cell_type": "markdown",
450 | "id": "26999087-0e85-4ddf-ba5f-13036466fce7",
451 | "metadata": {},
452 | "source": [
453 | "## 3. Generate new images "
454 | ]
455 | },
456 | {
457 | "cell_type": "code",
458 | "execution_count": null,
459 | "id": "48e90117-2e0e-4f4b-9138-b25dce9870f6",
460 | "metadata": {},
461 | "outputs": [],
462 | "source": [
463 | "# Sample some points in the latent space, from the standard normal distribution\n",
464 | "grid_width, grid_height = (10, 3)\n",
465 | "z_sample = np.random.normal(size=(grid_width * grid_height, Z_DIM))"
466 | ]
467 | },
468 | {
469 | "cell_type": "code",
470 | "execution_count": null,
471 | "id": "2e185509-3861-425c-882d-4fe16d82d355",
472 | "metadata": {},
473 | "outputs": [],
474 | "source": [
475 | "# Decode the sampled points\n",
476 | "reconstructions = generator.predict(z_sample)"
477 | ]
478 | },
479 | {
480 | "cell_type": "code",
481 | "execution_count": null,
482 | "id": "5e5e43c0-ef06-4d32-acf6-09f00cf2fa9c",
483 | "metadata": {},
484 | "outputs": [],
485 | "source": [
486 | "# Draw a plot of decoded images\n",
487 | "fig = plt.figure(figsize=(18, 5))\n",
488 | "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n",
489 | "\n",
490 | "# Output the grid of faces\n",
491 | "for i in range(grid_width * grid_height):\n",
492 | " ax = fig.add_subplot(grid_height, grid_width, i + 1)\n",
493 | " ax.axis(\"off\")\n",
494 | " ax.imshow(reconstructions[i, :, :], cmap=\"Greys\")"
495 | ]
496 | },
497 | {
498 | "cell_type": "code",
499 | "execution_count": null,
500 | "id": "59bfd4e4-7fdc-488a-86df-2c131c904803",
501 | "metadata": {},
502 | "outputs": [],
503 | "source": [
504 | "def compare_images(img1, img2):\n",
505 | " return np.mean(np.abs(img1 - img2))"
506 | ]
507 | },
508 | {
509 | "cell_type": "code",
510 | "execution_count": null,
511 | "id": "b568995a-d4ad-478c-98b2-d9a1cdb9e841",
512 | "metadata": {},
513 | "outputs": [],
514 | "source": [
515 | "all_data = []\n",
516 | "for i in train.as_numpy_iterator():\n",
517 | " all_data.extend(i)\n",
518 | "all_data = np.array(all_data)"
519 | ]
520 | },
521 | {
522 | "cell_type": "code",
523 | "execution_count": null,
524 | "id": "7c4b5bb1-3581-49b3-81ce-920400d6f3f7",
525 | "metadata": {},
526 | "outputs": [],
527 | "source": [
528 | "r, c = 3, 5\n",
529 | "fig, axs = plt.subplots(r, c, figsize=(10, 6))\n",
530 | "fig.suptitle(\"Generated images\", fontsize=20)\n",
531 | "\n",
532 | "noise = np.random.normal(size=(r * c, Z_DIM))\n",
533 | "gen_imgs = generator.predict(noise)\n",
534 | "\n",
535 | "cnt = 0\n",
536 | "for i in range(r):\n",
537 | " for j in range(c):\n",
538 | " axs[i, j].imshow(gen_imgs[cnt], cmap=\"gray_r\")\n",
539 | " axs[i, j].axis(\"off\")\n",
540 | " cnt += 1\n",
541 | "\n",
542 | "plt.show()"
543 | ]
544 | },
545 | {
546 | "cell_type": "code",
547 | "execution_count": null,
548 | "id": "51923e98-bf0e-4de4-948a-05147c486b72",
549 | "metadata": {},
550 | "outputs": [],
551 | "source": [
552 | "fig, axs = plt.subplots(r, c, figsize=(10, 6))\n",
553 | "fig.suptitle(\"Closest images in the training set\", fontsize=20)\n",
554 | "\n",
555 | "cnt = 0\n",
556 | "for i in range(r):\n",
557 | " for j in range(c):\n",
558 | " c_diff = 99999\n",
559 | " c_img = None\n",
560 | " for k_idx, k in enumerate(all_data):\n",
561 | " diff = compare_images(gen_imgs[cnt], k)\n",
562 | " if diff < c_diff:\n",
563 | " c_img = np.copy(k)\n",
564 | " c_diff = diff\n",
565 | " axs[i, j].imshow(c_img, cmap=\"gray_r\")\n",
566 | " axs[i, j].axis(\"off\")\n",
567 | " cnt += 1\n",
568 | "\n",
569 | "plt.show()"
570 | ]
571 | }
572 | ],
573 | "metadata": {
574 | "kernelspec": {
575 | "display_name": "Python 3 (ipykernel)",
576 | "language": "python",
577 | "name": "python3"
578 | },
579 | "language_info": {
580 | "codemirror_mode": {
581 | "name": "ipython",
582 | "version": 3
583 | },
584 | "file_extension": ".py",
585 | "mimetype": "text/x-python",
586 | "name": "python",
587 | "nbconvert_exporter": "python",
588 | "pygments_lexer": "ipython3",
589 | "version": "3.8.10"
590 | }
591 | },
592 | "nbformat": 4,
593 | "nbformat_minor": 5
594 | }
595 |
--------------------------------------------------------------------------------
/notebooks/04_gan/01_dcgan/logs/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/04_gan/01_dcgan/models/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/04_gan/01_dcgan/output/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/04_gan/02_wgan_gp/checkpoint/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/04_gan/02_wgan_gp/logs/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/04_gan/02_wgan_gp/models/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/04_gan/02_wgan_gp/output/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/04_gan/02_wgan_gp/wgan_gp.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "b076bd1a-b236-4fbc-953d-8295b25122ae",
6 | "metadata": {},
7 | "source": [
8 | "# 🤪 WGAN - CelebA Faces"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "id": "9235cbd1-f136-411c-88d9-f69f270c0b96",
14 | "metadata": {},
15 | "source": [
16 | "In this notebook, we'll walk through the steps required to train your own Wasserstein GAN on the CelebA faces dataset"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "id": "098731a6-4193-4fda-9018-b14114c54250",
22 | "metadata": {},
23 | "source": [
24 | "The code has been adapted from the excellent [WGAN-GP tutorial](https://keras.io/examples/generative/wgan_gp/) created by Aakash Kumar Nain, available on the Keras website."
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": null,
30 | "id": "84acc7be-6764-4668-b2bb-178f63deeed3",
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "%load_ext autoreload\n",
35 | "%autoreload 2\n",
36 | "import numpy as np\n",
37 | "\n",
38 | "import tensorflow as tf\n",
39 | "from tensorflow.keras import (\n",
40 | " layers,\n",
41 | " models,\n",
42 | " callbacks,\n",
43 | " utils,\n",
44 | " metrics,\n",
45 | " optimizers,\n",
46 | ")\n",
47 | "\n",
48 | "from notebooks.utils import display, sample_batch"
49 | ]
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "id": "339e6268-ebd7-4feb-86db-1fe7abccdbe5",
54 | "metadata": {},
55 | "source": [
56 | "## 0. Parameters "
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "id": "1b2ee6ce-129f-4833-b0c5-fa567381c4e0",
63 | "metadata": {},
64 | "outputs": [],
65 | "source": [
66 | "IMAGE_SIZE = 64\n",
67 | "CHANNELS = 3\n",
68 | "BATCH_SIZE = 512\n",
69 | "NUM_FEATURES = 64\n",
70 | "Z_DIM = 128\n",
71 | "LEARNING_RATE = 0.0002\n",
72 | "ADAM_BETA_1 = 0.5\n",
73 | "ADAM_BETA_2 = 0.999\n",
74 | "EPOCHS = 200\n",
75 | "CRITIC_STEPS = 3\n",
76 | "GP_WEIGHT = 10.0\n",
77 | "LOAD_MODEL = False\n",
78 | "ADAM_BETA_1 = 0.5\n",
79 | "ADAM_BETA_2 = 0.9"
80 | ]
81 | },
82 | {
83 | "cell_type": "markdown",
84 | "id": "b7716fac-0010-49b0-b98e-53be2259edde",
85 | "metadata": {},
86 | "source": [
87 | "## 1. Prepare the data "
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": null,
93 | "id": "9a73e5a4-1638-411c-8d3c-29f823424458",
94 | "metadata": {},
95 | "outputs": [],
96 | "source": [
97 | "# Load the data\n",
98 | "train_data = utils.image_dataset_from_directory(\n",
99 | " \"/app/data/celeba-dataset/img_align_celeba/img_align_celeba\",\n",
100 | " labels=None,\n",
101 | " color_mode=\"rgb\",\n",
102 | " image_size=(IMAGE_SIZE, IMAGE_SIZE),\n",
103 | " batch_size=BATCH_SIZE,\n",
104 | " shuffle=True,\n",
105 | " seed=42,\n",
106 | " interpolation=\"bilinear\",\n",
107 | ")"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": null,
113 | "id": "ebae2f0d-59fd-4796-841f-7213eae638de",
114 | "metadata": {},
115 | "outputs": [],
116 | "source": [
117 | "# Preprocess the data\n",
118 | "def preprocess(img):\n",
119 | " \"\"\"\n",
120 | " Normalize and reshape the images\n",
121 | " \"\"\"\n",
122 | " img = (tf.cast(img, \"float32\") - 127.5) / 127.5\n",
123 | " return img\n",
124 | "\n",
125 | "\n",
126 | "train = train_data.map(lambda x: preprocess(x))"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": null,
132 | "id": "fa53709f-7f3f-483b-9db8-2e5f9b9942c2",
133 | "metadata": {},
134 | "outputs": [],
135 | "source": [
136 | "# Show some faces from the training set\n",
137 | "train_sample = sample_batch(train)"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": null,
143 | "id": "b86c15ef-82b2-4a75-99f7-2d8810440403",
144 | "metadata": {},
145 | "outputs": [],
146 | "source": [
147 | "display(train_sample, cmap=None)"
148 | ]
149 | },
150 | {
151 | "cell_type": "markdown",
152 | "id": "aff50401-3abe-4c10-bba8-b35bc13ad7d5",
153 | "metadata": {
154 | "tags": []
155 | },
156 | "source": [
157 | "## 2. Build the WGAN-GP "
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": null,
163 | "id": "371eb69d-e534-4666-a412-b5b6fe24689a",
164 | "metadata": {},
165 | "outputs": [],
166 | "source": [
167 | "critic_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS))\n",
168 | "x = layers.Conv2D(64, kernel_size=4, strides=2, padding=\"same\")(critic_input)\n",
169 | "x = layers.LeakyReLU(0.2)(x)\n",
170 | "x = layers.Conv2D(128, kernel_size=4, strides=2, padding=\"same\")(x)\n",
171 | "x = layers.LeakyReLU()(x)\n",
172 | "x = layers.Dropout(0.3)(x)\n",
173 | "x = layers.Conv2D(256, kernel_size=4, strides=2, padding=\"same\")(x)\n",
174 | "x = layers.LeakyReLU(0.2)(x)\n",
175 | "x = layers.Dropout(0.3)(x)\n",
176 | "x = layers.Conv2D(512, kernel_size=4, strides=2, padding=\"same\")(x)\n",
177 | "x = layers.LeakyReLU(0.2)(x)\n",
178 | "x = layers.Dropout(0.3)(x)\n",
179 | "x = layers.Conv2D(1, kernel_size=4, strides=1, padding=\"valid\")(x)\n",
180 | "critic_output = layers.Flatten()(x)\n",
181 | "\n",
182 | "critic = models.Model(critic_input, critic_output)\n",
183 | "critic.summary()"
184 | ]
185 | },
186 | {
187 | "cell_type": "code",
188 | "execution_count": null,
189 | "id": "086e2584-c60d-4990-89f4-2092c44e023e",
190 | "metadata": {},
191 | "outputs": [],
192 | "source": [
193 | "generator_input = layers.Input(shape=(Z_DIM,))\n",
194 | "x = layers.Reshape((1, 1, Z_DIM))(generator_input)\n",
195 | "x = layers.Conv2DTranspose(\n",
196 | " 512, kernel_size=4, strides=1, padding=\"valid\", use_bias=False\n",
197 | ")(x)\n",
198 | "x = layers.BatchNormalization(momentum=0.9)(x)\n",
199 | "x = layers.LeakyReLU(0.2)(x)\n",
200 | "x = layers.Conv2DTranspose(\n",
201 | " 256, kernel_size=4, strides=2, padding=\"same\", use_bias=False\n",
202 | ")(x)\n",
203 | "x = layers.BatchNormalization(momentum=0.9)(x)\n",
204 | "x = layers.LeakyReLU(0.2)(x)\n",
205 | "x = layers.Conv2DTranspose(\n",
206 | " 128, kernel_size=4, strides=2, padding=\"same\", use_bias=False\n",
207 | ")(x)\n",
208 | "x = layers.BatchNormalization(momentum=0.9)(x)\n",
209 | "x = layers.LeakyReLU(0.2)(x)\n",
210 | "x = layers.Conv2DTranspose(\n",
211 | " 64, kernel_size=4, strides=2, padding=\"same\", use_bias=False\n",
212 | ")(x)\n",
213 | "x = layers.BatchNormalization(momentum=0.9)(x)\n",
214 | "x = layers.LeakyReLU(0.2)(x)\n",
215 | "generator_output = layers.Conv2DTranspose(\n",
216 | " CHANNELS, kernel_size=4, strides=2, padding=\"same\", activation=\"tanh\"\n",
217 | ")(x)\n",
218 | "generator = models.Model(generator_input, generator_output)\n",
219 | "generator.summary()"
220 | ]
221 | },
222 | {
223 | "cell_type": "code",
224 | "execution_count": null,
225 | "id": "88010f20-fb61-498c-b2b2-dac96f6c03b3",
226 | "metadata": {},
227 | "outputs": [],
228 | "source": [
229 | "class WGANGP(models.Model):\n",
230 | " def __init__(self, critic, generator, latent_dim, critic_steps, gp_weight):\n",
231 | " super(WGANGP, self).__init__()\n",
232 | " self.critic = critic\n",
233 | " self.generator = generator\n",
234 | " self.latent_dim = latent_dim\n",
235 | " self.critic_steps = critic_steps\n",
236 | " self.gp_weight = gp_weight\n",
237 | "\n",
238 | " def compile(self, c_optimizer, g_optimizer):\n",
239 | " super(WGANGP, self).compile()\n",
240 | " self.c_optimizer = c_optimizer\n",
241 | " self.g_optimizer = g_optimizer\n",
242 | " self.c_wass_loss_metric = metrics.Mean(name=\"c_wass_loss\")\n",
243 | " self.c_gp_metric = metrics.Mean(name=\"c_gp\")\n",
244 | " self.c_loss_metric = metrics.Mean(name=\"c_loss\")\n",
245 | " self.g_loss_metric = metrics.Mean(name=\"g_loss\")\n",
246 | "\n",
247 | " @property\n",
248 | " def metrics(self):\n",
249 | " return [\n",
250 | " self.c_loss_metric,\n",
251 | " self.c_wass_loss_metric,\n",
252 | " self.c_gp_metric,\n",
253 | " self.g_loss_metric,\n",
254 | " ]\n",
255 | "\n",
256 | " def gradient_penalty(self, batch_size, real_images, fake_images):\n",
257 | " alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)\n",
258 | " diff = fake_images - real_images\n",
259 | " interpolated = real_images + alpha * diff\n",
260 | "\n",
261 | " with tf.GradientTape() as gp_tape:\n",
262 | " gp_tape.watch(interpolated)\n",
263 | " pred = self.critic(interpolated, training=True)\n",
264 | "\n",
265 | " grads = gp_tape.gradient(pred, [interpolated])[0]\n",
266 | " norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))\n",
267 | " gp = tf.reduce_mean((norm - 1.0) ** 2)\n",
268 | " return gp\n",
269 | "\n",
270 | " def train_step(self, real_images):\n",
271 | " batch_size = tf.shape(real_images)[0]\n",
272 | "\n",
273 | " for i in range(self.critic_steps):\n",
274 | " random_latent_vectors = tf.random.normal(\n",
275 | " shape=(batch_size, self.latent_dim)\n",
276 | " )\n",
277 | "\n",
278 | " with tf.GradientTape() as tape:\n",
279 | " fake_images = self.generator(\n",
280 | " random_latent_vectors, training=True\n",
281 | " )\n",
282 | " fake_predictions = self.critic(fake_images, training=True)\n",
283 | " real_predictions = self.critic(real_images, training=True)\n",
284 | "\n",
285 | " c_wass_loss = tf.reduce_mean(fake_predictions) - tf.reduce_mean(\n",
286 | " real_predictions\n",
287 | " )\n",
288 | " c_gp = self.gradient_penalty(\n",
289 | " batch_size, real_images, fake_images\n",
290 | " )\n",
291 | " c_loss = c_wass_loss + c_gp * self.gp_weight\n",
292 | "\n",
293 | " c_gradient = tape.gradient(c_loss, self.critic.trainable_variables)\n",
294 | " self.c_optimizer.apply_gradients(\n",
295 | " zip(c_gradient, self.critic.trainable_variables)\n",
296 | " )\n",
297 | "\n",
298 | " random_latent_vectors = tf.random.normal(\n",
299 | " shape=(batch_size, self.latent_dim)\n",
300 | " )\n",
301 | " with tf.GradientTape() as tape:\n",
302 | " fake_images = self.generator(random_latent_vectors, training=True)\n",
303 | " fake_predictions = self.critic(fake_images, training=True)\n",
304 | " g_loss = -tf.reduce_mean(fake_predictions)\n",
305 | "\n",
306 | " gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)\n",
307 | " self.g_optimizer.apply_gradients(\n",
308 | " zip(gen_gradient, self.generator.trainable_variables)\n",
309 | " )\n",
310 | "\n",
311 | " self.c_loss_metric.update_state(c_loss)\n",
312 | " self.c_wass_loss_metric.update_state(c_wass_loss)\n",
313 | " self.c_gp_metric.update_state(c_gp)\n",
314 | " self.g_loss_metric.update_state(g_loss)\n",
315 | "\n",
316 | " return {m.name: m.result() for m in self.metrics}"
317 | ]
318 | },
319 | {
320 | "cell_type": "code",
321 | "execution_count": null,
322 | "id": "edf2f892-9209-42ee-b251-1e7604df5335",
323 | "metadata": {},
324 | "outputs": [],
325 | "source": [
326 | "# Create a GAN\n",
327 | "wgangp = WGANGP(\n",
328 | " critic=critic,\n",
329 | " generator=generator,\n",
330 | " latent_dim=Z_DIM,\n",
331 | " critic_steps=CRITIC_STEPS,\n",
332 | " gp_weight=GP_WEIGHT,\n",
333 | ")"
334 | ]
335 | },
336 | {
337 | "cell_type": "code",
338 | "execution_count": null,
339 | "id": "b2f48907-fa82-41b5-8caa-813b2f232c79",
340 | "metadata": {
341 | "tags": []
342 | },
343 | "outputs": [],
344 | "source": [
345 | "if LOAD_MODEL:\n",
346 | " wgangp.load_weights(\"./checkpoint/checkpoint.ckpt\")"
347 | ]
348 | },
349 | {
350 | "cell_type": "markdown",
351 | "id": "35b14665-4359-447b-be58-3fd58ba69084",
352 | "metadata": {},
353 | "source": [
354 | "## 3. Train the GAN "
355 | ]
356 | },
357 | {
358 | "cell_type": "code",
359 | "execution_count": null,
360 | "id": "b429fdad-ea9c-45a2-a556-eb950d793824",
361 | "metadata": {},
362 | "outputs": [],
363 | "source": [
364 | "# Compile the GAN\n",
365 | "wgangp.compile(\n",
366 | " c_optimizer=optimizers.Adam(\n",
367 | " learning_rate=LEARNING_RATE, beta_1=ADAM_BETA_1, beta_2=ADAM_BETA_2\n",
368 | " ),\n",
369 | " g_optimizer=optimizers.Adam(\n",
370 | " learning_rate=LEARNING_RATE, beta_1=ADAM_BETA_1, beta_2=ADAM_BETA_2\n",
371 | " ),\n",
372 | ")"
373 | ]
374 | },
375 | {
376 | "cell_type": "code",
377 | "execution_count": null,
378 | "id": "c525e44b-b3bb-489c-9d35-fcfe3e714e6a",
379 | "metadata": {},
380 | "outputs": [],
381 | "source": [
382 | "# Create a model save checkpoint\n",
383 | "model_checkpoint_callback = callbacks.ModelCheckpoint(\n",
384 | " filepath=\"./checkpoint/checkpoint.ckpt\",\n",
385 | " save_weights_only=True,\n",
386 | " save_freq=\"epoch\",\n",
387 | " verbose=0,\n",
388 | ")\n",
389 | "\n",
390 | "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")\n",
391 | "\n",
392 | "\n",
393 | "class ImageGenerator(callbacks.Callback):\n",
394 | " def __init__(self, num_img, latent_dim):\n",
395 | " self.num_img = num_img\n",
396 | " self.latent_dim = latent_dim\n",
397 | "\n",
398 | " def on_epoch_end(self, epoch, logs=None):\n",
399 | " random_latent_vectors = tf.random.normal(\n",
400 | " shape=(self.num_img, self.latent_dim)\n",
401 | " )\n",
402 | " generated_images = self.model.generator(random_latent_vectors)\n",
403 | " generated_images = generated_images * 127.5 + 127.5\n",
404 | " generated_images = generated_images.numpy()\n",
405 | " display(\n",
406 | " generated_images,\n",
407 | " save_to=\"./output/generated_img_%03d.png\" % (epoch),\n",
408 | " cmap=None,\n",
409 | " )"
410 | ]
411 | },
412 | {
413 | "cell_type": "code",
414 | "execution_count": null,
415 | "id": "b3c497b7-fa40-48df-b2bf-541239cc9400",
416 | "metadata": {
417 | "tags": []
418 | },
419 | "outputs": [],
420 | "source": [
421 | "wgangp.fit(\n",
422 | " train,\n",
423 | " epochs=EPOCHS,\n",
424 | " steps_per_epoch=2,\n",
425 | " callbacks=[\n",
426 | " model_checkpoint_callback,\n",
427 | " tensorboard_callback,\n",
428 | " ImageGenerator(num_img=10, latent_dim=Z_DIM),\n",
429 | " ],\n",
430 | ")"
431 | ]
432 | },
433 | {
434 | "cell_type": "code",
435 | "execution_count": null,
436 | "id": "028138af-d3a5-4134-b980-d3a8a703e70f",
437 | "metadata": {},
438 | "outputs": [],
439 | "source": [
440 | "# Save the final models\n",
441 | "generator.save(\"./models/generator\")\n",
442 | "critic.save(\"./models/critic\")"
443 | ]
444 | },
445 | {
446 | "cell_type": "markdown",
447 | "id": "0765b66b-d12c-42c4-90fa-2ff851a9b3f5",
448 | "metadata": {},
449 | "source": [
450 | "## Generate images"
451 | ]
452 | },
453 | {
454 | "cell_type": "code",
455 | "execution_count": null,
456 | "id": "86576e84-afc4-443a-b68d-9a5ee13ce730",
457 | "metadata": {},
458 | "outputs": [],
459 | "source": [
460 | "z_sample = np.random.normal(size=(10, Z_DIM))\n",
461 | "imgs = wgangp.generator.predict(z_sample)\n",
462 | "display(imgs, cmap=None)"
463 | ]
464 | }
465 | ],
466 | "metadata": {
467 | "kernelspec": {
468 | "display_name": "Python 3 (ipykernel)",
469 | "language": "python",
470 | "name": "python3"
471 | },
472 | "language_info": {
473 | "codemirror_mode": {
474 | "name": "ipython",
475 | "version": 3
476 | },
477 | "file_extension": ".py",
478 | "mimetype": "text/x-python",
479 | "name": "python",
480 | "nbconvert_exporter": "python",
481 | "pygments_lexer": "ipython3",
482 | "version": "3.8.10"
483 | },
484 | "vscode": {
485 | "interpreter": {
486 | "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
487 | }
488 | }
489 | },
490 | "nbformat": 4,
491 | "nbformat_minor": 5
492 | }
493 |
--------------------------------------------------------------------------------
/notebooks/04_gan/03_cgan/checkpoint/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/04_gan/03_cgan/logs/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/04_gan/03_cgan/models/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/04_gan/03_cgan/output/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/05_autoregressive/01_lstm/checkpoint/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/05_autoregressive/01_lstm/logs/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/05_autoregressive/01_lstm/lstm.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "b076bd1a-b236-4fbc-953d-8295b25122ae",
6 | "metadata": {},
7 | "source": [
8 | "# 🥙 LSTM on Recipe Data"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "id": "658a95da-9645-4bcf-bd9d-4b95a4b6f582",
14 | "metadata": {},
15 | "source": [
16 | "In this notebook, we'll walk through the steps required to train your own LSTM on the recipes dataset"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "id": "4e0d56cc-4773-4029-97d8-26f882ba79c9",
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "%load_ext autoreload\n",
27 | "%autoreload 2\n",
28 | "\n",
29 | "import numpy as np\n",
30 | "import json\n",
31 | "import re\n",
32 | "import string\n",
33 | "\n",
34 | "import tensorflow as tf\n",
35 | "from tensorflow.keras import layers, models, callbacks, losses"
36 | ]
37 | },
38 | {
39 | "cell_type": "markdown",
40 | "id": "339e6268-ebd7-4feb-86db-1fe7abccdbe5",
41 | "metadata": {},
42 | "source": [
43 | "## 0. Parameters "
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": null,
49 | "id": "2d8352af-343e-4c2e-8c91-95f8bac1c8a1",
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "VOCAB_SIZE = 10000\n",
54 | "MAX_LEN = 200\n",
55 | "EMBEDDING_DIM = 100\n",
56 | "N_UNITS = 128\n",
57 | "VALIDATION_SPLIT = 0.2\n",
58 | "SEED = 42\n",
59 | "LOAD_MODEL = False\n",
60 | "BATCH_SIZE = 32\n",
61 | "EPOCHS = 25"
62 | ]
63 | },
64 | {
65 | "cell_type": "markdown",
66 | "id": "b7716fac-0010-49b0-b98e-53be2259edde",
67 | "metadata": {},
68 | "source": [
69 | "## 1. Load the data "
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": null,
75 | "id": "93cf6b0f-9667-4146-8911-763a8a2925d3",
76 | "metadata": {
77 | "tags": []
78 | },
79 | "outputs": [],
80 | "source": [
81 | "# Load the full dataset\n",
82 | "with open(\"/app/data/epirecipes/full_format_recipes.json\") as json_data:\n",
83 | " recipe_data = json.load(json_data)"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": null,
89 | "id": "23a74eca-f1b7-4a46-9a1f-b5806a4ed361",
90 | "metadata": {
91 | "tags": []
92 | },
93 | "outputs": [],
94 | "source": [
95 | "# Filter the dataset\n",
96 | "filtered_data = [\n",
97 | " \"Recipe for \" + x[\"title\"] + \" | \" + \" \".join(x[\"directions\"])\n",
98 | " for x in recipe_data\n",
99 | " if \"title\" in x\n",
100 | " and x[\"title\"] is not None\n",
101 | " and \"directions\" in x\n",
102 | " and x[\"directions\"] is not None\n",
103 | "]"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": null,
109 | "id": "389c20de-0422-4c48-a7b4-6ee12a7bf0e2",
110 | "metadata": {
111 | "tags": []
112 | },
113 | "outputs": [],
114 | "source": [
115 | "# Count the recipes\n",
116 | "n_recipes = len(filtered_data)\n",
117 | "print(f\"{n_recipes} recipes loaded\")"
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "execution_count": null,
123 | "id": "1b2e3cf7-e416-460e-874a-0dd9637bca36",
124 | "metadata": {},
125 | "outputs": [],
126 | "source": [
127 | "example = filtered_data[9]\n",
128 | "print(example)"
129 | ]
130 | },
131 | {
132 | "cell_type": "markdown",
133 | "id": "3f871aaf-d873-41c7-8946-e4eef7ac17c1",
134 | "metadata": {},
135 | "source": [
136 | "## 2. Tokenise the data"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": null,
142 | "id": "5b2064fb-5dcc-4657-b470-0928d10e2ddc",
143 | "metadata": {
144 | "tags": []
145 | },
146 | "outputs": [],
147 | "source": [
148 | "# Pad the punctuation, to treat them as separate 'words'\n",
149 | "def pad_punctuation(s):\n",
150 | " s = re.sub(f\"([{string.punctuation}])\", r\" \\1 \", s)\n",
151 | " s = re.sub(\" +\", \" \", s)\n",
152 | " return s\n",
153 | "\n",
154 | "\n",
155 | "text_data = [pad_punctuation(x) for x in filtered_data]"
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": null,
161 | "id": "b87d7c65-9a46-492a-a5c0-a043b0d252f3",
162 | "metadata": {},
163 | "outputs": [],
164 | "source": [
165 | "# Display an example of a recipe\n",
166 | "example_data = text_data[9]\n",
167 | "example_data"
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": null,
173 | "id": "9834f916-b21a-4104-acc9-f28d3bd7a8c1",
174 | "metadata": {
175 | "tags": []
176 | },
177 | "outputs": [],
178 | "source": [
179 | "# Convert to a Tensorflow Dataset\n",
180 | "text_ds = (\n",
181 | " tf.data.Dataset.from_tensor_slices(text_data)\n",
182 | " .batch(BATCH_SIZE)\n",
183 | " .shuffle(1000)\n",
184 | ")"
185 | ]
186 | },
187 | {
188 | "cell_type": "code",
189 | "execution_count": null,
190 | "id": "884c0bcb-0807-45a1-8f7e-a32f2c6fa4de",
191 | "metadata": {},
192 | "outputs": [],
193 | "source": [
194 | "# Create a vectorisation layer\n",
195 | "vectorize_layer = layers.TextVectorization(\n",
196 | " standardize=\"lower\",\n",
197 | " max_tokens=VOCAB_SIZE,\n",
198 | " output_mode=\"int\",\n",
199 | " output_sequence_length=MAX_LEN + 1,\n",
200 | ")"
201 | ]
202 | },
203 | {
204 | "cell_type": "code",
205 | "execution_count": null,
206 | "id": "4d6dd34a-d905-497b-926a-405380ebcf98",
207 | "metadata": {},
208 | "outputs": [],
209 | "source": [
210 | "# Adapt the layer to the training set\n",
211 | "vectorize_layer.adapt(text_ds)\n",
212 | "vocab = vectorize_layer.get_vocabulary()"
213 | ]
214 | },
215 | {
216 | "cell_type": "code",
217 | "execution_count": null,
218 | "id": "f6c1c7ce-3cf0-40d4-a3dc-ab7090f69f2f",
219 | "metadata": {},
220 | "outputs": [],
221 | "source": [
222 | "# Display some token:word mappings\n",
223 | "for i, word in enumerate(vocab[:10]):\n",
224 | " print(f\"{i}: {word}\")"
225 | ]
226 | },
227 | {
228 | "cell_type": "code",
229 | "execution_count": null,
230 | "id": "1cc30186-7ec6-4eb6-b29a-65df6714d321",
231 | "metadata": {},
232 | "outputs": [],
233 | "source": [
234 | "# Display the same example converted to ints\n",
235 | "example_tokenised = vectorize_layer(example_data)\n",
236 | "print(example_tokenised.numpy())"
237 | ]
238 | },
239 | {
240 | "cell_type": "markdown",
241 | "id": "8c195efb-84c6-4be0-a989-a7542188ad35",
242 | "metadata": {},
243 | "source": [
244 | "## 3. Create the Training Set"
245 | ]
246 | },
247 | {
248 | "cell_type": "code",
249 | "execution_count": null,
250 | "id": "740294a1-1a6b-4c89-92f2-036d7d1b788b",
251 | "metadata": {},
252 | "outputs": [],
253 | "source": [
254 | "# Create the training set of recipes and the same text shifted by one word\n",
255 | "def prepare_inputs(text):\n",
256 | " text = tf.expand_dims(text, -1)\n",
257 | " tokenized_sentences = vectorize_layer(text)\n",
258 | " x = tokenized_sentences[:, :-1]\n",
259 | " y = tokenized_sentences[:, 1:]\n",
260 | " return x, y\n",
261 | "\n",
262 | "\n",
263 | "train_ds = text_ds.map(prepare_inputs)"
264 | ]
265 | },
266 | {
267 | "cell_type": "markdown",
268 | "id": "aff50401-3abe-4c10-bba8-b35bc13ad7d5",
269 | "metadata": {
270 | "tags": []
271 | },
272 | "source": [
273 | "## 4. Build the LSTM "
274 | ]
275 | },
276 | {
277 | "cell_type": "code",
278 | "execution_count": null,
279 | "id": "9230b5bf-b4a8-48d5-b73b-6899a598f296",
280 | "metadata": {},
281 | "outputs": [],
282 | "source": [
283 | "inputs = layers.Input(shape=(None,), dtype=\"int32\")\n",
284 | "x = layers.Embedding(VOCAB_SIZE, EMBEDDING_DIM)(inputs)\n",
285 | "x = layers.LSTM(N_UNITS, return_sequences=True)(x)\n",
286 | "outputs = layers.Dense(VOCAB_SIZE, activation=\"softmax\")(x)\n",
287 | "lstm = models.Model(inputs, outputs)\n",
288 | "lstm.summary()"
289 | ]
290 | },
291 | {
292 | "cell_type": "code",
293 | "execution_count": null,
294 | "id": "800a3c6e-fb11-4792-b6bc-9a43a7c977ad",
295 | "metadata": {
296 | "tags": []
297 | },
298 | "outputs": [],
299 | "source": [
300 | "if LOAD_MODEL:\n",
301 | " # model.load_weights('./models/model')\n",
302 | " lstm = models.load_model(\"./models/lstm\", compile=False)"
303 | ]
304 | },
305 | {
306 | "cell_type": "markdown",
307 | "id": "35b14665-4359-447b-be58-3fd58ba69084",
308 | "metadata": {},
309 | "source": [
310 | "## 5. Train the LSTM "
311 | ]
312 | },
313 | {
314 | "cell_type": "code",
315 | "execution_count": null,
316 | "id": "ffb1bd3b-6fd9-4536-973e-6375bbcbf16d",
317 | "metadata": {},
318 | "outputs": [],
319 | "source": [
320 | "loss_fn = losses.SparseCategoricalCrossentropy()\n",
321 | "lstm.compile(\"adam\", loss_fn)"
322 | ]
323 | },
324 | {
325 | "cell_type": "code",
326 | "execution_count": null,
327 | "id": "3ddcff5f-829d-4449-99d2-9a3cb68f7d72",
328 | "metadata": {},
329 | "outputs": [],
330 | "source": [
331 | "# Create a TextGenerator checkpoint\n",
332 | "class TextGenerator(callbacks.Callback):\n",
333 | " def __init__(self, index_to_word, top_k=10):\n",
334 | " self.index_to_word = index_to_word\n",
335 | " self.word_to_index = {\n",
336 | " word: index for index, word in enumerate(index_to_word)\n",
337 | " } # <1>\n",
338 | "\n",
339 | " def sample_from(self, probs, temperature): # <2>\n",
340 | " probs = probs ** (1 / temperature)\n",
341 | " probs = probs / np.sum(probs)\n",
342 | " return np.random.choice(len(probs), p=probs), probs\n",
343 | "\n",
344 | " def generate(self, start_prompt, max_tokens, temperature):\n",
345 | " start_tokens = [\n",
346 | " self.word_to_index.get(x, 1) for x in start_prompt.split()\n",
347 | " ] # <3>\n",
348 | " sample_token = None\n",
349 | " info = []\n",
350 | " while len(start_tokens) < max_tokens and sample_token != 0: # <4>\n",
351 | " x = np.array([start_tokens])\n",
352 | " y = self.model.predict(x, verbose=0) # <5>\n",
353 | " sample_token, probs = self.sample_from(y[0][-1], temperature) # <6>\n",
354 | " info.append({\"prompt\": start_prompt, \"word_probs\": probs})\n",
355 | " start_tokens.append(sample_token) # <7>\n",
356 | " start_prompt = start_prompt + \" \" + self.index_to_word[sample_token]\n",
357 | " print(f\"\\ngenerated text:\\n{start_prompt}\\n\")\n",
358 | " return info\n",
359 | "\n",
360 | " def on_epoch_end(self, epoch, logs=None):\n",
361 | " self.generate(\"recipe for\", max_tokens=100, temperature=1.0)"
362 | ]
363 | },
364 | {
365 | "cell_type": "code",
366 | "execution_count": null,
367 | "id": "349865fe-ffbe-450e-97be-043ae1740e78",
368 | "metadata": {},
369 | "outputs": [],
370 | "source": [
371 | "# Create a model save checkpoint\n",
372 | "model_checkpoint_callback = callbacks.ModelCheckpoint(\n",
373 | " filepath=\"./checkpoint/checkpoint.ckpt\",\n",
374 | " save_weights_only=True,\n",
375 | " save_freq=\"epoch\",\n",
376 | " verbose=0,\n",
377 | ")\n",
378 | "\n",
379 | "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")\n",
380 | "\n",
381 | "# Tokenize starting prompt\n",
382 | "text_generator = TextGenerator(vocab)"
383 | ]
384 | },
385 | {
386 | "cell_type": "code",
387 | "execution_count": null,
388 | "id": "461c2b3e-b5ae-4def-8bd9-e7bab8c63d8e",
389 | "metadata": {
390 | "tags": []
391 | },
392 | "outputs": [],
393 | "source": [
394 | "lstm.fit(\n",
395 | " train_ds,\n",
396 | " epochs=EPOCHS,\n",
397 | " callbacks=[model_checkpoint_callback, tensorboard_callback, text_generator],\n",
398 | ")"
399 | ]
400 | },
401 | {
402 | "cell_type": "code",
403 | "execution_count": null,
404 | "id": "369bde44-2e39-4bc6-8549-a3a27ecce55c",
405 | "metadata": {
406 | "tags": []
407 | },
408 | "outputs": [],
409 | "source": [
410 | "# Save the final model\n",
411 | "lstm.save(\"./models/lstm\")"
412 | ]
413 | },
414 | {
415 | "cell_type": "markdown",
416 | "id": "d64e02d2-84dc-40c8-8446-40c09adf1e20",
417 | "metadata": {},
418 | "source": [
419 | "## 6. Generate text using the LSTM"
420 | ]
421 | },
422 | {
423 | "cell_type": "code",
424 | "execution_count": null,
425 | "id": "4ad23adb-3ec9-4e9a-9a59-b9f9bafca649",
426 | "metadata": {},
427 | "outputs": [],
428 | "source": [
429 | "def print_probs(info, vocab, top_k=5):\n",
430 | " for i in info:\n",
431 | " print(f\"\\nPROMPT: {i['prompt']}\")\n",
432 | " word_probs = i[\"word_probs\"]\n",
433 | " p_sorted = np.sort(word_probs)[::-1][:top_k]\n",
434 | " i_sorted = np.argsort(word_probs)[::-1][:top_k]\n",
435 | " for p, i in zip(p_sorted, i_sorted):\n",
436 | " print(f\"{vocab[i]}: \\t{np.round(100*p,2)}%\")\n",
437 | " print(\"--------\\n\")"
438 | ]
439 | },
440 | {
441 | "cell_type": "code",
442 | "execution_count": null,
443 | "id": "3cf25578-d47c-4b26-8252-fcdf2316a4ac",
444 | "metadata": {},
445 | "outputs": [],
446 | "source": [
447 | "info = text_generator.generate(\n",
448 | " \"recipe for roasted vegetables | chop 1 /\", max_tokens=10, temperature=1.0\n",
449 | ")"
450 | ]
451 | },
452 | {
453 | "cell_type": "code",
454 | "execution_count": null,
455 | "id": "9df72866-b483-4489-8e26-d5e1466410fa",
456 | "metadata": {
457 | "tags": []
458 | },
459 | "outputs": [],
460 | "source": [
461 | "print_probs(info, vocab)"
462 | ]
463 | },
464 | {
465 | "cell_type": "code",
466 | "execution_count": null,
467 | "id": "562e1fe8-cbcb-438f-9637-2f2a6279c924",
468 | "metadata": {},
469 | "outputs": [],
470 | "source": [
471 | "info = text_generator.generate(\n",
472 | " \"recipe for roasted vegetables | chop 1 /\", max_tokens=10, temperature=0.2\n",
473 | ")"
474 | ]
475 | },
476 | {
477 | "cell_type": "code",
478 | "execution_count": null,
479 | "id": "56356f21-04ac-40e5-94ff-291eca6a7054",
480 | "metadata": {},
481 | "outputs": [],
482 | "source": [
483 | "print_probs(info, vocab)"
484 | ]
485 | },
486 | {
487 | "cell_type": "code",
488 | "execution_count": null,
489 | "id": "2e434497-07f3-4989-a68d-3e31cf8fa4fe",
490 | "metadata": {},
491 | "outputs": [],
492 | "source": [
493 | "info = text_generator.generate(\n",
494 | " \"recipe for chocolate ice cream |\", max_tokens=7, temperature=1.0\n",
495 | ")\n",
496 | "print_probs(info, vocab)"
497 | ]
498 | },
499 | {
500 | "cell_type": "code",
501 | "execution_count": null,
502 | "id": "011cd0e0-956c-4a63-8ec3-f7dfed31764e",
503 | "metadata": {},
504 | "outputs": [],
505 | "source": [
506 | "info = text_generator.generate(\n",
507 | " \"recipe for chocolate ice cream |\", max_tokens=7, temperature=0.2\n",
508 | ")\n",
509 | "print_probs(info, vocab)"
510 | ]
511 | }
512 | ],
513 | "metadata": {
514 | "kernelspec": {
515 | "display_name": "Python 3 (ipykernel)",
516 | "language": "python",
517 | "name": "python3"
518 | },
519 | "language_info": {
520 | "codemirror_mode": {
521 | "name": "ipython",
522 | "version": 3
523 | },
524 | "file_extension": ".py",
525 | "mimetype": "text/x-python",
526 | "name": "python",
527 | "nbconvert_exporter": "python",
528 | "pygments_lexer": "ipython3",
529 | "version": "3.8.10"
530 | }
531 | },
532 | "nbformat": 4,
533 | "nbformat_minor": 5
534 | }
535 |
--------------------------------------------------------------------------------
/notebooks/05_autoregressive/01_lstm/models/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/05_autoregressive/01_lstm/output/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/05_autoregressive/02_pixelcnn/checkpoint/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/05_autoregressive/02_pixelcnn/logs/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/05_autoregressive/02_pixelcnn/models/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/05_autoregressive/02_pixelcnn/output/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/05_autoregressive/02_pixelcnn/pixelcnn.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "206a93a2-e9b0-4ea2-a43f-696faa83ea03",
6 | "metadata": {},
7 | "source": [
8 | "# 👾 PixelCNN from scratch"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "id": "1af9e216-7e84-4f5b-a2db-26aca3bea464",
14 | "metadata": {},
15 | "source": [
16 | "In this notebook, we'll walk through the steps required to train your own PixelCNN on the fashion MNIST dataset from scratch"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "id": "9c1e6bbc-6f3b-48ac-a4f3-fde6f739f0ca",
22 | "metadata": {},
23 | "source": [
24 | "The code has been adapted from the excellent [PixelCNN tutorial](https://keras.io/examples/generative/pixelcnn/) created by ADMoreau, available on the Keras website."
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": null,
30 | "id": "6acebfa8-4546-41fd-adaa-2307c65b1b8e",
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "%load_ext autoreload\n",
35 | "%autoreload 2\n",
36 | "import numpy as np\n",
37 | "\n",
38 | "import tensorflow as tf\n",
39 | "from tensorflow.keras import datasets, layers, models, optimizers, callbacks\n",
40 | "\n",
41 | "from notebooks.utils import display"
42 | ]
43 | },
44 | {
45 | "cell_type": "markdown",
46 | "id": "8543166d-f4c7-43f8-a452-21ccbf2a0496",
47 | "metadata": {},
48 | "source": [
49 | "## 0. Parameters "
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": null,
55 | "id": "444d84de-2843-40d6-8e2e-93691a5393ab",
56 | "metadata": {},
57 | "outputs": [],
58 | "source": [
59 | "IMAGE_SIZE = 16\n",
60 | "PIXEL_LEVELS = 4\n",
61 | "N_FILTERS = 128\n",
62 | "RESIDUAL_BLOCKS = 5\n",
63 | "BATCH_SIZE = 128\n",
64 | "EPOCHS = 150"
65 | ]
66 | },
67 | {
68 | "cell_type": "markdown",
69 | "id": "d65dac68-d20b-4ed9-a136-eed57095ce4f",
70 | "metadata": {},
71 | "source": [
72 | "## 1. Prepare the data "
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": null,
78 | "id": "0ed0fc56-d1b0-4d42-b029-f4198f78e666",
79 | "metadata": {},
80 | "outputs": [],
81 | "source": [
82 | "# Load the data\n",
83 | "(x_train, _), (_, _) = datasets.fashion_mnist.load_data()"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": null,
89 | "id": "b667e78c-8fa7-4e5b-a2c0-69e50166ef77",
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "# Preprocess the data\n",
94 | "def preprocess(imgs_int):\n",
95 | " imgs_int = np.expand_dims(imgs_int, -1)\n",
96 | " imgs_int = tf.image.resize(imgs_int, (IMAGE_SIZE, IMAGE_SIZE)).numpy()\n",
97 | " imgs_int = (imgs_int / (256 / PIXEL_LEVELS)).astype(int)\n",
98 | " imgs = imgs_int.astype(\"float32\")\n",
99 | " imgs = imgs / PIXEL_LEVELS\n",
100 | " return imgs, imgs_int\n",
101 | "\n",
102 | "\n",
103 | "input_data, output_data = preprocess(x_train)"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": null,
109 | "id": "e3c2b304-8385-4931-8291-9b7cc462c95e",
110 | "metadata": {},
111 | "outputs": [],
112 | "source": [
113 | "# Show some items of clothing from the training set\n",
114 | "display(input_data)"
115 | ]
116 | },
117 | {
118 | "cell_type": "markdown",
119 | "id": "5ccd5cb2-8c7b-4667-8adb-4902f3fa60cf",
120 | "metadata": {},
121 | "source": [
122 | "## 2. Build the PixelCNN"
123 | ]
124 | },
125 | {
126 | "cell_type": "code",
127 | "execution_count": null,
128 | "id": "847050a5-e4e6-4134-9bfc-c690cb8cb44d",
129 | "metadata": {},
130 | "outputs": [],
131 | "source": [
132 | "# The first layer is the PixelCNN layer. This layer simply\n",
133 | "# builds on the 2D convolutional layer, but includes masking.\n",
134 | "class MaskedConv2D(layers.Layer):\n",
135 | " def __init__(self, mask_type, **kwargs):\n",
136 | " super(MaskedConv2D, self).__init__()\n",
137 | " self.mask_type = mask_type\n",
138 | " self.conv = layers.Conv2D(**kwargs)\n",
139 | "\n",
140 | " def build(self, input_shape):\n",
141 | " # Build the conv2d layer to initialize kernel variables\n",
142 | " self.conv.build(input_shape)\n",
143 | " # Use the initialized kernel to create the mask\n",
144 | " kernel_shape = self.conv.kernel.get_shape()\n",
145 | " self.mask = np.zeros(shape=kernel_shape)\n",
146 | " self.mask[: kernel_shape[0] // 2, ...] = 1.0\n",
147 | " self.mask[kernel_shape[0] // 2, : kernel_shape[1] // 2, ...] = 1.0\n",
148 | " if self.mask_type == \"B\":\n",
149 | " self.mask[kernel_shape[0] // 2, kernel_shape[1] // 2, ...] = 1.0\n",
150 | "\n",
151 | " def call(self, inputs):\n",
152 | " self.conv.kernel.assign(self.conv.kernel * self.mask)\n",
153 | " return self.conv(inputs)\n",
154 | "\n",
155 | " def get_config(self):\n",
156 | " cfg = super().get_config()\n",
157 | " return cfg"
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": null,
163 | "id": "a52f7795-790e-47b0-b724-80be3e3c3666",
164 | "metadata": {},
165 | "outputs": [],
166 | "source": [
167 | "class ResidualBlock(layers.Layer):\n",
168 | " def __init__(self, filters, **kwargs):\n",
169 | " super(ResidualBlock, self).__init__(**kwargs)\n",
170 | " self.conv1 = layers.Conv2D(\n",
171 | " filters=filters // 2, kernel_size=1, activation=\"relu\"\n",
172 | " )\n",
173 | " self.pixel_conv = MaskedConv2D(\n",
174 | " mask_type=\"B\",\n",
175 | " filters=filters // 2,\n",
176 | " kernel_size=3,\n",
177 | " activation=\"relu\",\n",
178 | " padding=\"same\",\n",
179 | " )\n",
180 | " self.conv2 = layers.Conv2D(\n",
181 | " filters=filters, kernel_size=1, activation=\"relu\"\n",
182 | " )\n",
183 | "\n",
184 | " def call(self, inputs):\n",
185 | " x = self.conv1(inputs)\n",
186 | " x = self.pixel_conv(x)\n",
187 | " x = self.conv2(x)\n",
188 | " return layers.add([inputs, x])\n",
189 | "\n",
190 | " def get_config(self):\n",
191 | " cfg = super().get_config()\n",
192 | " return cfg"
193 | ]
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": null,
198 | "id": "19b4508f-84de-42a9-a77f-950fb493db13",
199 | "metadata": {},
200 | "outputs": [],
201 | "source": [
202 | "inputs = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 1))\n",
203 | "x = MaskedConv2D(\n",
204 | " mask_type=\"A\",\n",
205 | " filters=N_FILTERS,\n",
206 | " kernel_size=7,\n",
207 | " activation=\"relu\",\n",
208 | " padding=\"same\",\n",
209 | ")(inputs)\n",
210 | "\n",
211 | "for _ in range(RESIDUAL_BLOCKS):\n",
212 | " x = ResidualBlock(filters=N_FILTERS)(x)\n",
213 | "\n",
214 | "for _ in range(2):\n",
215 | " x = MaskedConv2D(\n",
216 | " mask_type=\"B\",\n",
217 | " filters=N_FILTERS,\n",
218 | " kernel_size=1,\n",
219 | " strides=1,\n",
220 | " activation=\"relu\",\n",
221 | " padding=\"valid\",\n",
222 | " )(x)\n",
223 | "\n",
224 | "out = layers.Conv2D(\n",
225 | " filters=PIXEL_LEVELS,\n",
226 | " kernel_size=1,\n",
227 | " strides=1,\n",
228 | " activation=\"softmax\",\n",
229 | " padding=\"valid\",\n",
230 | ")(x)\n",
231 | "\n",
232 | "pixel_cnn = models.Model(inputs, out)\n",
233 | "pixel_cnn.summary()"
234 | ]
235 | },
236 | {
237 | "cell_type": "markdown",
238 | "id": "442b5ffa-67a3-4b15-a342-eb1eed5e87ac",
239 | "metadata": {},
240 | "source": [
241 | "## 3. Train the PixelCNN "
242 | ]
243 | },
244 | {
245 | "cell_type": "code",
246 | "execution_count": null,
247 | "id": "7204789a-2ad3-48bf-b7e8-00d4cab10d9c",
248 | "metadata": {},
249 | "outputs": [],
250 | "source": [
251 | "adam = optimizers.Adam(learning_rate=0.0005)\n",
252 | "pixel_cnn.compile(optimizer=adam, loss=\"sparse_categorical_crossentropy\")"
253 | ]
254 | },
255 | {
256 | "cell_type": "code",
257 | "execution_count": null,
258 | "id": "09d327fc-aff8-40e6-b390-d1bff4c06ea6",
259 | "metadata": {},
260 | "outputs": [],
261 | "source": [
262 | "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")\n",
263 | "\n",
264 | "\n",
265 | "class ImageGenerator(callbacks.Callback):\n",
266 | " def __init__(self, num_img):\n",
267 | " self.num_img = num_img\n",
268 | "\n",
269 | " def sample_from(self, probs, temperature): # <2>\n",
270 | " probs = probs ** (1 / temperature)\n",
271 | " probs = probs / np.sum(probs)\n",
272 | " return np.random.choice(len(probs), p=probs)\n",
273 | "\n",
274 | " def generate(self, temperature):\n",
275 | " generated_images = np.zeros(\n",
276 | " shape=(self.num_img,) + (pixel_cnn.input_shape)[1:]\n",
277 | " )\n",
278 | " batch, rows, cols, channels = generated_images.shape\n",
279 | "\n",
280 | " for row in range(rows):\n",
281 | " for col in range(cols):\n",
282 | " for channel in range(channels):\n",
283 | " probs = self.model.predict(generated_images, verbose=0)[\n",
284 | " :, row, col, :\n",
285 | " ]\n",
286 | " generated_images[:, row, col, channel] = [\n",
287 | " self.sample_from(x, temperature) for x in probs\n",
288 | " ]\n",
289 | " generated_images[:, row, col, channel] /= PIXEL_LEVELS\n",
290 | "\n",
291 | " return generated_images\n",
292 | "\n",
293 | " def on_epoch_end(self, epoch, logs=None):\n",
294 | " generated_images = self.generate(temperature=1.0)\n",
295 | " display(\n",
296 | " generated_images,\n",
297 | " save_to=\"./output/generated_img_%03d.png\" % (epoch),\n",
298 | " )\n",
299 | "\n",
300 | "\n",
301 | "img_generator_callback = ImageGenerator(num_img=10)"
302 | ]
303 | },
304 | {
305 | "cell_type": "code",
306 | "execution_count": null,
307 | "id": "85231056-d4a4-4897-ab91-065325a18d93",
308 | "metadata": {
309 | "tags": []
310 | },
311 | "outputs": [],
312 | "source": [
313 | "pixel_cnn.fit(\n",
314 | " input_data,\n",
315 | " output_data,\n",
316 | " batch_size=BATCH_SIZE,\n",
317 | " epochs=EPOCHS,\n",
318 | " callbacks=[tensorboard_callback, img_generator_callback],\n",
319 | ")"
320 | ]
321 | },
322 | {
323 | "cell_type": "markdown",
324 | "id": "cfb4fa72-dd2d-44c1-ad18-9c965060683e",
325 | "metadata": {},
326 | "source": [
327 | "## 4. Generate images "
328 | ]
329 | },
330 | {
331 | "cell_type": "code",
332 | "execution_count": null,
333 | "id": "7bbd4643-be09-49ba-b7bc-a524a2f00806",
334 | "metadata": {},
335 | "outputs": [],
336 | "source": [
337 | "generated_images = img_generator_callback.generate(temperature=1.0)"
338 | ]
339 | },
340 | {
341 | "cell_type": "code",
342 | "execution_count": null,
343 | "id": "52cadb4b-ae2c-42a9-92ac-68e2131380ef",
344 | "metadata": {},
345 | "outputs": [],
346 | "source": [
347 | "display(generated_images)"
348 | ]
349 | }
350 | ],
351 | "metadata": {
352 | "kernelspec": {
353 | "display_name": "Python 3 (ipykernel)",
354 | "language": "python",
355 | "name": "python3"
356 | },
357 | "language_info": {
358 | "codemirror_mode": {
359 | "name": "ipython",
360 | "version": 3
361 | },
362 | "file_extension": ".py",
363 | "mimetype": "text/x-python",
364 | "name": "python",
365 | "nbconvert_exporter": "python",
366 | "pygments_lexer": "ipython3",
367 | "version": "3.8.10"
368 | }
369 | },
370 | "nbformat": 4,
371 | "nbformat_minor": 5
372 | }
373 |
--------------------------------------------------------------------------------
/notebooks/05_autoregressive/03_pixelcnn_md/checkpoint/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/05_autoregressive/03_pixelcnn_md/logs/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/05_autoregressive/03_pixelcnn_md/models/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/05_autoregressive/03_pixelcnn_md/output/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/05_autoregressive/03_pixelcnn_md/pixelcnn_md.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "b076bd1a-b236-4fbc-953d-8295b25122ae",
6 | "metadata": {},
7 | "source": [
8 | "# 👾 PixelCNN using Tensorflow distributions"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "id": "9235cbd1-f136-411c-88d9-f69f270c0b96",
14 | "metadata": {},
15 | "source": [
16 | "In this notebook, we'll walk through the steps required to train your own PixelCNN on the fashion MNIST dataset using Tensorflow distributions"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "id": "84acc7be-6764-4668-b2bb-178f63deeed3",
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "%load_ext autoreload\n",
27 | "%autoreload 2\n",
28 | "import numpy as np\n",
29 | "\n",
30 | "import tensorflow as tf\n",
31 | "from tensorflow.keras import datasets, layers, models, optimizers, callbacks\n",
32 | "import tensorflow_probability as tfp\n",
33 | "\n",
34 | "from notebooks.utils import display"
35 | ]
36 | },
37 | {
38 | "cell_type": "markdown",
39 | "id": "339e6268-ebd7-4feb-86db-1fe7abccdbe5",
40 | "metadata": {},
41 | "source": [
42 | "## 0. Parameters "
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": null,
48 | "id": "1b2ee6ce-129f-4833-b0c5-fa567381c4e0",
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "IMAGE_SIZE = 32\n",
53 | "N_COMPONENTS = 5\n",
54 | "EPOCHS = 10\n",
55 | "BATCH_SIZE = 128"
56 | ]
57 | },
58 | {
59 | "cell_type": "markdown",
60 | "id": "b7716fac-0010-49b0-b98e-53be2259edde",
61 | "metadata": {},
62 | "source": [
63 | "## 1. Prepare the data "
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": null,
69 | "id": "9a73e5a4-1638-411c-8d3c-29f823424458",
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "# Load the data\n",
74 | "(x_train, _), (_, _) = datasets.fashion_mnist.load_data()"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "id": "ebae2f0d-59fd-4796-841f-7213eae638de",
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "# Preprocess the data\n",
85 | "\n",
86 | "\n",
87 | "def preprocess(imgs):\n",
88 | " imgs = np.expand_dims(imgs, -1)\n",
89 | " imgs = tf.image.resize(imgs, (IMAGE_SIZE, IMAGE_SIZE)).numpy()\n",
90 | " return imgs\n",
91 | "\n",
92 | "\n",
93 | "input_data = preprocess(x_train)"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": null,
99 | "id": "fa53709f-7f3f-483b-9db8-2e5f9b9942c2",
100 | "metadata": {},
101 | "outputs": [],
102 | "source": [
103 | "# Show some items of clothing from the training set\n",
104 | "display(input_data)"
105 | ]
106 | },
107 | {
108 | "cell_type": "markdown",
109 | "id": "aff50401-3abe-4c10-bba8-b35bc13ad7d5",
110 | "metadata": {
111 | "tags": []
112 | },
113 | "source": [
114 | "## 2. Build the PixelCNN "
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": null,
120 | "id": "71a2a4a1-690e-4c94-b323-86f0e5b691d5",
121 | "metadata": {},
122 | "outputs": [],
123 | "source": [
124 | "# Define a Pixel CNN network\n",
125 | "dist = tfp.distributions.PixelCNN(\n",
126 | " image_shape=(IMAGE_SIZE, IMAGE_SIZE, 1),\n",
127 | " num_resnet=1,\n",
128 | " num_hierarchies=2,\n",
129 | " num_filters=32,\n",
130 | " num_logistic_mix=N_COMPONENTS,\n",
131 | " dropout_p=0.3,\n",
132 | ")\n",
133 | "\n",
134 | "# Define the model input\n",
135 | "image_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 1))\n",
136 | "\n",
137 | "# Define the log likelihood for the loss fn\n",
138 | "log_prob = dist.log_prob(image_input)\n",
139 | "\n",
140 | "# Define the model\n",
141 | "pixelcnn = models.Model(inputs=image_input, outputs=log_prob)\n",
142 | "pixelcnn.add_loss(-tf.reduce_mean(log_prob))"
143 | ]
144 | },
145 | {
146 | "cell_type": "markdown",
147 | "id": "35b14665-4359-447b-be58-3fd58ba69084",
148 | "metadata": {},
149 | "source": [
150 | "## 3. Train the PixelCNN "
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": null,
156 | "id": "d9ec362d-41fa-473a-ad56-ebeec6cfd3b8",
157 | "metadata": {},
158 | "outputs": [],
159 | "source": [
160 | "# Compile and train the model\n",
161 | "pixelcnn.compile(\n",
162 | " optimizer=optimizers.Adam(0.001),\n",
163 | ")"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": null,
169 | "id": "c525e44b-b3bb-489c-9d35-fcfe3e714e6a",
170 | "metadata": {},
171 | "outputs": [],
172 | "source": [
173 | "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")\n",
174 | "\n",
175 | "\n",
176 | "class ImageGenerator(callbacks.Callback):\n",
177 | " def __init__(self, num_img):\n",
178 | " self.num_img = num_img\n",
179 | "\n",
180 | " def generate(self):\n",
181 | " return dist.sample(self.num_img).numpy()\n",
182 | "\n",
183 | " def on_epoch_end(self, epoch, logs=None):\n",
184 | " generated_images = self.generate()\n",
185 | " display(\n",
186 | " generated_images,\n",
187 | " n=self.num_img,\n",
188 | " save_to=\"./output/generated_img_%03d.png\" % (epoch),\n",
189 | " )\n",
190 | "\n",
191 | "\n",
192 | "img_generator_callback = ImageGenerator(num_img=2)"
193 | ]
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": null,
198 | "id": "bd6a5a71-eb55-4ec0-9c8c-cb11a382ff90",
199 | "metadata": {
200 | "tags": []
201 | },
202 | "outputs": [],
203 | "source": [
204 | "pixelcnn.fit(\n",
205 | " input_data,\n",
206 | " batch_size=BATCH_SIZE,\n",
207 | " epochs=EPOCHS,\n",
208 | " verbose=True,\n",
209 | " callbacks=[tensorboard_callback, img_generator_callback],\n",
210 | ")"
211 | ]
212 | },
213 | {
214 | "cell_type": "markdown",
215 | "id": "fb1f295f-ade0-4040-a6a5-a7b428b08ebc",
216 | "metadata": {},
217 | "source": [
218 | "## 4. Generate images "
219 | ]
220 | },
221 | {
222 | "cell_type": "code",
223 | "execution_count": null,
224 | "id": "8db3cfe3-339e-463d-8af5-fbd403385fca",
225 | "metadata": {},
226 | "outputs": [],
227 | "source": [
228 | "generated_images = img_generator_callback.generate()"
229 | ]
230 | },
231 | {
232 | "cell_type": "code",
233 | "execution_count": null,
234 | "id": "80087297-3f47-4e0c-ac89-8758d4386d7c",
235 | "metadata": {},
236 | "outputs": [],
237 | "source": [
238 | "display(generated_images, n=img_generator_callback.num_img)"
239 | ]
240 | }
241 | ],
242 | "metadata": {
243 | "kernelspec": {
244 | "display_name": "Python 3 (ipykernel)",
245 | "language": "python",
246 | "name": "python3"
247 | },
248 | "language_info": {
249 | "codemirror_mode": {
250 | "name": "ipython",
251 | "version": 3
252 | },
253 | "file_extension": ".py",
254 | "mimetype": "text/x-python",
255 | "name": "python",
256 | "nbconvert_exporter": "python",
257 | "pygments_lexer": "ipython3",
258 | "version": "3.8.10"
259 | }
260 | },
261 | "nbformat": 4,
262 | "nbformat_minor": 5
263 | }
264 |
--------------------------------------------------------------------------------
/notebooks/06_normflow/01_realnvp/checkpoint/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/06_normflow/01_realnvp/logs/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/06_normflow/01_realnvp/models/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/06_normflow/01_realnvp/output/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/06_normflow/01_realnvp/realnvp.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "b076bd1a-b236-4fbc-953d-8295b25122ae",
6 | "metadata": {},
7 | "source": [
8 | "# 🌀 RealNVP"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "id": "9235cbd1-f136-411c-88d9-f69f270c0b96",
14 | "metadata": {},
15 | "source": [
16 | "In this notebook, we'll walk through the steps required to train your own RealNVP network to predict the distribution of a demo dataset"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "id": "2db9a506-bf8f-4a40-ab27-703b0c82371b",
22 | "metadata": {},
23 | "source": [
24 | "The code has been adapted from the excellent [RealNVP tutorial](https://keras.io/examples/generative/real_nvp) created by Mandolini Giorgio Maria, Sanna Daniele and Zannini Quirini Giorgio available on the Keras website."
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": null,
30 | "id": "84acc7be-6764-4668-b2bb-178f63deeed3",
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "%load_ext autoreload\n",
35 | "%autoreload 2\n",
36 | "\n",
37 | "import numpy as np\n",
38 | "import matplotlib.pyplot as plt\n",
39 | "\n",
40 | "from sklearn import datasets\n",
41 | "\n",
42 | "import tensorflow as tf\n",
43 | "from tensorflow.keras import (\n",
44 | " layers,\n",
45 | " models,\n",
46 | " regularizers,\n",
47 | " metrics,\n",
48 | " optimizers,\n",
49 | " callbacks,\n",
50 | ")\n",
51 | "import tensorflow_probability as tfp"
52 | ]
53 | },
54 | {
55 | "cell_type": "markdown",
56 | "id": "339e6268-ebd7-4feb-86db-1fe7abccdbe5",
57 | "metadata": {},
58 | "source": [
59 | "## 0. Parameters "
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": null,
65 | "id": "1b2ee6ce-129f-4833-b0c5-fa567381c4e0",
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "COUPLING_DIM = 256\n",
70 | "COUPLING_LAYERS = 2\n",
71 | "INPUT_DIM = 2\n",
72 | "REGULARIZATION = 0.01\n",
73 | "BATCH_SIZE = 256\n",
74 | "EPOCHS = 300"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "id": "9a73e5a4-1638-411c-8d3c-29f823424458",
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "# Load the data\n",
85 | "data = datasets.make_moons(30000, noise=0.05)[0].astype(\"float32\")\n",
86 | "norm = layers.Normalization()\n",
87 | "norm.adapt(data)\n",
88 | "normalized_data = norm(data)\n",
89 | "plt.scatter(\n",
90 | " normalized_data.numpy()[:, 0], normalized_data.numpy()[:, 1], c=\"green\"\n",
91 | ")\n",
92 | "plt.show()"
93 | ]
94 | },
95 | {
96 | "cell_type": "markdown",
97 | "id": "aff50401-3abe-4c10-bba8-b35bc13ad7d5",
98 | "metadata": {
99 | "tags": []
100 | },
101 | "source": [
102 | "## 2. Build the RealNVP network "
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": null,
108 | "id": "71a2a4a1-690e-4c94-b323-86f0e5b691d5",
109 | "metadata": {},
110 | "outputs": [],
111 | "source": [
112 | "def Coupling(input_dim, coupling_dim, reg):\n",
113 | " input_layer = layers.Input(shape=input_dim)\n",
114 | "\n",
115 | " s_layer_1 = layers.Dense(\n",
116 | " coupling_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n",
117 | " )(input_layer)\n",
118 | " s_layer_2 = layers.Dense(\n",
119 | " coupling_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n",
120 | " )(s_layer_1)\n",
121 | " s_layer_3 = layers.Dense(\n",
122 | " coupling_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n",
123 | " )(s_layer_2)\n",
124 | " s_layer_4 = layers.Dense(\n",
125 | " coupling_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n",
126 | " )(s_layer_3)\n",
127 | " s_layer_5 = layers.Dense(\n",
128 | " input_dim, activation=\"tanh\", kernel_regularizer=regularizers.l2(reg)\n",
129 | " )(s_layer_4)\n",
130 | "\n",
131 | " t_layer_1 = layers.Dense(\n",
132 | " coupling_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n",
133 | " )(input_layer)\n",
134 | " t_layer_2 = layers.Dense(\n",
135 | " coupling_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n",
136 | " )(t_layer_1)\n",
137 | " t_layer_3 = layers.Dense(\n",
138 | " coupling_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n",
139 | " )(t_layer_2)\n",
140 | " t_layer_4 = layers.Dense(\n",
141 | " coupling_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n",
142 | " )(t_layer_3)\n",
143 | " t_layer_5 = layers.Dense(\n",
144 | " input_dim, activation=\"linear\", kernel_regularizer=regularizers.l2(reg)\n",
145 | " )(t_layer_4)\n",
146 | "\n",
147 | " return models.Model(inputs=input_layer, outputs=[s_layer_5, t_layer_5])"
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": null,
153 | "id": "0f4dcd44-4189-4f39-b262-7afedb00a5a9",
154 | "metadata": {},
155 | "outputs": [],
156 | "source": [
157 | "class RealNVP(models.Model):\n",
158 | " def __init__(\n",
159 | " self, input_dim, coupling_layers, coupling_dim, regularization\n",
160 | " ):\n",
161 | " super(RealNVP, self).__init__()\n",
162 | " self.coupling_layers = coupling_layers\n",
163 | " self.distribution = tfp.distributions.MultivariateNormalDiag(\n",
164 | " loc=[0.0, 0.0], scale_diag=[1.0, 1.0]\n",
165 | " )\n",
166 | " self.masks = np.array(\n",
167 | " [[0, 1], [1, 0]] * (coupling_layers // 2), dtype=\"float32\"\n",
168 | " )\n",
169 | " self.loss_tracker = metrics.Mean(name=\"loss\")\n",
170 | " self.layers_list = [\n",
171 | " Coupling(input_dim, coupling_dim, regularization)\n",
172 | " for i in range(coupling_layers)\n",
173 | " ]\n",
174 | "\n",
175 | " @property\n",
176 | " def metrics(self):\n",
177 | " return [self.loss_tracker]\n",
178 | "\n",
179 | " def call(self, x, training=True):\n",
180 | " log_det_inv = 0\n",
181 | " direction = 1\n",
182 | " if training:\n",
183 | " direction = -1\n",
184 | " for i in range(self.coupling_layers)[::direction]:\n",
185 | " x_masked = x * self.masks[i]\n",
186 | " reversed_mask = 1 - self.masks[i]\n",
187 | " s, t = self.layers_list[i](x_masked)\n",
188 | " s *= reversed_mask\n",
189 | " t *= reversed_mask\n",
190 | " gate = (direction - 1) / 2\n",
191 | " x = (\n",
192 | " reversed_mask\n",
193 | " * (x * tf.exp(direction * s) + direction * t * tf.exp(gate * s))\n",
194 | " + x_masked\n",
195 | " )\n",
196 | " log_det_inv += gate * tf.reduce_sum(s, axis=1)\n",
197 | " return x, log_det_inv\n",
198 | "\n",
199 | " def log_loss(self, x):\n",
200 | " y, logdet = self(x)\n",
201 | " log_likelihood = self.distribution.log_prob(y) + logdet\n",
202 | " return -tf.reduce_mean(log_likelihood)\n",
203 | "\n",
204 | " def train_step(self, data):\n",
205 | " with tf.GradientTape() as tape:\n",
206 | " loss = self.log_loss(data)\n",
207 | " g = tape.gradient(loss, self.trainable_variables)\n",
208 | " self.optimizer.apply_gradients(zip(g, self.trainable_variables))\n",
209 | " self.loss_tracker.update_state(loss)\n",
210 | " return {\"loss\": self.loss_tracker.result()}\n",
211 | "\n",
212 | " def test_step(self, data):\n",
213 | " loss = self.log_loss(data)\n",
214 | " self.loss_tracker.update_state(loss)\n",
215 | " return {\"loss\": self.loss_tracker.result()}\n",
216 | "\n",
217 | "\n",
218 | "model = RealNVP(\n",
219 | " input_dim=INPUT_DIM,\n",
220 | " coupling_layers=COUPLING_LAYERS,\n",
221 | " coupling_dim=COUPLING_DIM,\n",
222 | " regularization=REGULARIZATION,\n",
223 | ")"
224 | ]
225 | },
226 | {
227 | "cell_type": "markdown",
228 | "id": "35b14665-4359-447b-be58-3fd58ba69084",
229 | "metadata": {},
230 | "source": [
231 | "## 3. Train the RealNVP network "
232 | ]
233 | },
234 | {
235 | "cell_type": "code",
236 | "execution_count": null,
237 | "id": "d9ec362d-41fa-473a-ad56-ebeec6cfd3b8",
238 | "metadata": {},
239 | "outputs": [],
240 | "source": [
241 | "# Compile and train the model\n",
242 | "model.compile(optimizer=optimizers.Adam(learning_rate=0.0001))"
243 | ]
244 | },
245 | {
246 | "cell_type": "code",
247 | "execution_count": null,
248 | "id": "c525e44b-b3bb-489c-9d35-fcfe3e714e6a",
249 | "metadata": {},
250 | "outputs": [],
251 | "source": [
252 | "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")\n",
253 | "\n",
254 | "\n",
255 | "class ImageGenerator(callbacks.Callback):\n",
256 | " def __init__(self, num_samples):\n",
257 | " self.num_samples = num_samples\n",
258 | "\n",
259 | " def generate(self):\n",
260 | " # From data to latent space.\n",
261 | " z, _ = model(normalized_data)\n",
262 | "\n",
263 | " # From latent space to data.\n",
264 | " samples = model.distribution.sample(self.num_samples)\n",
265 | " x, _ = model.predict(samples, verbose=0)\n",
266 | "\n",
267 | " return x, z, samples\n",
268 | "\n",
269 | " def display(self, x, z, samples, save_to=None):\n",
270 | " f, axes = plt.subplots(2, 2)\n",
271 | " f.set_size_inches(8, 5)\n",
272 | "\n",
273 | " axes[0, 0].scatter(\n",
274 | " normalized_data[:, 0], normalized_data[:, 1], color=\"r\", s=1\n",
275 | " )\n",
276 | " axes[0, 0].set(title=\"Data space X\", xlabel=\"x_1\", ylabel=\"x_2\")\n",
277 | " axes[0, 0].set_xlim([-2, 2])\n",
278 | " axes[0, 0].set_ylim([-2, 2])\n",
279 | " axes[0, 1].scatter(z[:, 0], z[:, 1], color=\"r\", s=1)\n",
280 | " axes[0, 1].set(title=\"f(X)\", xlabel=\"z_1\", ylabel=\"z_2\")\n",
281 | " axes[0, 1].set_xlim([-2, 2])\n",
282 | " axes[0, 1].set_ylim([-2, 2])\n",
283 | " axes[1, 0].scatter(samples[:, 0], samples[:, 1], color=\"g\", s=1)\n",
284 | " axes[1, 0].set(title=\"Latent space Z\", xlabel=\"z_1\", ylabel=\"z_2\")\n",
285 | " axes[1, 0].set_xlim([-2, 2])\n",
286 | " axes[1, 0].set_ylim([-2, 2])\n",
287 | " axes[1, 1].scatter(x[:, 0], x[:, 1], color=\"g\", s=1)\n",
288 | " axes[1, 1].set(title=\"g(Z)\", xlabel=\"x_1\", ylabel=\"x_2\")\n",
289 | " axes[1, 1].set_xlim([-2, 2])\n",
290 | " axes[1, 1].set_ylim([-2, 2])\n",
291 | "\n",
292 | " plt.subplots_adjust(wspace=0.3, hspace=0.6)\n",
293 | " if save_to:\n",
294 | " plt.savefig(save_to)\n",
295 | " print(f\"\\nSaved to {save_to}\")\n",
296 | "\n",
297 | " plt.show()\n",
298 | "\n",
299 | " def on_epoch_end(self, epoch, logs=None):\n",
300 | " if epoch % 10 == 0:\n",
301 | " x, z, samples = self.generate()\n",
302 | " self.display(\n",
303 | " x,\n",
304 | " z,\n",
305 | " samples,\n",
306 | " save_to=\"./output/generated_img_%03d.png\" % (epoch),\n",
307 | " )\n",
308 | "\n",
309 | "\n",
310 | "img_generator_callback = ImageGenerator(num_samples=3000)"
311 | ]
312 | },
313 | {
314 | "cell_type": "code",
315 | "execution_count": null,
316 | "id": "bd6a5a71-eb55-4ec0-9c8c-cb11a382ff90",
317 | "metadata": {
318 | "tags": []
319 | },
320 | "outputs": [],
321 | "source": [
322 | "model.fit(\n",
323 | " normalized_data,\n",
324 | " batch_size=BATCH_SIZE,\n",
325 | " epochs=EPOCHS,\n",
326 | " callbacks=[tensorboard_callback, img_generator_callback],\n",
327 | ")"
328 | ]
329 | },
330 | {
331 | "cell_type": "markdown",
332 | "id": "fb1f295f-ade0-4040-a6a5-a7b428b08ebc",
333 | "metadata": {},
334 | "source": [
335 | "## 4. Generate images "
336 | ]
337 | },
338 | {
339 | "cell_type": "code",
340 | "execution_count": null,
341 | "id": "8db3cfe3-339e-463d-8af5-fbd403385fca",
342 | "metadata": {},
343 | "outputs": [],
344 | "source": [
345 | "x, z, samples = img_generator_callback.generate()"
346 | ]
347 | },
348 | {
349 | "cell_type": "code",
350 | "execution_count": null,
351 | "id": "80087297-3f47-4e0c-ac89-8758d4386d7c",
352 | "metadata": {},
353 | "outputs": [],
354 | "source": [
355 | "img_generator_callback.display(x, z, samples)"
356 | ]
357 | }
358 | ],
359 | "metadata": {
360 | "kernelspec": {
361 | "display_name": "Python 3 (ipykernel)",
362 | "language": "python",
363 | "name": "python3"
364 | },
365 | "language_info": {
366 | "codemirror_mode": {
367 | "name": "ipython",
368 | "version": 3
369 | },
370 | "file_extension": ".py",
371 | "mimetype": "text/x-python",
372 | "name": "python",
373 | "nbconvert_exporter": "python",
374 | "pygments_lexer": "ipython3",
375 | "version": "3.8.10"
376 | }
377 | },
378 | "nbformat": 4,
379 | "nbformat_minor": 5
380 | }
381 |
--------------------------------------------------------------------------------
/notebooks/07_ebm/01_ebm/checkpoint/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/07_ebm/01_ebm/ebm.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "b076bd1a-b236-4fbc-953d-8295b25122ae",
6 | "metadata": {},
7 | "source": [
8 | "# ⚡️ Energy-Based Models"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "id": "9235cbd1-f136-411c-88d9-f69f270c0b96",
14 | "metadata": {},
15 | "source": [
16 | "In this notebook, we'll walk through the steps required to train your own Energy Based Model to predict the distribution of a demo dataset"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "id": "2531aef5-c81a-4b53-a344-4b979dd4eec5",
22 | "metadata": {},
23 | "source": [
24 | "The code is adapted from the excellent ['Deep Energy-Based Generative Models' tutorial](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial8/Deep_Energy_Models.html) created by Phillip Lippe."
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": null,
30 | "id": "84acc7be-6764-4668-b2bb-178f63deeed3",
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "%load_ext autoreload\n",
35 | "%autoreload 2\n",
36 | "\n",
37 | "import numpy as np\n",
38 | "\n",
39 | "import tensorflow as tf\n",
40 | "from tensorflow.keras import (\n",
41 | " datasets,\n",
42 | " layers,\n",
43 | " models,\n",
44 | " optimizers,\n",
45 | " activations,\n",
46 | " metrics,\n",
47 | " callbacks,\n",
48 | ")\n",
49 | "\n",
50 | "from notebooks.utils import display, sample_batch\n",
51 | "import random"
52 | ]
53 | },
54 | {
55 | "cell_type": "markdown",
56 | "id": "339e6268-ebd7-4feb-86db-1fe7abccdbe5",
57 | "metadata": {},
58 | "source": [
59 | "## 0. Parameters "
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": null,
65 | "id": "1b2ee6ce-129f-4833-b0c5-fa567381c4e0",
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "IMAGE_SIZE = 32\n",
70 | "CHANNELS = 1\n",
71 | "STEP_SIZE = 10\n",
72 | "STEPS = 60\n",
73 | "NOISE = 0.005\n",
74 | "ALPHA = 0.1\n",
75 | "GRADIENT_CLIP = 0.03\n",
76 | "BATCH_SIZE = 128\n",
77 | "BUFFER_SIZE = 8192\n",
78 | "LEARNING_RATE = 0.0001\n",
79 | "EPOCHS = 60\n",
80 | "LOAD_MODEL = False"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": null,
86 | "id": "9a73e5a4-1638-411c-8d3c-29f823424458",
87 | "metadata": {},
88 | "outputs": [],
89 | "source": [
90 | "# Load the data\n",
91 | "(x_train, _), (x_test, _) = datasets.mnist.load_data()"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": null,
97 | "id": "20697102-8c8d-4582-88d4-f8e2af84e060",
98 | "metadata": {},
99 | "outputs": [],
100 | "source": [
101 | "# Preprocess the data\n",
102 | "\n",
103 | "\n",
104 | "def preprocess(imgs):\n",
105 | " \"\"\"\n",
106 | " Normalize and reshape the images\n",
107 | " \"\"\"\n",
108 | " imgs = (imgs.astype(\"float32\") - 127.5) / 127.5\n",
109 | " imgs = np.pad(imgs, ((0, 0), (2, 2), (2, 2)), constant_values=-1.0)\n",
110 | " imgs = np.expand_dims(imgs, -1)\n",
111 | " return imgs\n",
112 | "\n",
113 | "\n",
114 | "x_train = preprocess(x_train)\n",
115 | "x_test = preprocess(x_test)"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": null,
121 | "id": "13668819-2e42-4661-8682-33ff2c24ae8b",
122 | "metadata": {},
123 | "outputs": [],
124 | "source": [
125 | "x_train = tf.data.Dataset.from_tensor_slices(x_train).batch(BATCH_SIZE)\n",
126 | "x_test = tf.data.Dataset.from_tensor_slices(x_test).batch(BATCH_SIZE)"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": null,
132 | "id": "a7e1a420-699e-4869-8d10-3c049dbad030",
133 | "metadata": {},
134 | "outputs": [],
135 | "source": [
136 | "# Show some items of clothing from the training set\n",
137 | "train_sample = sample_batch(x_train)\n",
138 | "display(train_sample)"
139 | ]
140 | },
141 | {
142 | "cell_type": "markdown",
143 | "id": "f53945d9-b7c5-49d0-a356-bcf1d1e1798b",
144 | "metadata": {},
145 | "source": [
146 | "## 2. Build the EBM network "
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": null,
152 | "id": "8936d951-3281-4424-9cce-59433976bf2f",
153 | "metadata": {},
154 | "outputs": [],
155 | "source": [
156 | "ebm_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS))\n",
157 | "x = layers.Conv2D(\n",
158 | " 16, kernel_size=5, strides=2, padding=\"same\", activation=activations.swish\n",
159 | ")(ebm_input)\n",
160 | "x = layers.Conv2D(\n",
161 | " 32, kernel_size=3, strides=2, padding=\"same\", activation=activations.swish\n",
162 | ")(x)\n",
163 | "x = layers.Conv2D(\n",
164 | " 64, kernel_size=3, strides=2, padding=\"same\", activation=activations.swish\n",
165 | ")(x)\n",
166 | "x = layers.Conv2D(\n",
167 | " 64, kernel_size=3, strides=2, padding=\"same\", activation=activations.swish\n",
168 | ")(x)\n",
169 | "x = layers.Flatten()(x)\n",
170 | "x = layers.Dense(64, activation=activations.swish)(x)\n",
171 | "ebm_output = layers.Dense(1)(x)\n",
172 | "model = models.Model(ebm_input, ebm_output)\n",
173 | "model.summary()"
174 | ]
175 | },
176 | {
177 | "cell_type": "code",
178 | "execution_count": null,
179 | "id": "32221908-8819-48fa-8e57-0dc5179ca2cf",
180 | "metadata": {
181 | "tags": []
182 | },
183 | "outputs": [],
184 | "source": [
185 | "if LOAD_MODEL:\n",
186 | " model.load_weights(\"./models/model.h5\")"
187 | ]
188 | },
189 | {
190 | "cell_type": "markdown",
191 | "id": "1f392424-45a9-49cc-8ea0-c1bec9064d74",
192 | "metadata": {},
193 | "source": [
194 | "## 2. Set up a Langevin sampler function "
195 | ]
196 | },
197 | {
198 | "cell_type": "code",
199 | "execution_count": null,
200 | "id": "bf10775a-0fbf-42df-aca5-be4b256a0c2b",
201 | "metadata": {},
202 | "outputs": [],
203 | "source": [
204 | "# Function to generate samples using Langevin Dynamics\n",
205 | "def generate_samples(\n",
206 | " model, inp_imgs, steps, step_size, noise, return_img_per_step=False\n",
207 | "):\n",
208 | " imgs_per_step = []\n",
209 | " for _ in range(steps):\n",
210 | " inp_imgs += tf.random.normal(inp_imgs.shape, mean=0, stddev=noise)\n",
211 | " inp_imgs = tf.clip_by_value(inp_imgs, -1.0, 1.0)\n",
212 | " with tf.GradientTape() as tape:\n",
213 | " tape.watch(inp_imgs)\n",
214 | " out_score = model(inp_imgs)\n",
215 | " grads = tape.gradient(out_score, inp_imgs)\n",
216 | " grads = tf.clip_by_value(grads, -GRADIENT_CLIP, GRADIENT_CLIP)\n",
217 | " inp_imgs += step_size * grads\n",
218 | " inp_imgs = tf.clip_by_value(inp_imgs, -1.0, 1.0)\n",
219 | " if return_img_per_step:\n",
220 | " imgs_per_step.append(inp_imgs)\n",
221 | " if return_img_per_step:\n",
222 | " return tf.stack(imgs_per_step, axis=0)\n",
223 | " else:\n",
224 | " return inp_imgs"
225 | ]
226 | },
227 | {
228 | "cell_type": "markdown",
229 | "id": "180fb0a1-ed16-47c2-b326-ad66071cd6e2",
230 | "metadata": {},
231 | "source": [
232 | "## 3. Set up a buffer to store examples "
233 | ]
234 | },
235 | {
236 | "cell_type": "code",
237 | "execution_count": null,
238 | "id": "52615dcd-be2b-4e05-b729-0ec45ea6ef98",
239 | "metadata": {},
240 | "outputs": [],
241 | "source": [
242 | "class Buffer:\n",
243 | " def __init__(self, model):\n",
244 | " super().__init__()\n",
245 | " self.model = model\n",
246 | " self.examples = [\n",
247 | " tf.random.uniform(shape=(1, IMAGE_SIZE, IMAGE_SIZE, CHANNELS)) * 2\n",
248 | " - 1\n",
249 | " for _ in range(BATCH_SIZE)\n",
250 | " ]\n",
251 | "\n",
252 | " def sample_new_exmps(self, steps, step_size, noise):\n",
253 | " n_new = np.random.binomial(BATCH_SIZE, 0.05)\n",
254 | " rand_imgs = (\n",
255 | " tf.random.uniform((n_new, IMAGE_SIZE, IMAGE_SIZE, CHANNELS)) * 2 - 1\n",
256 | " )\n",
257 | " old_imgs = tf.concat(\n",
258 | " random.choices(self.examples, k=BATCH_SIZE - n_new), axis=0\n",
259 | " )\n",
260 | " inp_imgs = tf.concat([rand_imgs, old_imgs], axis=0)\n",
261 | " inp_imgs = generate_samples(\n",
262 | " self.model, inp_imgs, steps=steps, step_size=step_size, noise=noise\n",
263 | " )\n",
264 | " self.examples = tf.split(inp_imgs, BATCH_SIZE, axis=0) + self.examples\n",
265 | " self.examples = self.examples[:BUFFER_SIZE]\n",
266 | " return inp_imgs"
267 | ]
268 | },
269 | {
270 | "cell_type": "code",
271 | "execution_count": null,
272 | "id": "71a2a4a1-690e-4c94-b323-86f0e5b691d5",
273 | "metadata": {},
274 | "outputs": [],
275 | "source": [
276 | "class EBM(models.Model):\n",
277 | " def __init__(self):\n",
278 | " super(EBM, self).__init__()\n",
279 | " self.model = model\n",
280 | " self.buffer = Buffer(self.model)\n",
281 | " self.alpha = ALPHA\n",
282 | " self.loss_metric = metrics.Mean(name=\"loss\")\n",
283 | " self.reg_loss_metric = metrics.Mean(name=\"reg\")\n",
284 | " self.cdiv_loss_metric = metrics.Mean(name=\"cdiv\")\n",
285 | " self.real_out_metric = metrics.Mean(name=\"real\")\n",
286 | " self.fake_out_metric = metrics.Mean(name=\"fake\")\n",
287 | "\n",
288 | " @property\n",
289 | " def metrics(self):\n",
290 | " return [\n",
291 | " self.loss_metric,\n",
292 | " self.reg_loss_metric,\n",
293 | " self.cdiv_loss_metric,\n",
294 | " self.real_out_metric,\n",
295 | " self.fake_out_metric,\n",
296 | " ]\n",
297 | "\n",
298 | " def train_step(self, real_imgs):\n",
299 | " real_imgs += tf.random.normal(\n",
300 | " shape=tf.shape(real_imgs), mean=0, stddev=NOISE\n",
301 | " )\n",
302 | " real_imgs = tf.clip_by_value(real_imgs, -1.0, 1.0)\n",
303 | " fake_imgs = self.buffer.sample_new_exmps(\n",
304 | " steps=STEPS, step_size=STEP_SIZE, noise=NOISE\n",
305 | " )\n",
306 | " inp_imgs = tf.concat([real_imgs, fake_imgs], axis=0)\n",
307 | " with tf.GradientTape() as training_tape:\n",
308 | " real_out, fake_out = tf.split(self.model(inp_imgs), 2, axis=0)\n",
309 | " cdiv_loss = tf.reduce_mean(fake_out, axis=0) - tf.reduce_mean(\n",
310 | " real_out, axis=0\n",
311 | " )\n",
312 | " reg_loss = self.alpha * tf.reduce_mean(\n",
313 | " real_out**2 + fake_out**2, axis=0\n",
314 | " )\n",
315 | " loss = cdiv_loss + reg_loss\n",
316 | " grads = training_tape.gradient(loss, self.model.trainable_variables)\n",
317 | " self.optimizer.apply_gradients(\n",
318 | " zip(grads, self.model.trainable_variables)\n",
319 | " )\n",
320 | " self.loss_metric.update_state(loss)\n",
321 | " self.reg_loss_metric.update_state(reg_loss)\n",
322 | " self.cdiv_loss_metric.update_state(cdiv_loss)\n",
323 | " self.real_out_metric.update_state(tf.reduce_mean(real_out, axis=0))\n",
324 | " self.fake_out_metric.update_state(tf.reduce_mean(fake_out, axis=0))\n",
325 | " return {m.name: m.result() for m in self.metrics}\n",
326 | "\n",
327 | " def test_step(self, real_imgs):\n",
328 | " batch_size = real_imgs.shape[0]\n",
329 | " fake_imgs = (\n",
330 | " tf.random.uniform((batch_size, IMAGE_SIZE, IMAGE_SIZE, CHANNELS))\n",
331 | " * 2\n",
332 | " - 1\n",
333 | " )\n",
334 | " inp_imgs = tf.concat([real_imgs, fake_imgs], axis=0)\n",
335 | " real_out, fake_out = tf.split(self.model(inp_imgs), 2, axis=0)\n",
336 | " cdiv = tf.reduce_mean(fake_out, axis=0) - tf.reduce_mean(\n",
337 | " real_out, axis=0\n",
338 | " )\n",
339 | " self.cdiv_loss_metric.update_state(cdiv)\n",
340 | " self.real_out_metric.update_state(tf.reduce_mean(real_out, axis=0))\n",
341 | " self.fake_out_metric.update_state(tf.reduce_mean(fake_out, axis=0))\n",
342 | " return {m.name: m.result() for m in self.metrics[2:]}"
343 | ]
344 | },
345 | {
346 | "cell_type": "code",
347 | "execution_count": null,
348 | "id": "6337e801-eb59-4abe-84dc-9536cf4dc257",
349 | "metadata": {},
350 | "outputs": [],
351 | "source": [
352 | "ebm = EBM()"
353 | ]
354 | },
355 | {
356 | "cell_type": "markdown",
357 | "id": "35b14665-4359-447b-be58-3fd58ba69084",
358 | "metadata": {},
359 | "source": [
360 | "## 3. Train the EBM network "
361 | ]
362 | },
363 | {
364 | "cell_type": "code",
365 | "execution_count": null,
366 | "id": "d9ec362d-41fa-473a-ad56-ebeec6cfd3b8",
367 | "metadata": {},
368 | "outputs": [],
369 | "source": [
370 | "# Compile and train the model\n",
371 | "ebm.compile(\n",
372 | " optimizer=optimizers.Adam(learning_rate=LEARNING_RATE), run_eagerly=True\n",
373 | ")"
374 | ]
375 | },
376 | {
377 | "cell_type": "code",
378 | "execution_count": null,
379 | "id": "8ceca4de-f634-40ff-beb8-09ba42fd0f75",
380 | "metadata": {},
381 | "outputs": [],
382 | "source": [
383 | "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")\n",
384 | "\n",
385 | "\n",
386 | "class ImageGenerator(callbacks.Callback):\n",
387 | " def __init__(self, num_img):\n",
388 | " self.num_img = num_img\n",
389 | "\n",
390 | " def on_epoch_end(self, epoch, logs=None):\n",
391 | " start_imgs = (\n",
392 | " np.random.uniform(\n",
393 | " size=(self.num_img, IMAGE_SIZE, IMAGE_SIZE, CHANNELS)\n",
394 | " )\n",
395 | " * 2\n",
396 | " - 1\n",
397 | " )\n",
398 | " generated_images = generate_samples(\n",
399 | " ebm.model,\n",
400 | " start_imgs,\n",
401 | " steps=1000,\n",
402 | " step_size=STEP_SIZE,\n",
403 | " noise=NOISE,\n",
404 | " return_img_per_step=False,\n",
405 | " )\n",
406 | " generated_images = generated_images.numpy()\n",
407 | " display(\n",
408 | " generated_images,\n",
409 | " save_to=\"./output/generated_img_%03d.png\" % (epoch),\n",
410 | " )\n",
411 | "\n",
412 | " example_images = tf.concat(\n",
413 | " random.choices(ebm.buffer.examples, k=10), axis=0\n",
414 | " )\n",
415 | " example_images = example_images.numpy()\n",
416 | " display(\n",
417 | " example_images, save_to=\"./output/example_img_%03d.png\" % (epoch)\n",
418 | " )\n",
419 | "\n",
420 | "\n",
421 | "image_generator_callback = ImageGenerator(num_img=10)"
422 | ]
423 | },
424 | {
425 | "cell_type": "code",
426 | "execution_count": null,
427 | "id": "627c1387-f29a-4cce-85a8-0903c1890e23",
428 | "metadata": {},
429 | "outputs": [],
430 | "source": [
431 | "class SaveModel(callbacks.Callback):\n",
432 | " def on_epoch_end(self, epoch, logs=None):\n",
433 | " model.save_weights(\"./models/model.h5\")\n",
434 | "\n",
435 | "\n",
436 | "save_model_callback = SaveModel()"
437 | ]
438 | },
439 | {
440 | "cell_type": "code",
441 | "execution_count": null,
442 | "id": "bd6a5a71-eb55-4ec0-9c8c-cb11a382ff90",
443 | "metadata": {
444 | "scrolled": true,
445 | "tags": []
446 | },
447 | "outputs": [],
448 | "source": [
449 | "ebm.fit(\n",
450 | " x_train,\n",
451 | " shuffle=True,\n",
452 | " epochs=60,\n",
453 | " validation_data=x_test,\n",
454 | " callbacks=[\n",
455 | " save_model_callback,\n",
456 | " tensorboard_callback,\n",
457 | " image_generator_callback,\n",
458 | " ],\n",
459 | ")"
460 | ]
461 | },
462 | {
463 | "cell_type": "markdown",
464 | "id": "fb1f295f-ade0-4040-a6a5-a7b428b08ebc",
465 | "metadata": {},
466 | "source": [
467 | "## 4. Generate images "
468 | ]
469 | },
470 | {
471 | "cell_type": "code",
472 | "execution_count": null,
473 | "id": "8db3cfe3-339e-463d-8af5-fbd403385fca",
474 | "metadata": {},
475 | "outputs": [],
476 | "source": [
477 | "start_imgs = (\n",
478 | " np.random.uniform(size=(10, IMAGE_SIZE, IMAGE_SIZE, CHANNELS)) * 2 - 1\n",
479 | ")"
480 | ]
481 | },
482 | {
483 | "cell_type": "code",
484 | "execution_count": null,
485 | "id": "80087297-3f47-4e0c-ac89-8758d4386d7c",
486 | "metadata": {},
487 | "outputs": [],
488 | "source": [
489 | "display(start_imgs)"
490 | ]
491 | },
492 | {
493 | "cell_type": "code",
494 | "execution_count": null,
495 | "id": "eaf4b749-5f6e-4a12-863f-b0bbcd23549c",
496 | "metadata": {
497 | "scrolled": true,
498 | "tags": []
499 | },
500 | "outputs": [],
501 | "source": [
502 | "gen_img = generate_samples(\n",
503 | " ebm.model,\n",
504 | " start_imgs,\n",
505 | " steps=1000,\n",
506 | " step_size=STEP_SIZE,\n",
507 | " noise=NOISE,\n",
508 | " return_img_per_step=True,\n",
509 | ")"
510 | ]
511 | },
512 | {
513 | "cell_type": "code",
514 | "execution_count": null,
515 | "id": "eac707f6-0597-499c-9a52-7cade6724795",
516 | "metadata": {
517 | "tags": []
518 | },
519 | "outputs": [],
520 | "source": [
521 | "display(gen_img[-1].numpy())"
522 | ]
523 | },
524 | {
525 | "cell_type": "code",
526 | "execution_count": null,
527 | "id": "8476aaa1-e0e7-44dc-a1fd-cc30344b8dcb",
528 | "metadata": {},
529 | "outputs": [],
530 | "source": [
531 | "imgs = []\n",
532 | "for i in [0, 1, 3, 5, 10, 30, 50, 100, 300, 999]:\n",
533 | " imgs.append(gen_img[i].numpy()[6])\n",
534 | "\n",
535 | "display(np.array(imgs))"
536 | ]
537 | },
538 | {
539 | "cell_type": "code",
540 | "execution_count": null,
541 | "id": "3e05d5a0-e124-40d1-80ed-552260c6e350",
542 | "metadata": {},
543 | "outputs": [],
544 | "source": []
545 | }
546 | ],
547 | "metadata": {
548 | "kernelspec": {
549 | "display_name": "Python 3 (ipykernel)",
550 | "language": "python",
551 | "name": "python3"
552 | },
553 | "language_info": {
554 | "codemirror_mode": {
555 | "name": "ipython",
556 | "version": 3
557 | },
558 | "file_extension": ".py",
559 | "mimetype": "text/x-python",
560 | "name": "python",
561 | "nbconvert_exporter": "python",
562 | "pygments_lexer": "ipython3",
563 | "version": "3.8.10"
564 | }
565 | },
566 | "nbformat": 4,
567 | "nbformat_minor": 5
568 | }
569 |
--------------------------------------------------------------------------------
/notebooks/07_ebm/01_ebm/logs/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/07_ebm/01_ebm/models/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/07_ebm/01_ebm/output/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/08_diffusion/01_ddm/checkpoint/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/08_diffusion/01_ddm/logs/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/08_diffusion/01_ddm/models/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/08_diffusion/01_ddm/output/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/09_transformer/gpt/checkpoint/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/09_transformer/gpt/logs/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/09_transformer/gpt/models/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/09_transformer/gpt/output/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/11_music/01_transformer/checkpoint/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/11_music/01_transformer/logs/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/11_music/01_transformer/models/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/11_music/01_transformer/output/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/11_music/01_transformer/parsed_data/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/11_music/01_transformer/transformer_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle as pkl
3 | import music21
4 | import keras
5 | import tensorflow as tf
6 |
7 | from fractions import Fraction
8 |
9 |
10 | def parse_midi_files(file_list, parser, seq_len, parsed_data_path=None):
11 | notes_list = []
12 | duration_list = []
13 | notes = []
14 | durations = []
15 |
16 | for i, file in enumerate(file_list):
17 | print(i + 1, "Parsing %s" % file)
18 | score = parser.parse(file).chordify()
19 |
20 | notes.append("START")
21 | durations.append("0.0")
22 |
23 | for element in score.flat:
24 | note_name = None
25 | duration_name = None
26 |
27 | if isinstance(element, music21.key.Key):
28 | note_name = str(element.tonic.name) + ":" + str(element.mode)
29 | duration_name = "0.0"
30 |
31 | elif isinstance(element, music21.meter.TimeSignature):
32 | note_name = str(element.ratioString) + "TS"
33 | duration_name = "0.0"
34 |
35 | elif isinstance(element, music21.chord.Chord):
36 | note_name = element.pitches[-1].nameWithOctave
37 | duration_name = str(element.duration.quarterLength)
38 |
39 | elif isinstance(element, music21.note.Rest):
40 | note_name = str(element.name)
41 | duration_name = str(element.duration.quarterLength)
42 |
43 | elif isinstance(element, music21.note.Note):
44 | note_name = str(element.nameWithOctave)
45 | duration_name = str(element.duration.quarterLength)
46 |
47 | if note_name and duration_name:
48 | notes.append(note_name)
49 | durations.append(duration_name)
50 | print(f"{len(notes)} notes parsed")
51 |
52 | notes_list = []
53 | duration_list = []
54 |
55 | print(f"Building sequences of length {seq_len}")
56 | for i in range(len(notes) - seq_len):
57 | notes_list.append(" ".join(notes[i : (i + seq_len)]))
58 | duration_list.append(" ".join(durations[i : (i + seq_len)]))
59 |
60 | if parsed_data_path:
61 | with open(os.path.join(parsed_data_path, "notes"), "wb") as f:
62 | pkl.dump(notes_list, f)
63 | with open(os.path.join(parsed_data_path, "durations"), "wb") as f:
64 | pkl.dump(duration_list, f)
65 |
66 | return notes_list, duration_list
67 |
68 |
69 | def load_parsed_files(parsed_data_path):
70 | with open(os.path.join(parsed_data_path, "notes"), "rb") as f:
71 | notes = pkl.load(f)
72 | with open(os.path.join(parsed_data_path, "durations"), "rb") as f:
73 | durations = pkl.load(f)
74 | return notes, durations
75 |
76 |
77 | def get_midi_note(sample_note, sample_duration):
78 | new_note = None
79 |
80 | if "TS" in sample_note:
81 | new_note = music21.meter.TimeSignature(sample_note.split("TS")[0])
82 |
83 | elif "major" in sample_note or "minor" in sample_note:
84 | tonic, mode = sample_note.split(":")
85 | new_note = music21.key.Key(tonic, mode)
86 |
87 | elif sample_note == "rest":
88 | new_note = music21.note.Rest()
89 | new_note.duration = music21.duration.Duration(
90 | float(Fraction(sample_duration))
91 | )
92 | new_note.storedInstrument = music21.instrument.Violoncello()
93 |
94 | elif "." in sample_note:
95 | notes_in_chord = sample_note.split(".")
96 | chord_notes = []
97 | for current_note in notes_in_chord:
98 | n = music21.note.Note(current_note)
99 | n.duration = music21.duration.Duration(
100 | float(Fraction(sample_duration))
101 | )
102 | n.storedInstrument = music21.instrument.Violoncello()
103 | chord_notes.append(n)
104 | new_note = music21.chord.Chord(chord_notes)
105 |
106 | elif sample_note == "rest":
107 | new_note = music21.note.Rest()
108 | new_note.duration = music21.duration.Duration(
109 | float(Fraction(sample_duration))
110 | )
111 | new_note.storedInstrument = music21.instrument.Violoncello()
112 |
113 | elif sample_note != "START":
114 | new_note = music21.note.Note(sample_note)
115 | new_note.duration = music21.duration.Duration(
116 | float(Fraction(sample_duration))
117 | )
118 | new_note.storedInstrument = music21.instrument.Violoncello()
119 |
120 | return new_note
121 |
122 |
123 | class SinePositionEncoding(keras.layers.Layer):
124 | """Sinusoidal positional encoding layer.
125 | This layer calculates the position encoding as a mix of sine and cosine
126 | functions with geometrically increasing wavelengths. Defined and formulized
127 | in [Attention is All You Need](https://arxiv.org/abs/1706.03762).
128 | Takes as input an embedded token tensor. The input must have shape
129 | [batch_size, sequence_length, feature_size]. This layer will return a
130 | positional encoding the same size as the embedded token tensor, which
131 | can be added directly to the embedded token tensor.
132 | Args:
133 | max_wavelength: The maximum angular wavelength of the sine/cosine
134 | curves, as described in Attention is All You Need. Defaults to
135 | 10000.
136 | Examples:
137 | ```python
138 | # create a simple embedding layer with sinusoidal positional encoding
139 | seq_len = 100
140 | vocab_size = 1000
141 | embedding_dim = 32
142 | inputs = keras.Input((seq_len,), dtype=tf.float32)
143 | embedding = keras.layers.Embedding(
144 | input_dim=vocab_size, output_dim=embedding_dim
145 | )(inputs)
146 | positional_encoding = keras_nlp.layers.SinePositionEncoding()(embedding)
147 | outputs = embedding + positional_encoding
148 | ```
149 | References:
150 | - [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)
151 | """
152 |
153 | def __init__(
154 | self,
155 | max_wavelength=10000,
156 | **kwargs,
157 | ):
158 | super().__init__(**kwargs)
159 | self.max_wavelength = max_wavelength
160 |
161 | def call(self, inputs):
162 | # TODO(jbischof): replace `hidden_size` with`hidden_dim` for consistency
163 | # with other layers.
164 | input_shape = tf.shape(inputs)
165 | # length of sequence is the second last dimension of the inputs
166 | seq_length = input_shape[-2]
167 | hidden_size = input_shape[-1]
168 | position = tf.cast(tf.range(seq_length), self.compute_dtype)
169 | min_freq = tf.cast(1 / self.max_wavelength, dtype=self.compute_dtype)
170 | timescales = tf.pow(
171 | min_freq,
172 | tf.cast(2 * (tf.range(hidden_size) // 2), self.compute_dtype)
173 | / tf.cast(hidden_size, self.compute_dtype),
174 | )
175 | angles = tf.expand_dims(position, 1) * tf.expand_dims(timescales, 0)
176 | # even indices are sine, odd are cosine
177 | cos_mask = tf.cast(tf.range(hidden_size) % 2, self.compute_dtype)
178 | sin_mask = 1 - cos_mask
179 | # embedding shape is [seq_length, hidden_size]
180 | positional_encodings = (
181 | tf.sin(angles) * sin_mask + tf.cos(angles) * cos_mask
182 | )
183 |
184 | return tf.broadcast_to(positional_encodings, input_shape)
185 |
186 | def get_config(self):
187 | config = super().get_config()
188 | config.update(
189 | {
190 | "max_wavelength": self.max_wavelength,
191 | }
192 | )
193 | return config
194 |
--------------------------------------------------------------------------------
/notebooks/11_music/02_musegan/checkpoint/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/11_music/02_musegan/logs/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/11_music/02_musegan/models/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/11_music/02_musegan/musegan_utils.py:
--------------------------------------------------------------------------------
1 | import music21
2 | import numpy as np
3 | from matplotlib import pyplot as plt
4 |
5 |
6 | def binarise_output(output):
7 | # output is a set of scores: [batch size , steps , pitches , tracks]
8 | max_pitches = np.argmax(output, axis=3)
9 | return max_pitches
10 |
11 |
12 | def notes_to_midi(output, n_bars, n_tracks, n_steps_per_bar, filename):
13 | for score_num in range(len(output)):
14 | max_pitches = binarise_output(output)
15 | midi_note_score = max_pitches[score_num].reshape(
16 | [n_bars * n_steps_per_bar, n_tracks]
17 | )
18 | parts = music21.stream.Score()
19 | parts.append(music21.tempo.MetronomeMark(number=66))
20 | for i in range(n_tracks):
21 | last_x = int(midi_note_score[:, i][0])
22 | s = music21.stream.Part()
23 | dur = 0
24 | for idx, x in enumerate(midi_note_score[:, i]):
25 | x = int(x)
26 | if (x != last_x or idx % 4 == 0) and idx > 0:
27 | n = music21.note.Note(last_x)
28 | n.duration = music21.duration.Duration(dur)
29 | s.append(n)
30 | dur = 0
31 | last_x = x
32 | dur = dur + 0.25
33 | n = music21.note.Note(last_x)
34 | n.duration = music21.duration.Duration(dur)
35 | s.append(n)
36 | parts.append(s)
37 | parts.write(
38 | "midi", fp="./output/{}_{}.midi".format(filename, score_num)
39 | )
40 |
41 |
42 | def draw_bar(data, score_num, bar, part):
43 | plt.imshow(
44 | data[score_num, bar, :, :, part].transpose([1, 0]),
45 | origin="lower",
46 | cmap="Greys",
47 | vmin=-1,
48 | vmax=1,
49 | )
50 |
51 |
52 | def draw_score(data, score_num):
53 | n_bars = data.shape[1]
54 | n_tracks = data.shape[-1]
55 |
56 | fig, axes = plt.subplots(
57 | ncols=n_bars, nrows=n_tracks, figsize=(12, 8), sharey=True, sharex=True
58 | )
59 | fig.subplots_adjust(0, 0, 0.2, 1.5, 0, 0)
60 |
61 | for bar in range(n_bars):
62 | for track in range(n_tracks):
63 | if n_bars > 1:
64 | axes[track, bar].imshow(
65 | data[score_num, bar, :, :, track].transpose([1, 0]),
66 | origin="lower",
67 | cmap="Greys",
68 | )
69 | else:
70 | axes[track].imshow(
71 | data[score_num, bar, :, :, track].transpose([1, 0]),
72 | origin="lower",
73 | cmap="Greys",
74 | )
75 |
--------------------------------------------------------------------------------
/notebooks/11_music/02_musegan/output/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/notebooks/utils.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 |
4 | def sample_batch(dataset):
5 | batch = dataset.take(1).get_single_element()
6 | if isinstance(batch, tuple):
7 | batch = batch[0]
8 | return batch.numpy()
9 |
10 |
11 | def display(
12 | images, n=10, size=(20, 3), cmap="gray_r", as_type="float32", save_to=None
13 | ):
14 | """
15 | Displays n random images from each one of the supplied arrays.
16 | """
17 | if images.max() > 1.0:
18 | images = images / 255.0
19 | elif images.min() < 0.0:
20 | images = (images + 1.0) / 2.0
21 |
22 | plt.figure(figsize=size)
23 | for i in range(n):
24 | _ = plt.subplot(1, n, i + 1)
25 | plt.imshow(images[i].astype(as_type), cmap=cmap)
26 | plt.axis("off")
27 |
28 | if save_to:
29 | plt.savefig(save_to)
30 | print(f"\nSaved to {save_to}")
31 |
32 | plt.show()
33 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==3.6.3
2 | jupyterlab==3.5.3
3 | scipy==1.10.0
4 | scikit-learn==1.2.1
5 | scikit-image==0.19.3
6 | pandas==1.5.3
7 | music21==8.1.0
8 | black[jupyter]==23.1.0
9 | click==8.0.4
10 | flake8==5.0.4
11 | kaggle==1.5.12
12 | pydot==1.4.2
13 | ipywidgets==8.0.4
14 | tensorflow==2.10.1
15 | tensorboard==2.10.1
16 | tensorflow_probability==0.18.0
17 | flake8-nb==0.5.2
--------------------------------------------------------------------------------
/sample.env:
--------------------------------------------------------------------------------
1 | JUPYTER_PORT=8888
2 | TENSORBOARD_PORT=6006
3 | KAGGLE_USERNAME=
4 | KAGGLE_KEY=
--------------------------------------------------------------------------------
/scripts/download.sh:
--------------------------------------------------------------------------------
1 | DATASET=$1
2 |
3 | if [ $DATASET = "faces" ]
4 | then
5 | source scripts/downloaders/download_kaggle_data.sh jessicali9530 celeba-dataset
6 | elif [ $DATASET = "bricks" ]
7 | then
8 | source scripts/downloaders/download_kaggle_data.sh joosthazelzet lego-brick-images
9 | elif [ $DATASET = "recipes" ]
10 | then
11 | source scripts/downloaders/download_kaggle_data.sh hugodarwood epirecipes
12 | elif [ $DATASET = "flowers" ]
13 | then
14 | source scripts/downloaders/download_kaggle_data.sh nunenuh pytorch-challange-flower-dataset
15 | elif [ $DATASET = "wines" ]
16 | then
17 | source scripts/downloaders/download_kaggle_data.sh zynicide wine-reviews
18 | elif [ $DATASET = "cellosuites" ]
19 | then
20 | source scripts/downloaders/download_bach_cello_data.sh
21 | elif [ $DATASET = "chorales" ]
22 | then
23 | source scripts/downloaders/download_bach_chorale_data.sh
24 | else
25 | echo "Invalid dataset name - please choose from: faces, bricks, recipes, flowers, wines, cellosuites, chorales"
26 | fi
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/scripts/downloaders/download_bach_cello_data.sh:
--------------------------------------------------------------------------------
1 | docker compose exec app bash -c "
2 | mkdir /app/data/bach-cello/
3 | cd /app/data/bach-cello/ &&
4 | echo 'Downloading...' &&
5 | curl -O http://www.jsbach.net/midi/cs1-1pre.mid -s &&
6 | curl -O http://www.jsbach.net/midi/cs1-2all.mid -s &&
7 | curl -O http://www.jsbach.net/midi/cs1-3cou.mid -s &&
8 | curl -O http://www.jsbach.net/midi/cs1-4sar.mid -s &&
9 | curl -O http://www.jsbach.net/midi/cs1-5men.mid -s &&
10 | curl -O http://www.jsbach.net/midi/cs1-6gig.mid -s &&
11 | curl -O http://www.jsbach.net/midi/cs2-1pre.mid -s &&
12 | curl -O http://www.jsbach.net/midi/cs2-2all.mid -s &&
13 | curl -O http://www.jsbach.net/midi/cs2-3cou.mid -s &&
14 | curl -O http://www.jsbach.net/midi/cs2-4sar.mid -s &&
15 | curl -O http://www.jsbach.net/midi/cs2-5men.mid -s &&
16 | curl -O http://www.jsbach.net/midi/cs2-6gig.mid -s &&
17 | curl -O http://www.jsbach.net/midi/cs3-1pre.mid -s &&
18 | curl -O http://www.jsbach.net/midi/cs3-2all.mid -s &&
19 | curl -O http://www.jsbach.net/midi/cs3-3cou.mid -s &&
20 | curl -O http://www.jsbach.net/midi/cs3-4sar.mid -s &&
21 | curl -O http://www.jsbach.net/midi/cs3-5bou.mid -s &&
22 | curl -O http://www.jsbach.net/midi/cs3-6gig.mid -s &&
23 | curl -O http://www.jsbach.net/midi/cs4-1pre.mid -s &&
24 | curl -O http://www.jsbach.net/midi/cs4-2all.mid -s &&
25 | curl -O http://www.jsbach.net/midi/cs4-3cou.mid -s &&
26 | curl -O http://www.jsbach.net/midi/cs4-4sar.mid -s &&
27 | curl -O http://www.jsbach.net/midi/cs4-5bou.mid -s &&
28 | curl -O http://www.jsbach.net/midi/cs4-6gig.mid -s &&
29 | curl -O http://www.jsbach.net/midi/cs5-1pre.mid -s &&
30 | curl -O http://www.jsbach.net/midi/cs5-2all.mid -s &&
31 | curl -O http://www.jsbach.net/midi/cs5-3cou.mid -s &&
32 | curl -O http://www.jsbach.net/midi/cs5-4sar.mid -s &&
33 | curl -O http://www.jsbach.net/midi/cs5-5gav.mid -s &&
34 | curl -O http://www.jsbach.net/midi/cs5-6gig.mid -s &&
35 | curl -O http://www.jsbach.net/midi/cs6-1pre.mid -s &&
36 | curl -O http://www.jsbach.net/midi/cs6-2all.mid -s &&
37 | curl -O http://www.jsbach.net/midi/cs6-3cou.mid -s &&
38 | curl -O http://www.jsbach.net/midi/cs6-4sar.mid -s &&
39 | curl -O http://www.jsbach.net/midi/cs6-5gav.mid -s &&
40 | curl -O http://www.jsbach.net/midi/cs6-6gig.mid -s &&
41 | echo '🚀 Done!'
42 | "
--------------------------------------------------------------------------------
/scripts/downloaders/download_bach_chorale_data.sh:
--------------------------------------------------------------------------------
1 | docker compose exec app bash -c "
2 | mkdir /app/data/bach-chorales/
3 | cd /app/data/bach-chorales/ &&
4 | echo 'Downloading...' &&
5 | curl -LO https://github.com/czhuang/JSB-Chorales-dataset/raw/master/Jsb16thSeparated.npz -s &&
6 | echo '🚀 Done!'
7 | "
8 |
--------------------------------------------------------------------------------
/scripts/downloaders/download_kaggle_data.sh:
--------------------------------------------------------------------------------
1 | USER=$1
2 | DATASET=$2
3 |
4 | docker compose exec app bash -c "cd /app/data/ && kaggle datasets download -d $USER/$DATASET && echo 'Unzipping...' && unzip -q -o /app/data/$DATASET.zip -d /app/data/$DATASET && rm /app/data/$DATASET.zip && echo '🚀 Done!'"
5 |
6 |
--------------------------------------------------------------------------------
/scripts/format.sh:
--------------------------------------------------------------------------------
1 | docker compose exec app black -l 80 .
2 | docker compose exec app flake8 --max-line-length=80 --exclude=./data --ignore=W503,E203,E402
3 | docker compose exec app flake8_nb --max-line-length=80 --exclude=./data --ignore=W503,E203,E402
--------------------------------------------------------------------------------
/scripts/tensorboard.sh:
--------------------------------------------------------------------------------
1 | CHAPTER=$1
2 | EXAMPLE=$2
3 | echo "Attaching Tensorboard to ./notebooks/$CHAPTER/$EXAMPLE/logs/"
4 | docker compose exec -e CHAPTER=$CHAPTER -e EXAMPLE=$EXAMPLE app bash -c 'tensorboard --logdir "./notebooks/$CHAPTER/$EXAMPLE/logs" --host 0.0.0.0 --port $TENSORBOARD_PORT'
--------------------------------------------------------------------------------