├── JAX_FLAX ├── requirements.txt ├── chapter_03_vae │ └── utils.py └── chapter_05_autoregressive │ └── 01_lstm.ipynb ├── PyTorch ├── requirements.txt ├── chapter_03_vae │ └── utils.py └── chapter_05_autoregressive │ ├── 01_lstm.ipynb │ └── 02_pixelcnn.ipynb ├── LICENSE ├── .gitignore └── README.md /JAX_FLAX/requirements.txt: -------------------------------------------------------------------------------- 1 | -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 2 | jax[cuda11_pip]==0.4.14 3 | flax==0.6.10 4 | clu==0.0.9 5 | matplotlib==3.7.1 6 | tensorflow==2.13.* 7 | tensorflow-datasets==4.9.2 8 | tensorflow-probability==0.21.0 9 | kaggle==1.5.13 10 | pandas==2.0.2 11 | scikit-learn==1.3.0 12 | jupyterlab 13 | jupyterlab-night 14 | ipywidgets 15 | -------------------------------------------------------------------------------- /PyTorch/requirements.txt: -------------------------------------------------------------------------------- 1 | --find-links https://download.pytorch.org/whl/cu118 2 | torch==2.0.1 3 | torchvision==0.15.2 4 | torchaudio==2.0.2 5 | matplotlib==3.7.1 6 | tqdm==4.65.0 7 | torchsummary==1.5.1 8 | torchinfo==1.8.0 9 | kaggle==1.5.13 10 | pandas==2.0.2 11 | scipy==1.10.1 12 | torcheval==0.0.6 13 | tensorboard==2.13.0 14 | torchtext==0.15.2 15 | scikit-learn==1.3.0 16 | jupyterlab 17 | ipywidgets 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Tingsong (Terrence) Ou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .DS_Store 163 | data/ 164 | basics/ 165 | models/ 166 | -------------------------------------------------------------------------------- /JAX_FLAX/chapter_03_vae/utils.py: -------------------------------------------------------------------------------- 1 | '''' 2 | ############################################################# 3 | IMPORTANT: 4 | 5 | This python file adapts from David Foster's code provided 6 | in his book Generative Deep Learning 2nd Edition: 7 | https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition/blob/main/notebooks/03_vae/03_vae_faces/vae_utils.py 8 | The original code is using Apache-2.0 license; this file adapts 9 | the original implementation to the JAX context 10 | 11 | ############################################################# 12 | ''' 13 | 14 | import numpy as np 15 | import matplotlib.pyplot as plt 16 | 17 | 18 | # Get the vector towards the given feature 19 | def get_vector_from_label(dataset, embedding_dim, label, encode_fn, state, rng): 20 | # Initialize parameters 21 | current_sum_POS = np.zeros(shape=embedding_dim, dtype=np.float32) 22 | current_n_POS = 0 23 | current_mean_POS = np.zeros(shape=embedding_dim, dtype=np.float32) 24 | 25 | current_sum_NEG = np.zeros(shape=embedding_dim, dtype=np.float32) 26 | current_n_NEG = 0 27 | current_mean_NEG = np.zeros(shape=embedding_dim, dtype=np.float32) 28 | 29 | current_vector = np.zeros(shape=embedding_dim, dtype=np.float32) 30 | current_dist = 0 31 | 32 | print('label: ' + label) 33 | print('images | POS move | NEG move | distance | 𝛥 distance:') 34 | 35 | total_POS_samples = 5000 36 | curr_iter = 0 37 | while current_n_POS < total_POS_samples: 38 | # Sampling new POS and NEG images 39 | batch = next(iter(dataset)) 40 | imgs = batch[0].numpy() 41 | attributes = batch[1].numpy() 42 | z = encode_fn(state, imgs, rng) 43 | z_POS = z[attributes==1] 44 | z_NEG = z[attributes==-1] 45 | 46 | # Updated both mean vector for both POS and NEG samples 47 | if len(z_POS) > 0: 48 | current_sum_POS = current_sum_POS + np.sum(z_POS, axis=0) 49 | current_n_POS += len(z_POS) 50 | new_mean_POS = current_sum_POS / current_n_POS 51 | movement_POS = np.linalg.norm(new_mean_POS - current_mean_POS) 52 | 53 | if len(z_NEG) > 0: 54 | current_sum_NEG = current_sum_NEG + np.sum(z_NEG, axis=0) 55 | current_n_NEG += len(z_NEG) 56 | new_mean_NEG = current_sum_NEG / current_n_POS 57 | movement_NEG = np.linalg.norm(new_mean_NEG - current_mean_NEG) 58 | 59 | # Updated the feature vector 60 | current_vector = new_mean_POS - new_mean_NEG 61 | new_dist = np.linalg.norm(current_vector) 62 | dist_change = new_dist - current_dist 63 | 64 | # Print the vector-finding process 65 | placeholder = '| ' 66 | if curr_iter % 5 == 0: 67 | print(f'{current_n_POS:6d}', placeholder, 68 | f'{movement_POS:6.3f}', placeholder, 69 | f'{movement_NEG:6.3f}', placeholder, 70 | f'{new_dist:6.3f}', placeholder, 71 | f'{dist_change:6.3f}') 72 | 73 | current_mean_POS = np.copy(new_mean_POS) 74 | current_mean_NEG = np.copy(new_mean_NEG) 75 | current_dist = np.copy(new_dist) 76 | 77 | # When the changing distance is very small, terminate the while loop 78 | stop_thresh = 8e-2 79 | if np.sum([movement_POS, movement_NEG]) < stop_thresh: 80 | current_vector = current_vector / current_dist 81 | print('Found the ' + label + ' vector') 82 | break 83 | 84 | curr_iter += 1 85 | return current_vector 86 | 87 | 88 | # Plot the feature transitions 89 | def add_vector_to_images(dataset, feature_vec, encode_fn, decode_fn, state, rng): 90 | n_plots = 5 91 | factors = np.arange(-4, 5) 92 | batch = next(iter(dataset)) 93 | imgs = batch[0].numpy() 94 | 95 | # Get image embeddings 96 | z = encode_fn(state, imgs, rng) 97 | 98 | fig = plt.figure(figsize=(18, 10)) 99 | counter = 1 100 | 101 | for i in range(n_plots): 102 | img = imgs[i] 103 | ax = fig.add_subplot(n_plots, len(factors) + 1, counter) 104 | ax.axis('off') 105 | ax.imshow(img) 106 | counter += 1 107 | # Add transition images 108 | for factor in factors: 109 | new_z_sample = z[i] + feature_vec * factor 110 | generated_img = decode_fn(state, np.expand_dims(new_z_sample, axis=0))[0] 111 | ax = fig.add_subplot(n_plots, len(factors) + 1, counter) 112 | ax.axis('off') 113 | ax.imshow(generated_img) 114 | counter += 1 115 | plt.show() 116 | 117 | 118 | # Morph between two faces 119 | def morph_faces(dataset, encode_fn, decode_fn, state, rng): 120 | factors = np.arange(0.0, 1.0, 0.1) 121 | 122 | sample_faces = next(iter(dataset))[0][:2] # sample two faces 123 | imgs = sample_faces.numpy() 124 | z = encode_fn(state, imgs, rng) 125 | fig = plt.figure(figsize=(18, 8)) 126 | counter = 1 127 | 128 | face_a = imgs[0] 129 | face_b = imgs[1] 130 | 131 | # show original face 132 | ax = fig.add_subplot(1, len(factors) + 2, counter) 133 | ax.axis('off') 134 | ax.imshow(face_a) 135 | 136 | counter += 1 137 | 138 | # plot transitions 139 | for factor in factors: 140 | factored_z = z[0] * (1 - factor) + z[1] * factor 141 | generated_img = decode_fn(state, np.expand_dims(factored_z, axis=0))[0] 142 | ax = fig.add_subplot(1, len(factors) + 2, counter) 143 | ax.axis('off') 144 | ax.imshow(generated_img) 145 | counter += 1 146 | 147 | # show target face 148 | ax = fig.add_subplot(1, len(factors) + 2, counter) 149 | ax.axis('off') 150 | ax.imshow(face_b) 151 | 152 | plt.show() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # **Generative Deep Learning 2nd Edition in JAX and PyTorch 2.0** 2 | 3 | This repository includes the Pytorch 2.0 and JAX implementations of examples in **Generative Deep Learning 2nd Edition** by *David Foster*. You can find the physical/kindle book on [Amazon](https://www.amazon.com/Generative-Deep-Learning-Teaching-Machines/dp/1098134184/ref=sr_1_1?keywords=generative+deep+learning%2C+2nd+edition&qid=1684898042&sprefix=generative%2Caps%2C96&sr=8-1&ufe=app_do%3Aamzn1.fos.006c50ae-5d4c-4777-9bc0-4513d670b6bc) or read it online through [O'REILLY library](https://learning.oreilly.com/home/) (paid subscription needed). 4 | 5 | ### **Motivation of this project** 6 | I started my journey of deep learning with the help of the first edition of this book. The author introduces topics of generative deep learning in clear and concise way that hels me quickly grasp the key points of type of algorithms without being freak out by heavy mathematics. The codes in the book, written in Tensorflow and Keras, helps me quickly making theories into practice.
7 | We now have other popular deep learning frameworks, like PyTorch and JAX, used by various ML communities. Therefore, I want to translate the Tensorflow and Keras code provided in the book to PyTorch and JAX to help more people study this valuable book more easily. 8 | 9 | ### **File structure** 10 | The files are organized by frameworks: 11 | ```bash 12 | ├── JAX_FLAX 13 | │ ├── chapter_**_** 14 | │ │ ├── **.ipynb 15 | │ ├── requirements.txt 16 | ├── PyTorch 17 | │ ├── chapter_**_** 18 | │ │ ├── **.ipynb 19 | │ ├── requirements.txt 20 | ├── .gitignore 21 | ``` 22 | 23 | ## **Environment setup** 24 | I recommend using the separated environments for PyTorch and JAX to avoid potential conflicts on CUDA versions or other related packages. I use [`miniconda`](https://docs.conda.io/en/latest/miniconda.html) to help managing packages for both environments.
25 | 26 | Configure `PyTorch` environment: 27 | ```bash 28 | cd PyTorch 29 | conda create -n GDL_PyTorch python==3.9 30 | conda activate GDL_PyTorch 31 | pip install -r requirements.txt 32 | ``` 33 | NOTE: If you're using PyTorch on WSL, please add `export LD_LIBRARY_PATH=/usr/lib/wsl/lib:$LD_LIBRARY_PATH` to `~/.bashrc` to avoid kernel restarting problems.
34 |
35 | Configure `JAX` environment: 36 | ```bash 37 | cd JAX_FLAX 38 | conda create -n GDL_JAX python==3.9 39 | conda activate GDL_JAX 40 | pip install -r requirements.txt 41 | ``` 42 | 43 | `.ipynb` is the extension of the python notebook; I use [Jupyter Lab](https://jupyter.org/install) to run the notebooks in this repository. 44 | 45 | ## **Model list** 46 | ### *Chapter 2 Deep Learning* 47 | - MLP ([Pytorch](https://github.com/terrence-ou/Generative-Deep-Learning-2nd-Edition-PyTorch-JAX/blob/main/PyTorch/chapter_02_deeplearning/01_MLP.ipynb) | [JAX](https://github.com/terrence-ou/Generative-Deep-Learning-2nd-Edition-PyTorch-JAX/blob/main/JAX_FLAX/chapter_02_deeplearning/01_MLP.ipynb)) 48 | - CNN ([Pytorch](https://github.com/terrence-ou/Generative-Deep-Learning-2nd-Edition-PyTorch-JAX/blob/main/PyTorch/chapter_02_deeplearning/02_CNN.ipynb) | [JAX](https://github.com/terrence-ou/Generative-Deep-Learning-2nd-Edition-PyTorch-JAX/blob/main/JAX_FLAX/chapter_02_deeplearning/02_CNN.ipynb)) 49 | ### *Chapter 3 Variational AutoEncoder* 50 | - AutoEncoder ([Pytorch](https://github.com/terrence-ou/Generative-Deep-Learning-2nd-Edition-PyTorch-JAX/blob/main/PyTorch/chapter_03_vae/01_autoencoder.ipynb) | [JAX](https://github.com/terrence-ou/Generative-Deep-Learning-2nd-Edition-PyTorch-JAX/blob/main/JAX_FLAX/chapter_03_vae/01_autoencoder.ipynb)) 51 | - VAE (FashionMNIST) ([Pytorch](https://github.com/terrence-ou/Generative-Deep-Learning-2nd-Edition-PyTorch-JAX/blob/main/PyTorch/chapter_03_vae/02_vae_fashion.ipynb) | [JAX](https://github.com/terrence-ou/Generative-Deep-Learning-2nd-Edition-PyTorch-JAX/blob/main/JAX_FLAX/chapter_03_vae/02_vae_fashion.ipynb)) 52 | - VAE (CelebA Dataset) ([Pytorch](https://github.com/terrence-ou/Generative-Deep-Learning-2nd-Edition-PyTorch-JAX/blob/main/PyTorch/chapter_03_vae/03_vae_face.ipynb) | [JAX](https://github.com/terrence-ou/Generative-Deep-Learning-2nd-Edition-PyTorch-JAX/blob/main/JAX_FLAX/chapter_03_vae/03_vae_faces.ipynb)) 53 | ### *Chapter 4 Generative Adversarial Networks* 54 | - DCGAN ([Pytorch](https://github.com/terrence-ou/Generative-Deep-Learning-2nd-Edition-PyTorch-JAX/blob/main/PyTorch/chapter_04_gan/01_dcgan.ipynb) | [JAX](https://github.com/terrence-ou/Generative-Deep-Learning-2nd-Edition-PyTorch-JAX/blob/main/JAX_FLAX/chapter_04_gan/01_dcgan.ipynb)) 55 | - WGAN-GP ([Pytorch](https://github.com/terrence-ou/Generative-Deep-Learning-2nd-Edition-PyTorch-JAX/blob/main/PyTorch/chapter_04_gan/02_wgan_gp.ipynb) | [JAX](https://github.com/terrence-ou/Generative-Deep-Learning-2nd-Edition-PyTorch-JAX/blob/main/JAX_FLAX/chapter_04_gan/02_wgan_gp.ipynb)) 56 | - Conditional GAN ([Pytorch](https://github.com/terrence-ou/Generative-Deep-Learning-2nd-Edition-PyTorch-JAX/blob/main/PyTorch/chapter_04_gan/03_cgan.ipynb) | [JAX](https://github.com/terrence-ou/Generative-Deep-Learning-2nd-Edition-PyTorch-JAX/blob/main/JAX_FLAX/chapter_04_gan/03_cgan.ipynb)) 57 | ### *Chapter 5 Autoregressive Models* 58 | - LSTM ([Pytorch](https://github.com/terrence-ou/Generative-Deep-Learning-2nd-Edition-PyTorch-JAX/blob/main/PyTorch/chapter_05_autoregressive/01_lstm.ipynb) | [JAX](https://github.com/terrence-ou/Generative-Deep-Learning-2nd-Edition-PyTorch-JAX/blob/main/JAX_FLAX/chapter_05_autoregressive/01_lstm.ipynb)) 59 | - PixelCNN (FashionMNIST) ([Pytorch](https://github.com/terrence-ou/Generative-Deep-Learning-2nd-Edition-PyTorch-JAX/blob/main/PyTorch/chapter_05_autoregressive/02_pixelcnn.ipynb) | [JAX](https://github.com/terrence-ou/Generative-Deep-Learning-2nd-Edition-PyTorch-JAX/blob/main/JAX_FLAX/chapter_05_autoregressive/02_pixelcnn.ipynb)) 60 | ### *Chapter 6 Normalizing Flow Models* 61 | - RealNVP ([Pytorch](https://github.com/terrence-ou/Generative-Deep-Learning-2nd-Edition-PyTorch-JAX/blob/main/PyTorch/chapter_06_normalizing_flow/01_realnvp.ipynb) | [JAX](https://github.com/terrence-ou/Generative-Deep-Learning-2nd-Edition-PyTorch-JAX/blob/main/JAX_FLAX/chapter_06_normalizing_flow/01_realnvp.ipynb)) 62 | ### *Chapter 7 Energy-Based Models* 63 | - EBM (MNIST) ([Pytorch](https://github.com/terrence-ou/Generative-Deep-Learning-2nd-Edition-PyTorch-JAX/blob/main/PyTorch/chapter_07_energy_based_model/01_ebm.ipynb) | [JAX](https://github.com/terrence-ou/Generative-Deep-Learning-2nd-Edition-PyTorch-JAX/blob/main/JAX_FLAX/chapter_07_energy_based_model/01_ebm.ipynb)) -------------------------------------------------------------------------------- /PyTorch/chapter_03_vae/utils.py: -------------------------------------------------------------------------------- 1 | '''' 2 | ############################################################# 3 | IMPORTANT: 4 | 5 | This python file adapts from David Foster's code provided 6 | in his book Generative Deep Learning 2nd Edition: 7 | https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition/blob/main/notebooks/03_vae/03_vae_faces/vae_utils.py 8 | The original code is using Apache-2.0 license; this file adapts 9 | the original implementation to the PyTorch 2.0 context 10 | 11 | ############################################################# 12 | ''' 13 | 14 | import numpy as np 15 | import matplotlib.pyplot as plt 16 | import torch 17 | 18 | 19 | # Get the vector towards the given feature 20 | def get_vector_from_label(dataloader, embedding_dim, label, model): 21 | model.eval() 22 | device = next(model.parameters()).device 23 | # Initialize parameters 24 | current_sum_POS = np.zeros(shape=embedding_dim, dtype=np.float32) 25 | current_n_POS = 0 26 | current_mean_POS = np.zeros(shape=embedding_dim, dtype=np.float32) 27 | 28 | current_sum_NEG = np.zeros(shape=embedding_dim, dtype=np.float32) 29 | current_n_NEG = 0 30 | current_mean_NEG = np.zeros(shape=embedding_dim, dtype=np.float32) 31 | 32 | current_vector = np.zeros(shape=embedding_dim, dtype=np.float32) 33 | current_dist = 0 34 | 35 | print('label: ' + label) 36 | print('images | POS move | NEG move | distance | 𝛥 distance:') 37 | 38 | total_POS_samples = 5000 39 | curr_iter = 0 40 | while current_n_POS < total_POS_samples: 41 | # Sampling new POS and NEG images 42 | imgs, labels = next(iter(dataloader)) 43 | imgs = imgs.to(device) 44 | with torch.no_grad(): 45 | mean, logvar = model.encoder(imgs) 46 | z = model.reparameterize(mean, logvar) 47 | z_POS = z[labels==1].detach().cpu().numpy() 48 | z_NEG = z[labels==-1].detach().cpu().numpy() 49 | 50 | # Updated both mean vector for both POS and NEG samples 51 | if len(z_POS) > 0: 52 | current_sum_POS = current_sum_POS + np.sum(z_POS, axis=0) 53 | current_n_POS += len(z_POS) 54 | new_mean_POS = current_sum_POS / current_n_POS 55 | movement_POS = np.linalg.norm(new_mean_POS - current_mean_POS) 56 | 57 | if len(z_NEG) > 0: 58 | current_sum_NEG = current_sum_NEG + np.sum(z_NEG, axis=0) 59 | current_n_NEG += len(z_NEG) 60 | new_mean_NEG = current_sum_NEG / current_n_POS 61 | movement_NEG = np.linalg.norm(new_mean_NEG - current_mean_NEG) 62 | 63 | # Updated the feature vector 64 | current_vector = new_mean_POS - new_mean_NEG 65 | new_dist = np.linalg.norm(current_vector) 66 | dist_change = new_dist - current_dist 67 | 68 | # Print the vector-finding process 69 | placeholder = '| ' 70 | if curr_iter % 5 == 0: 71 | print(f'{current_n_POS:6d}', placeholder, 72 | f'{movement_POS:6.3f}', placeholder, 73 | f'{movement_NEG:6.3f}', placeholder, 74 | f'{new_dist:6.3f}', placeholder, 75 | f'{dist_change:6.3f}') 76 | 77 | current_mean_POS = np.copy(new_mean_POS) 78 | current_mean_NEG = np.copy(new_mean_NEG) 79 | current_dist = np.copy(new_dist) 80 | 81 | # When the changing distance is very small, terminate the while loop 82 | stop_thresh = 8e-2 83 | if np.sum([movement_POS, movement_NEG]) < stop_thresh: 84 | current_vector = current_vector / current_dist 85 | print('Found the ' + label + ' vector') 86 | break 87 | 88 | curr_iter += 1 89 | return current_vector 90 | 91 | 92 | # Plot the feature transitions 93 | def add_vector_to_images(dataloader, feature_vec, model): 94 | 95 | model.eval() 96 | 97 | n_plots = 5 98 | factors = np.arange(-4, 5) 99 | 100 | device = next(model.parameters()).device 101 | imgs, labels = next(iter(dataloader)) 102 | imgs = imgs.to(device) 103 | 104 | # Get image embeddings 105 | with torch.no_grad(): 106 | mean, logvar = model.encoder(imgs) 107 | z = model.reparameterize(mean, logvar) 108 | 109 | fig = plt.figure(figsize=(18, 10)) 110 | counter = 1 111 | 112 | imgs = imgs.detach().cpu().permute(0, 2, 3, 1) 113 | 114 | for i in range(n_plots): 115 | img = imgs[i] 116 | ax = fig.add_subplot(n_plots, len(factors) + 1, counter) 117 | ax.axis('off') 118 | ax.imshow(img) 119 | counter += 1 120 | # Add transition images 121 | for factor in factors: 122 | new_z_sample = z[i] + torch.tensor(feature_vec * factor, dtype=torch.float32, device=device) 123 | generated_img = model.decoder(new_z_sample.unsqueeze(0))[0] 124 | generated_img = generated_img.detach().cpu().permute(1, 2, 0) 125 | ax = fig.add_subplot(n_plots, len(factors) + 1, counter) 126 | ax.axis('off') 127 | ax.imshow(generated_img) 128 | counter += 1 129 | plt.show() 130 | 131 | 132 | # Morph between two faces 133 | def morph_faces(dataloader, model): 134 | 135 | device = next(model.parameters()).device 136 | factors = np.arange(0.0, 1.0, 0.1) 137 | 138 | sample_faces = next(iter(dataloader))[0][:2] # sample two faces 139 | # imgs = sample_faces.numpy() 140 | # z = encode_fn(state, imgs, rng) 141 | 142 | sample_faces = sample_faces.to(device) 143 | with torch.no_grad(): 144 | mean, logvar = model.encoder(sample_faces) 145 | z = model.reparameterize(mean, logvar) 146 | 147 | fig = plt.figure(figsize=(18, 8)) 148 | counter = 1 149 | 150 | face_a = sample_faces[0].detach().cpu().permute(1, 2, 0) 151 | face_b = sample_faces[1].detach().cpu().permute(1, 2, 0) 152 | 153 | # show original face 154 | ax = fig.add_subplot(1, len(factors) + 2, counter) 155 | ax.axis('off') 156 | ax.imshow(face_a) 157 | 158 | counter += 1 159 | 160 | # plot transitions 161 | for factor in factors: 162 | factored_z = (z[0] * (1 - factor) + z[1] * factor).to(device) 163 | generated_img = model.decoder(factored_z.unsqueeze(0))[0] 164 | ax = fig.add_subplot(1, len(factors) + 2, counter) 165 | ax.axis('off') 166 | ax.imshow(generated_img.detach().cpu().permute(1, 2, 0)) 167 | counter += 1 168 | 169 | # show target face 170 | ax = fig.add_subplot(1, len(factors) + 2, counter) 171 | ax.axis('off') 172 | ax.imshow(face_b) 173 | 174 | plt.show() -------------------------------------------------------------------------------- /JAX_FLAX/chapter_05_autoregressive/01_lstm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b4e1ee6e-4367-4ef1-b15c-9ad225053267", 6 | "metadata": {}, 7 | "source": [ 8 | "# LSTM on Recipe Data" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "4a8e0b4f-1228-4eda-8842-9bf19873b034", 14 | "metadata": {}, 15 | "source": [ 16 | "**The notebook has been adapted from the notebook provided in David Foster's Generative Deep Learning, 2nd Edition.**\n", 17 | "\n", 18 | "- Book: [Amazon](https://www.amazon.com/Generative-Deep-Learning-Teaching-Machines/dp/1098134184/ref=sr_1_1?keywords=generative+deep+learning%2C+2nd+edition&qid=1684708209&sprefix=generative+de%2Caps%2C93&sr=8-1)\n", 19 | "- Original notebook (tensorflow and keras): [Github](https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition/blob/main/notebooks/05_autoregressive/01_lstm/lstm.ipynb)\n", 20 | "- Dataset: [Kaggle](https://www.kaggle.com/datasets/hugodarwood/epirecipes)" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 1, 26 | "id": "2d625ea0-486e-4619-ad9d-0ebbd7c24d04", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import re\n", 31 | "import string\n", 32 | "import json\n", 33 | "from collections import defaultdict\n", 34 | "import time\n", 35 | "\n", 36 | "import numpy as np\n", 37 | "import jax\n", 38 | "import jax.numpy as jnp\n", 39 | "from tensorflow.data import Dataset\n", 40 | "from tensorflow.keras.layers import TextVectorization\n", 41 | "from tensorflow.keras import utils\n", 42 | "\n", 43 | "from flax import struct\n", 44 | "from flax.training import train_state\n", 45 | "import flax.linen as nn\n", 46 | "import optax\n", 47 | "\n", 48 | "from clu import metrics" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "id": "ef76dd9c-8a94-4b81-907d-b0a64ee0c905", 54 | "metadata": {}, 55 | "source": [ 56 | "## 0. Train parameters" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 2, 62 | "id": "de4684fb-c1d3-41b8-b7da-4ea8da5dfa47", 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "DATA_DIR = '../../data/epirecipes/full_format_recipes.json'\n", 67 | "\n", 68 | "EMBEDDING_DIM = 100\n", 69 | "HIDDEN_DIM = 128\n", 70 | "NUM_LSTM_LAYERS = 1\n", 71 | "VALIDATION_SPLIT = 0.2\n", 72 | "BATCH_SIZE = 32\n", 73 | "EPOCHS = 30\n", 74 | "VOCAB_SIZE = 8200\n", 75 | "LR = 1e-3\n", 76 | "\n", 77 | "MAX_PAD_LEN = 200\n", 78 | "MAX_VAL_TOKENS = 100 # Max number of tokens when generating texts" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "id": "c7785935-9cdb-4b7f-87d4-c62df452fa08", 84 | "metadata": {}, 85 | "source": [ 86 | "## 1. Load dataset" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 3, 92 | "id": "0c07bf4e-9e45-4811-a65c-b0e5378aeeed", 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "def pad_punctuation(sentence):\n", 97 | " sentence = re.sub(f'([{string.punctuation}])', r' \\1 ', sentence)\n", 98 | " sentence = re.sub(' +', ' ', sentence)\n", 99 | " return sentence" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 4, 105 | "id": "4ad38cac-1ae6-42f7-95ef-3a9bff45e293", 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "# load dataset\n", 110 | "with open(DATA_DIR, 'r+') as f:\n", 111 | " recipe_data = json.load(f)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 5, 117 | "id": "a6f19bd7-9d03-4f72-8191-7d082c22663d", 118 | "metadata": {}, 119 | "outputs": [ 120 | { 121 | "name": "stdout", 122 | "output_type": "stream", 123 | "text": [ 124 | "Total recipe loaded: 20098\n" 125 | ] 126 | } 127 | ], 128 | "source": [ 129 | "# preprocess dataset\n", 130 | "filtered_data = [\n", 131 | " 'Recipe for ' + x['title'] + ' | ' + ' '.join(x['directions'])\n", 132 | " for x in recipe_data\n", 133 | " if 'title' in x and x['title']\n", 134 | " and 'directions' in x and x['directions']\n", 135 | "]\n", 136 | "\n", 137 | "text_ds = [pad_punctuation(sentence) for sentence in filtered_data]\n", 138 | "print(f'Total recipe loaded: {len(text_ds)}')" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 6, 144 | "id": "a0bc4013-f1cb-46ea-81d1-abc8017365dd", 145 | "metadata": {}, 146 | "outputs": [ 147 | { 148 | "name": "stdout", 149 | "output_type": "stream", 150 | "text": [ 151 | "Sample data:\n", 152 | "Recipe for Cucumber , Radish , and Red Onion Salad | Peel , halve , and seed cucumber . Diagonally cut cucumber into thin slices and cut radishes into julienne strips . In a bowl toss together all ingredients and season with salt and pepper . \n" 153 | ] 154 | } 155 | ], 156 | "source": [ 157 | "print('Sample data:')\n", 158 | "sample_data = np.random.choice(text_ds)\n", 159 | "print(sample_data)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "id": "a2809b15-a4bc-488c-a8cd-ae70f5fd1dae", 165 | "metadata": {}, 166 | "source": [ 167 | "## 2. Build vocabularies" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 7, 173 | "id": "ef8cf15a-6a3c-4b16-a7f5-06426b294859", 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "# conver texts list to tf dataset\n", 178 | "text_ds_tf = Dataset.from_tensor_slices(text_ds)\n", 179 | "\n", 180 | "vectorize_layer = TextVectorization(\n", 181 | " standardize='lower',\n", 182 | " max_tokens=VOCAB_SIZE,\n", 183 | " output_mode='int',\n", 184 | " output_sequence_length=MAX_PAD_LEN+1\n", 185 | ")" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 8, 191 | "id": "b1ccea64-db03-4e3b-b30c-6c2e4c88f92c", 192 | "metadata": {}, 193 | "outputs": [ 194 | { 195 | "name": "stdout", 196 | "output_type": "stream", 197 | "text": [ 198 | "0: \n", 199 | "1: [UNK]\n", 200 | "2: .\n", 201 | "3: ,\n", 202 | "4: and\n", 203 | "5: to\n", 204 | "6: in\n", 205 | "7: the\n", 206 | "8: with\n", 207 | "9: a\n" 208 | ] 209 | } 210 | ], 211 | "source": [ 212 | "vectorize_layer.adapt(text_ds_tf)\n", 213 | "vocab = vectorize_layer.get_vocabulary()\n", 214 | "index_to_word = {index : word for index, word in enumerate(vocab)}\n", 215 | "word_to_index = {word : index for index, word in enumerate(vocab)}\n", 216 | "\n", 217 | "# First 10 items in the vocabulary\n", 218 | "for i, word in enumerate(vocab[:10]):\n", 219 | " print(f'{i}: {word}')" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": 9, 225 | "id": "3c385672-4086-4ddb-8113-e9efd56c2c6e", 226 | "metadata": {}, 227 | "outputs": [ 228 | { 229 | "name": "stdout", 230 | "output_type": "stream", 231 | "text": [ 232 | "Source text:\n", 233 | "Recipe for Cucumber , Radish , and Red Onion Salad | Peel , halve , and seed cucumber . Diagonally cut cucumber into thin slices and cut radishes into julienne strips . In a bowl toss together all ingredients and season with salt and pepper . \n", 234 | "\n", 235 | "\n", 236 | "Mapped sample:\n", 237 | "[ 26 16 569 3 1362 3 4 282 115 189 27 175 3 538\n", 238 | " 3 4 805 569 2 932 74 569 25 355 160 4 74 941\n", 239 | " 25 1710 393 2 6 9 21 117 110 122 131 4 63 8\n", 240 | " 24 4 33 2 0 0 0 0 0 0 0 0 0 0\n", 241 | " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", 242 | " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", 243 | " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", 244 | " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", 245 | " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", 246 | " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", 247 | " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", 248 | " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", 249 | " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", 250 | " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", 251 | " 0 0 0 0 0]\n" 252 | ] 253 | } 254 | ], 255 | "source": [ 256 | "sample_data_tokenized = vectorize_layer(sample_data)\n", 257 | "print('Source text:')\n", 258 | "print(sample_data)\n", 259 | "print('\\n')\n", 260 | "print('Mapped sample:')\n", 261 | "print(sample_data_tokenized.numpy())" 262 | ] 263 | }, 264 | { 265 | "cell_type": "markdown", 266 | "id": "8d4c4727-7bd6-49d3-a4b5-1ea3e37fd4bb", 267 | "metadata": {}, 268 | "source": [ 269 | "## 3. Create train/validation datasets" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 10, 275 | "id": "a717b2b3-68ce-4511-8c7a-6db64813f4ab", 276 | "metadata": {}, 277 | "outputs": [], 278 | "source": [ 279 | "# map a single text slice into source and targets\n", 280 | "def map_src_tgt(text):\n", 281 | " tokenized_sentence = vectorize_layer(text)\n", 282 | " src = tokenized_sentence[:-1]\n", 283 | " tgt = tokenized_sentence[1:]\n", 284 | " return src, tgt\n", 285 | " \n", 286 | "# create datasets\n", 287 | "def get_datasets(input_ds):\n", 288 | " train_size = int(len(input_ds) * (1 - VALIDATION_SPLIT))\n", 289 | " train_ds = input_ds.take(train_size) # take train dataset\n", 290 | " valid_ds = input_ds.skip(train_size) # take validation dataset\n", 291 | " print(f'train size: {train_ds.cardinality()}, valid size: {valid_ds.cardinality()}')\n", 292 | "\n", 293 | " train_ds = train_ds.map(map_src_tgt)\n", 294 | " valid_ds = valid_ds.map(map_src_tgt)\n", 295 | " \n", 296 | " train_ds = train_ds.batch(BATCH_SIZE).shuffle(1024).prefetch(1)\n", 297 | " valid_ds = valid_ds.batch(BATCH_SIZE).prefetch(1)\n", 298 | "\n", 299 | " print(f'train batch: {train_ds.cardinality()}, valid batch: {valid_ds.cardinality()}')\n", 300 | " return train_ds, valid_ds" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": 11, 306 | "id": "c14bc2df-97b5-404c-b600-e6481f16a187", 307 | "metadata": {}, 308 | "outputs": [ 309 | { 310 | "name": "stdout", 311 | "output_type": "stream", 312 | "text": [ 313 | "train size: 16078, valid size: 4020\n", 314 | "train batch: 503, valid batch: 126\n" 315 | ] 316 | } 317 | ], 318 | "source": [ 319 | "train_ds, valid_ds = get_datasets(text_ds_tf)" 320 | ] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "id": "c93c4189-e96b-44d0-977e-c1eff2c61607", 325 | "metadata": {}, 326 | "source": [ 327 | "## 4. Build LSTM model" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": 12, 333 | "id": "51f616bd-e50c-4c44-ab5b-fcfcc686bdfc", 334 | "metadata": {}, 335 | "outputs": [ 336 | { 337 | "name": "stdout", 338 | "output_type": "stream", 339 | "text": [ 340 | "\n", 341 | "\u001b[3m LSTM_model Summary \u001b[0m\n", 342 | "┏━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", 343 | "┃\u001b[1m \u001b[0m\u001b[1mpath \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmodule \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1minputs \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1moutputs \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mparams \u001b[0m\u001b[1m \u001b[0m┃\n", 344 | "┡━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", 345 | "│ │ LSTM_model │ \u001b[2mint32\u001b[0m[32,200] │ \u001b[2mfloat32\u001b[0m[32,200,8200] │ │\n", 346 | "├───────────────┼────────────┼─────────────────────┼──────────────────────┼──────────────────────────────┤\n", 347 | "│ embed │ Embed │ \u001b[2mint32\u001b[0m[32,200] │ \u001b[2mfloat32\u001b[0m[32,200,128] │ embedding: \u001b[2mfloat32\u001b[0m[8200,128] │\n", 348 | "│ │ │ │ │ │\n", 349 | "│ │ │ │ │ \u001b[1m1,049,600 \u001b[0m\u001b[1;2m(4.2 MB)\u001b[0m │\n", 350 | "├───────────────┼────────────┼─────────────────────┼──────────────────────┼──────────────────────────────┤\n", 351 | "│ lstm_layers_0 │ RNN │ \u001b[2mfloat32\u001b[0m[32,200,128] │ \u001b[2mfloat32\u001b[0m[32,200,128] │ \u001b[1m131,584 \u001b[0m\u001b[1;2m(526.3 KB)\u001b[0m │\n", 352 | "├───────────────┼────────────┼─────────────────────┼──────────────────────┼──────────────────────────────┤\n", 353 | "│ dense │ Dense │ \u001b[2mfloat32\u001b[0m[32,200,128] │ \u001b[2mfloat32\u001b[0m[32,200,8200] │ bias: \u001b[2mfloat32\u001b[0m[8200] │\n", 354 | "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[128,8200] │\n", 355 | "│ │ │ │ │ │\n", 356 | "│ │ │ │ │ \u001b[1m1,057,800 \u001b[0m\u001b[1;2m(4.2 MB)\u001b[0m │\n", 357 | "├───────────────┼────────────┼─────────────────────┼──────────────────────┼──────────────────────────────┤\n", 358 | "│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m Total\u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m2,238,984 \u001b[0m\u001b[1;2m(9.0 MB)\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\n", 359 | "└───────────────┴────────────┴─────────────────────┴──────────────────────┴──────────────────────────────┘\n", 360 | "\u001b[1m \u001b[0m\n", 361 | "\u001b[1m Total Parameters: 2,238,984 \u001b[0m\u001b[1;2m(9.0 MB)\u001b[0m\u001b[1m \u001b[0m\n", 362 | "\n", 363 | "\n" 364 | ] 365 | } 366 | ], 367 | "source": [ 368 | "class LSTM_model(nn.Module):\n", 369 | "\n", 370 | " num_lstm_layers: int\n", 371 | " \n", 372 | " def setup(self):\n", 373 | " self.embed = nn.Embed(num_embeddings=VOCAB_SIZE, features=HIDDEN_DIM)\n", 374 | " \n", 375 | " self.lstm_layers = [nn.RNN(nn.OptimizedLSTMCell(), HIDDEN_DIM) \n", 376 | " for _ in range(self.num_lstm_layers)]\n", 377 | " \n", 378 | " self.dense = nn.Dense(features=VOCAB_SIZE)\n", 379 | "\n", 380 | " def __call__(self, x):\n", 381 | " # Embedding\n", 382 | " x = self.embed(x)\n", 383 | " for lstm in self.lstm_layers:\n", 384 | " x = lstm(x)\n", 385 | " \n", 386 | " # Dense layer\n", 387 | " x = self.dense(x)\n", 388 | " return x\n", 389 | "\n", 390 | "lstm_model = LSTM_model(NUM_LSTM_LAYERS)\n", 391 | "rng = jax.random.PRNGKey(0)\n", 392 | "\n", 393 | "print(lstm_model.tabulate(rng, jnp.ones((BATCH_SIZE, MAX_PAD_LEN), dtype=int), depth=1))" 394 | ] 395 | }, 396 | { 397 | "cell_type": "markdown", 398 | "id": "c655caae-b24f-4b7d-aaf2-0f23e42fb42c", 399 | "metadata": {}, 400 | "source": [ 401 | "## 5. Create `TrainState`" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": 13, 407 | "id": "ad1c0d0f-8e64-4757-9c40-9e00a000f028", 408 | "metadata": {}, 409 | "outputs": [], 410 | "source": [ 411 | "@struct.dataclass\n", 412 | "class Metrics(metrics.Collection):\n", 413 | " loss: metrics.Average.from_output('loss')\n", 414 | "\n", 415 | "class TrainState(train_state.TrainState):\n", 416 | " metrics: Metrics\n", 417 | "\n", 418 | "# train state for lstm model\n", 419 | "def create_train_state(model, param_key, learning_rate):\n", 420 | " # initialize model\n", 421 | " params = model.init(param_key, jnp.ones((BATCH_SIZE, MAX_PAD_LEN), dtype=int))['params']\n", 422 | " # initialize optimizer\n", 423 | " tx = optax.adam(learning_rate=learning_rate)\n", 424 | " return TrainState.create(\n", 425 | " apply_fn=model.apply,\n", 426 | " params=params,\n", 427 | " tx=tx,\n", 428 | " metrics=Metrics.empty())" 429 | ] 430 | }, 431 | { 432 | "cell_type": "markdown", 433 | "id": "cb99e7e0-091a-43f8-88da-1c688debed12", 434 | "metadata": {}, 435 | "source": [ 436 | "## 6. Train step functions" 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": 14, 442 | "id": "1249a94d-8cde-4c3b-8117-5c4d2dee48e3", 443 | "metadata": {}, 444 | "outputs": [], 445 | "source": [ 446 | "# train step\n", 447 | "@jax.jit\n", 448 | "def train_step(state, batch):\n", 449 | " def loss_fn(params):\n", 450 | " preds = state.apply_fn({'params': params}, batch[0])\n", 451 | " loss = optax.softmax_cross_entropy_with_integer_labels(preds, batch[1]).mean()\n", 452 | " return loss\n", 453 | "\n", 454 | " # compute loss and apply gradients\n", 455 | " grad_fn = jax.value_and_grad(loss_fn)\n", 456 | " loss, grads = grad_fn(state.params)\n", 457 | " state = state.apply_gradients(grads=grads)\n", 458 | "\n", 459 | " # Update metrics\n", 460 | " metric_updates = state.metrics.single_from_model_output(loss=loss)\n", 461 | " metrics = state.metrics.merge(metric_updates)\n", 462 | " state = state.replace(metrics=metrics)\n", 463 | " return state \n", 464 | "\n", 465 | "\n", 466 | "# evaluation\n", 467 | "@jax.jit\n", 468 | "def validation(state, batch):\n", 469 | " preds = state.apply_fn({'params': state.params}, batch[0])\n", 470 | " loss = optax.softmax_cross_entropy_with_integer_labels(preds, batch[1]).mean()\n", 471 | "\n", 472 | " # Update metrics\n", 473 | " metric_updates = state.metrics.single_from_model_output(loss=loss)\n", 474 | " metrics = state.metrics.merge(metric_updates)\n", 475 | " state = state.replace(metrics=metrics)\n", 476 | " return state" 477 | ] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "execution_count": 15, 482 | "id": "72be04d9-b2b7-4e96-a67e-e278976f93ff", 483 | "metadata": {}, 484 | "outputs": [], 485 | "source": [ 486 | "# get next-word probability distribution\n", 487 | "@jax.jit\n", 488 | "def get_probs(state, input_tokens):\n", 489 | " return state.apply_fn({'params': state.params}, input_tokens)[0][-1]\n", 490 | "\n", 491 | "\n", 492 | "# Text generator\n", 493 | "class TextGenerator():\n", 494 | " def __init__(self, index_to_word):\n", 495 | " self.index_to_word = index_to_word\n", 496 | " self.word_to_index = {word : index for index, word in index_to_word.items()}\n", 497 | "\n", 498 | " # scaling the model's output probability with temperature\n", 499 | " def sample_from(self, probs, temperature):\n", 500 | " probs = probs ** (1 / temperature)\n", 501 | " probs = probs / np.sum(probs)\n", 502 | " return np.random.choice(VOCAB_SIZE, p=probs), probs\n", 503 | " \n", 504 | " # generate text\n", 505 | " def generate(self, state, start_prompt, max_tokens, temperature, output_info=False):\n", 506 | " \n", 507 | " start_tokens = [self.word_to_index[word] for word in start_prompt.split()]\n", 508 | " sample_token = None\n", 509 | " info = []\n", 510 | "\n", 511 | " while len(start_tokens) < max_tokens and sample_token != 0:\n", 512 | " input_tokens = np.array(start_tokens).reshape(1, -1)\n", 513 | " probs = get_probs(state, input_tokens)\n", 514 | " probs = nn.log_softmax(probs)\n", 515 | " sample_token, probs = self.sample_from(np.exp(probs), temperature)\n", 516 | " start_tokens.append(sample_token)\n", 517 | " if output_info:\n", 518 | " info.append({'tokens': np.copy(start_tokens), 'word_probs': probs})\n", 519 | " \n", 520 | " output_text = [self.index_to_word[token] for token in start_tokens if token != 0]\n", 521 | " print(' '.join(output_text))\n", 522 | "\n", 523 | " return info" 524 | ] 525 | }, 526 | { 527 | "cell_type": "markdown", 528 | "id": "3ea137ad-33ca-4525-bbd5-ed3097b49214", 529 | "metadata": {}, 530 | "source": [ 531 | "## 7. Training" 532 | ] 533 | }, 534 | { 535 | "cell_type": "code", 536 | "execution_count": 16, 537 | "id": "0ba8cf1f-9330-4a29-80b2-0c494b3327d2", 538 | "metadata": {}, 539 | "outputs": [], 540 | "source": [ 541 | "lstm_model = LSTM_model(NUM_LSTM_LAYERS)\n", 542 | "state = create_train_state(lstm_model, jax.random.PRNGKey(0), learning_rate=LR)\n", 543 | "text_generator = TextGenerator(index_to_word)" 544 | ] 545 | }, 546 | { 547 | "cell_type": "code", 548 | "execution_count": 17, 549 | "id": "99ebe744-3b35-4473-9e8a-35d81238c7e9", 550 | "metadata": {}, 551 | "outputs": [ 552 | { 553 | "name": "stdout", 554 | "output_type": "stream", 555 | "text": [ 556 | "Epoch: 1\tepoch time 0.21 min\n", 557 | "\ttrain loss: 3.9904, valid loss: 2.9524\n", 558 | "Epoch: 2\tepoch time 0.18 min\n", 559 | "\ttrain loss: 2.6645, valid loss: 2.4425\n", 560 | "Epoch: 3\tepoch time 0.18 min\n", 561 | "\ttrain loss: 2.3290, valid loss: 2.2334\n", 562 | "Epoch: 4\tepoch time 0.18 min\n", 563 | "\ttrain loss: 2.1584, valid loss: 2.1128\n", 564 | "Epoch: 5\tepoch time 0.18 min\n", 565 | "\ttrain loss: 2.0486, valid loss: 2.0321\n", 566 | "Epoch: 6\tepoch time 0.18 min\n", 567 | "\ttrain loss: 1.9692, valid loss: 1.9724\n", 568 | "Epoch: 7\tepoch time 0.18 min\n", 569 | "\ttrain loss: 1.9072, valid loss: 1.9247\n", 570 | "Epoch: 8\tepoch time 0.18 min\n", 571 | "\ttrain loss: 1.8567, valid loss: 1.8866\n", 572 | "Epoch: 9\tepoch time 0.18 min\n", 573 | "\ttrain loss: 1.8137, valid loss: 1.8557\n", 574 | "Epoch: 10\tepoch time 0.18 min\n", 575 | "\ttrain loss: 1.7761, valid loss: 1.8270\n", 576 | "\n", 577 | "Generated text:\n", 578 | "recipe for miso - water sauce | 1 . scrape them into an spice grinder . combine peaches , chili , and lemon zest in 2 inches . line with a bowl with nonstick oil . cook cumin and garlic in oil in same skillet over moderate heat , stirring , stirring , until wilted , about 5 minutes . remove from heat . add butter and sauté until all sauce is just wilted , about 1 minute . add chili powder , the reserved reserved 2 tablespoons butter , shallot , red pepper flakes , and ground black pepper\n", 579 | "\n", 580 | "\n", 581 | "Epoch: 11\tepoch time 0.18 min\n", 582 | "\ttrain loss: 1.7424, valid loss: 1.8037\n", 583 | "Epoch: 12\tepoch time 0.18 min\n", 584 | "\ttrain loss: 1.7133, valid loss: 1.7849\n", 585 | "Epoch: 13\tepoch time 0.18 min\n", 586 | "\ttrain loss: 1.6881, valid loss: 1.7693\n", 587 | "Epoch: 14\tepoch time 0.18 min\n", 588 | "\ttrain loss: 1.6653, valid loss: 1.7563\n", 589 | "Epoch: 15\tepoch time 0.18 min\n", 590 | "\ttrain loss: 1.6450, valid loss: 1.7436\n", 591 | "Epoch: 16\tepoch time 0.18 min\n", 592 | "\ttrain loss: 1.6265, valid loss: 1.7327\n", 593 | "Epoch: 17\tepoch time 0.18 min\n", 594 | "\ttrain loss: 1.6096, valid loss: 1.7246\n", 595 | "Epoch: 18\tepoch time 0.18 min\n", 596 | "\ttrain loss: 1.5942, valid loss: 1.7166\n", 597 | "Epoch: 19\tepoch time 0.18 min\n", 598 | "\ttrain loss: 1.5798, valid loss: 1.7087\n", 599 | "Epoch: 20\tepoch time 0.18 min\n", 600 | "\ttrain loss: 1.5667, valid loss: 1.7038\n", 601 | "\n", 602 | "Generated text:\n", 603 | "recipe for moroccan lamb | cook creole seasoning in a hot - salted water to a boil over high heat until evenly brown and almost and vegetables are boiling , about 1 minute . season with salt and pepper . transfer to a bowl . thinly slice avocado and zest , grapefruit juice , and lime juice and puree into a food processor . using a vegetable peeler , peel potatoes with a sharp paring knife , add orange juice and basil to skillet until smooth . in another bowl whisk together mayonnaise , salt and parsley until mixture just\n", 604 | "\n", 605 | "\n", 606 | "Epoch: 21\tepoch time 0.18 min\n", 607 | "\ttrain loss: 1.5543, valid loss: 1.6994\n", 608 | "Epoch: 22\tepoch time 0.18 min\n", 609 | "\ttrain loss: 1.5422, valid loss: 1.6951\n", 610 | "Epoch: 23\tepoch time 0.18 min\n", 611 | "\ttrain loss: 1.5328, valid loss: 1.6912\n", 612 | "Epoch: 24\tepoch time 0.18 min\n", 613 | "\ttrain loss: 1.5210, valid loss: 1.6868\n", 614 | "Epoch: 25\tepoch time 0.18 min\n", 615 | "\ttrain loss: 1.5114, valid loss: 1.6840\n", 616 | "Epoch: 26\tepoch time 0.18 min\n", 617 | "\ttrain loss: 1.5021, valid loss: 1.6821\n", 618 | "Epoch: 27\tepoch time 0.18 min\n", 619 | "\ttrain loss: 1.4933, valid loss: 1.6805\n", 620 | "Epoch: 28\tepoch time 0.18 min\n", 621 | "\ttrain loss: 1.4854, valid loss: 1.6770\n", 622 | "Epoch: 29\tepoch time 0.18 min\n", 623 | "\ttrain loss: 1.4772, valid loss: 1.6756\n", 624 | "Epoch: 30\tepoch time 0.18 min\n", 625 | "\ttrain loss: 1.4697, valid loss: 1.6750\n", 626 | "\n", 627 | "Generated text:\n", 628 | "recipe for minted ditalini sauce | preheat oven to 425°f . gently mix 1 / 4 cup fresh oyster concentrate , remaining 3 / 4 cup butter , eggs , and shallot in a small bowl . toss well - seasoned flour in a 13 - by 9 - inch baking pan . preheat oven to 475° . remove the rolls from grill and reserve . cut into 1 / 16 - inch pieces and in a small bowl toss ingredients with remaining teaspoon salt and shake , reserving . add to blue tomato sauce and stir until emulsified .\n", 629 | "\n", 630 | "\n" 631 | ] 632 | } 633 | ], 634 | "source": [ 635 | "loss_hist = defaultdict(list)\n", 636 | "\n", 637 | "for i in range(EPOCHS):\n", 638 | " prev_time = time.time()\n", 639 | " \n", 640 | " #training\n", 641 | " for batch in train_ds.as_numpy_iterator():\n", 642 | " state = train_step(state, batch)\n", 643 | "\n", 644 | " train_loss = state.metrics.compute()['loss']\n", 645 | " state = state.replace(metrics=state.metrics.empty())\n", 646 | "\n", 647 | " #validation\n", 648 | " test_state = state\n", 649 | " for batch in valid_ds.as_numpy_iterator():\n", 650 | " test_state = validation(test_state, batch)\n", 651 | "\n", 652 | " valid_loss = test_state.metrics.compute()['loss']\n", 653 | " \n", 654 | " loss_hist['train_loss'].append(train_loss)\n", 655 | " loss_hist['valid_loss'].append(valid_loss)\n", 656 | "\n", 657 | " curr_time = time.time()\n", 658 | " print(f'Epoch: {i+1}\\tepoch time {(curr_time - prev_time) / 60:.2f} min')\n", 659 | " print(f'\\ttrain loss: {train_loss:.4f}, valid loss: {valid_loss:.4f}')\n", 660 | " \n", 661 | " if (i + 1) % 10 == 0:\n", 662 | " # generate text\n", 663 | " print('\\nGenerated text:')\n", 664 | " info = text_generator.generate(state, 'recipe for', MAX_VAL_TOKENS, 1.0)\n", 665 | " print('\\n')" 666 | ] 667 | }, 668 | { 669 | "cell_type": "markdown", 670 | "id": "d64e8861-4971-423a-81ab-5f38512465ee", 671 | "metadata": {}, 672 | "source": [ 673 | "## 8. Generate texts" 674 | ] 675 | }, 676 | { 677 | "cell_type": "code", 678 | "execution_count": 18, 679 | "id": "06a2d4c7-c0d1-4972-9260-ac68b79e9c3d", 680 | "metadata": {}, 681 | "outputs": [], 682 | "source": [ 683 | "# print prompt and top k candidate words probability\n", 684 | "def print_probs(info, index_to_word, top_k=5):\n", 685 | " assert len(info) > 0, 'Please make `output_info=True`'\n", 686 | " for i in range(len(info)):\n", 687 | " start_tokens, word_probs = info[i].values()\n", 688 | " start_prompts = [index_to_word[token] for token in start_tokens if token != 0]\n", 689 | " start_prompts = ' '.join(start_prompts)\n", 690 | " print(f'\\nPrompt: {start_prompts}')\n", 691 | " # word_probs\n", 692 | " probs_sorted = np.argsort(word_probs)[::-1][:top_k]\n", 693 | " for idx in probs_sorted:\n", 694 | " print(f'{index_to_word[idx]}\\t{word_probs[idx] * 100:.2f}%')" 695 | ] 696 | }, 697 | { 698 | "cell_type": "code", 699 | "execution_count": 19, 700 | "id": "8b73be6f-b081-45c1-b998-73d9ceeec322", 701 | "metadata": {}, 702 | "outputs": [ 703 | { 704 | "name": "stdout", 705 | "output_type": "stream", 706 | "text": [ 707 | "recipe for roast turkey with au\n", 708 | "\n", 709 | "Prompt: recipe for roast turkey\n", 710 | "chicken\t22.24%\n", 711 | "pork\t16.31%\n", 712 | "turkey\t14.29%\n", 713 | "beef\t9.31%\n", 714 | "lamb\t4.78%\n", 715 | "\n", 716 | "Prompt: recipe for roast turkey with\n", 717 | "with\t91.31%\n", 718 | "breast\t1.43%\n", 719 | "and\t0.74%\n", 720 | "|\t0.58%\n", 721 | "legs\t0.31%\n", 722 | "\n", 723 | "Prompt: recipe for roast turkey with au\n", 724 | "lemon\t2.73%\n", 725 | "port\t2.62%\n", 726 | "red\t1.99%\n", 727 | "roasted\t1.95%\n", 728 | "salt\t1.81%\n" 729 | ] 730 | } 731 | ], 732 | "source": [ 733 | "# Candidate words probability with temperature = 1.0\n", 734 | "info = text_generator.generate(state, \n", 735 | " 'recipe for roast', \n", 736 | " max_tokens=6, \n", 737 | " temperature=1.0, \n", 738 | " output_info=True)\n", 739 | "\n", 740 | "print_probs(info, index_to_word, 5)" 741 | ] 742 | }, 743 | { 744 | "cell_type": "code", 745 | "execution_count": 20, 746 | "id": "2975f0c4-3c0e-4081-a365-83d9e83ec503", 747 | "metadata": {}, 748 | "outputs": [ 749 | { 750 | "name": "stdout", 751 | "output_type": "stream", 752 | "text": [ 753 | "recipe for roast chicken with roasted\n", 754 | "\n", 755 | "Prompt: recipe for roast chicken\n", 756 | "chicken\t74.93%\n", 757 | "pork\t15.88%\n", 758 | "turkey\t8.19%\n", 759 | "beef\t0.96%\n", 760 | "lamb\t0.03%\n", 761 | "\n", 762 | "Prompt: recipe for roast chicken with\n", 763 | "with\t100.00%\n", 764 | "|\t0.00%\n", 765 | "and\t0.00%\n", 766 | "in\t0.00%\n", 767 | "breasts\t0.00%\n", 768 | "\n", 769 | "Prompt: recipe for roast chicken with roasted\n", 770 | "lemon\t40.40%\n", 771 | "roasted\t31.21%\n", 772 | "red\t5.84%\n", 773 | "mustard\t4.11%\n", 774 | "white\t3.64%\n" 775 | ] 776 | } 777 | ], 778 | "source": [ 779 | "# Candidate words probability distribution with temperature = 1.0\n", 780 | "info = text_generator.generate(state, \n", 781 | " 'recipe for roast', \n", 782 | " max_tokens=6, \n", 783 | " temperature=0.2, \n", 784 | " output_info=True)\n", 785 | "\n", 786 | "print_probs(info, index_to_word, 5)" 787 | ] 788 | }, 789 | { 790 | "cell_type": "code", 791 | "execution_count": 21, 792 | "id": "94def9a4-fc37-4154-877b-157f93039d73", 793 | "metadata": {}, 794 | "outputs": [ 795 | { 796 | "name": "stdout", 797 | "output_type": "stream", 798 | "text": [ 799 | "recipe for roast pork chops with pimiento cheese and basil | cook guajillo and onion in large pot of boiling salted water 30 seconds . drain potatoes in well . transfer carrots to large baking sheet . lift off skin and discard . add sliced prunes to outside of 9 - inch - diameter glass dish with pan juices . stir in butter and chile paste . increase heat by 1 cup , thyme , onion , cardamom , and cloves , necks and peppercorns to hot water . cover , and cook until potatoes are tender , about 8\n" 800 | ] 801 | } 802 | ], 803 | "source": [ 804 | "# generate text with temperature = 1.0\n", 805 | "info = text_generator.generate(state, \n", 806 | " 'recipe for roast', \n", 807 | " max_tokens=100, \n", 808 | " temperature=1.0, \n", 809 | " output_info=True)" 810 | ] 811 | }, 812 | { 813 | "cell_type": "code", 814 | "execution_count": 22, 815 | "id": "d4b2f17c-37e3-4990-a3b7-fa48cbe25f82", 816 | "metadata": {}, 817 | "outputs": [ 818 | { 819 | "name": "stdout", 820 | "output_type": "stream", 821 | "text": [ 822 | "recipe for roast chicken with ginger - ginger butter | preheat oven to 350°f . butter and flour a 9 - inch - diameter glass pie dish . combine first 6 ingredients in small bowl . add 1 / 2 cup oil ; rub with remaining 1 / 4 cup oil . place chicken on baking sheet . roast until tender , about 20 minutes . transfer to large rimmed baking sheet . add remaining 1 / 4 cup oil ; rub in pan . roast until thermometer inserted into thickest part of thigh registers 165°f , about 35 minutes\n" 823 | ] 824 | } 825 | ], 826 | "source": [ 827 | "# generate text with temperature = 0.2\n", 828 | "info = text_generator.generate(state, \n", 829 | " 'recipe for roast', \n", 830 | " max_tokens=100, \n", 831 | " temperature=0.2, \n", 832 | " output_info=True)" 833 | ] 834 | } 835 | ], 836 | "metadata": { 837 | "kernelspec": { 838 | "display_name": "Python 3", 839 | "language": "python", 840 | "name": "python3" 841 | }, 842 | "language_info": { 843 | "codemirror_mode": { 844 | "name": "ipython", 845 | "version": 3 846 | }, 847 | "file_extension": ".py", 848 | "mimetype": "text/x-python", 849 | "name": "python", 850 | "nbconvert_exporter": "python", 851 | "pygments_lexer": "ipython3", 852 | "version": "3.9.0" 853 | } 854 | }, 855 | "nbformat": 4, 856 | "nbformat_minor": 5 857 | } 858 | -------------------------------------------------------------------------------- /PyTorch/chapter_05_autoregressive/01_lstm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "c12c6fab-b046-4739-884f-0a63b9298246", 6 | "metadata": {}, 7 | "source": [ 8 | "# LSTM on Recipe Data" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "fce0db59-ce7c-432c-afda-2bd1f5a3f666", 14 | "metadata": {}, 15 | "source": [ 16 | "**The notebook has been adapted from the notebook provided in David Foster's Generative Deep Learning, 2nd Edition.**\n", 17 | "\n", 18 | "- Book: [Amazon](https://www.amazon.com/Generative-Deep-Learning-Teaching-Machines/dp/1098134184/ref=sr_1_1?keywords=generative+deep+learning%2C+2nd+edition&qid=1684708209&sprefix=generative+de%2Caps%2C93&sr=8-1)\n", 19 | "- Original notebook (tensorflow and keras): [Github](https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition/blob/main/notebooks/05_autoregressive/01_lstm/lstm.ipynb)\n", 20 | "- Dataset: [Kaggle](https://www.kaggle.com/datasets/hugodarwood/epirecipes)" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 1, 26 | "id": "18a75e59-1fd7-4879-8ddc-2534fbdcfc94", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import numpy as np\n", 31 | "import json\n", 32 | "import re\n", 33 | "import string\n", 34 | "import time\n", 35 | "\n", 36 | "import torch\n", 37 | "from torch import nn\n", 38 | "from torch.nn.functional import pad\n", 39 | "from torch.utils.data import Dataset, DataLoader, random_split\n", 40 | "\n", 41 | "from torchtext.vocab import build_vocab_from_iterator\n", 42 | "from torchtext.data.utils import get_tokenizer\n", 43 | "\n", 44 | "import torchinfo" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "id": "fedd03ec-0126-4d07-893e-dc454c6d16e2", 50 | "metadata": {}, 51 | "source": [ 52 | "## 0. Train parameters" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 2, 58 | "id": "5d4de63f-867b-4c6b-b0ea-335a385517a6", 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "DATA_DIR = '../../data/epirecipes/full_format_recipes.json'\n", 63 | "\n", 64 | "EMBEDDING_DIM = 100\n", 65 | "HIDDEN_DIM = 128\n", 66 | "VALIDATION_SPLIT = 0.2\n", 67 | "SEED = 1024\n", 68 | "BATCH_SIZE = 32\n", 69 | "EPOCHS = 30\n", 70 | "\n", 71 | "MAX_PAD_LEN = 200\n", 72 | "MAX_VAL_TOKENS = 100 # Max number of tokens when generating texts\n", 73 | "\n", 74 | "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "id": "c48ceccc-efd6-4b85-805a-d3bea8b72120", 80 | "metadata": {}, 81 | "source": [ 82 | "## 1. Load dataset" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 3, 88 | "id": "673da423-0772-4e13-9496-a22e5c5916f9", 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "def pad_punctuation(sentence):\n", 93 | " sentence = re.sub(f'([{string.punctuation}])', r' \\1 ', sentence)\n", 94 | " sentence = re.sub(' +', ' ', sentence)\n", 95 | " return sentence" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 4, 101 | "id": "d47ce264-2543-43d3-8eef-7f24d3ecf019", 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "# Load dataset\n", 106 | "with open(DATA_DIR, 'r+') as f:\n", 107 | " recipe_data = json.load(f)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 5, 113 | "id": "b33cbc22-91c1-403b-a602-531cf47455b4", 114 | "metadata": {}, 115 | "outputs": [ 116 | { 117 | "name": "stdout", 118 | "output_type": "stream", 119 | "text": [ 120 | "Total recipe loaded: 20098\n" 121 | ] 122 | } 123 | ], 124 | "source": [ 125 | "# preprocess dataset\n", 126 | "filtered_data = [\n", 127 | " 'Recipe for ' + x['title'] + ' | ' + ' '.join(x['directions'])\n", 128 | " for x in recipe_data\n", 129 | " if 'title' in x and x['title']\n", 130 | " and 'directions' in x and x['directions']\n", 131 | "]\n", 132 | "\n", 133 | "text_ds = [pad_punctuation(sentence) for sentence in filtered_data]\n", 134 | "\n", 135 | "print(f'Total recipe loaded: {len(text_ds)}')" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 6, 141 | "id": "bc7d7d85-b07c-4c64-a78e-b5ee38b16120", 142 | "metadata": {}, 143 | "outputs": [ 144 | { 145 | "name": "stdout", 146 | "output_type": "stream", 147 | "text": [ 148 | "Sample data:\n", 149 | "Recipe for Ricotta Cheesecake | Preheat oven to 350°F . Pulse flour , sugar , salt , and butter in a food processor until mixture resembles coarse meal . Add yolk , vanilla , and lemon juice and pulse just until mixture begins to form a dough . Spread dough with a small offset spatula or back of a spoon over buttered bottom of a 24 - centimeter springform pan and prick all over with a fork . Chill 30 minutes . Bake crust in a shallow baking pan ( to catch drips ) in middle of oven until golden brown , about 25 minutes , and cool on a rack . Increase temperature to 375°F . Discard liquid and cheesecloth and force drained ricotta through sieve into bowl . Beat yolks and sugar with an electric mixer until thick and pale , then beat in ricotta , flour , and zests . Beat whites with salt in another bowl until they hold soft peaks , and fold into ricotta mixture . Butter side of springform pan and pour filling over crust ( pan will be completely full ) . Bake in baking pan in middle of oven until cake is puffed and golden and a tester inserted 1 inch from center comes out clean , about 1 hour . Run a knife around top edge of cake to loosen and cool completely in springform pan on rack . Chill , loosely covered , at least 4 hours . Remove side of pan and transfer cake to a plate . Bring to room temperature before serving . \n" 150 | ] 151 | } 152 | ], 153 | "source": [ 154 | "print('Sample data:')\n", 155 | "sample_data = np.random.choice(text_ds)\n", 156 | "print(sample_data)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "id": "c80f4e5d-fd9a-4c89-ab0e-d82c20d7b690", 162 | "metadata": {}, 163 | "source": [ 164 | "## 2. Build vocabularies" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 7, 170 | "id": "667a7462-74b1-43de-854b-0616f581facb", 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "# The iterator that yields tokenized data\n", 175 | "def yield_tokens(data_iter, tokenizer):\n", 176 | " for sample in data_iter:\n", 177 | " yield tokenizer(sample)\n", 178 | "\n", 179 | "# Building vocabulary\n", 180 | "def build_vocab(dataset, tokenizer):\n", 181 | " vocab = build_vocab_from_iterator(\n", 182 | " yield_tokens(dataset, tokenizer),\n", 183 | " min_freq=2,\n", 184 | " specials=['', '']\n", 185 | " )\n", 186 | " return vocab" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 8, 192 | "id": "30ecca2d-4340-441d-9c82-d9dde0ad2174", 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "tokenizer = get_tokenizer('basic_english')\n", 197 | "vocab = build_vocab(text_ds, tokenizer)\n", 198 | "vocab.set_default_index(vocab[''])\n", 199 | "\n", 200 | "# Create index-to-word mapping\n", 201 | "index_to_word = {index : word for word, index in vocab.get_stoi().items()}" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 9, 207 | "id": "ab021f66-5ca0-41bf-9b8e-f0121aaf20ab", 208 | "metadata": {}, 209 | "outputs": [ 210 | { 211 | "name": "stdout", 212 | "output_type": "stream", 213 | "text": [ 214 | "0: \n", 215 | "1: \n", 216 | "2: .\n", 217 | "3: ,\n", 218 | "4: and\n", 219 | "5: to\n", 220 | "6: in\n", 221 | "7: the\n", 222 | "8: with\n", 223 | "9: a\n" 224 | ] 225 | } 226 | ], 227 | "source": [ 228 | "# display some token-word mappings\n", 229 | "for i in range(10):\n", 230 | " word = vocab.get_itos()[i]\n", 231 | " print(f'{i}: {word}')" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": 10, 237 | "id": "b7a952c5-43da-4428-8a53-6f76afa69974", 238 | "metadata": {}, 239 | "outputs": [ 240 | { 241 | "name": "stdout", 242 | "output_type": "stream", 243 | "text": [ 244 | "Source text:\n", 245 | "Recipe for Ricotta Cheesecake | Preheat oven to 350°F . Pulse flour , sugar , salt , and butter in a food processor until mixture resembles coarse meal . Add yolk , vanilla , and lemon juice and pulse just until mixture begins to form a dough . Spread dough with a small offset spatula or back of a spoon over buttered bottom of a 24 - centimeter springform pan and prick all over with a fork . Chill 30 minutes . Bake crust in a shallow baking pan ( to catch drips ) in middle of oven until golden brown , about 25 minutes , and cool on a rack . Increase temperature to 375°F . Discard liquid and cheesecloth and force drained ricotta through sieve into bowl . Beat yolks and sugar with an electric mixer until thick and pale , then beat in ricotta , flour , and zests . Beat whites with salt in another bowl until they hold soft peaks , and fold into ricotta mixture . Butter side of springform pan and pour filling over crust ( pan will be completely full ) . Bake in baking pan in middle of oven until cake is puffed and golden and a tester inserted 1 inch from center comes out clean , about 1 hour . Run a knife around top edge of cake to loosen and cool completely in springform pan on rack . Chill , loosely covered , at least 4 hours . Remove side of pan and transfer cake to a plate . Bring to room temperature before serving . \n", 246 | "\n", 247 | "\n", 248 | "Mapped sample:\n", 249 | "[25, 16, 781, 1060, 26, 85, 46, 5, 215, 2, 434, 110, 3, 55, 3, 23, 3, 4, 49, 6, 9, 289, 187, 10, 30, 730, 409, 641, 2, 18, 697, 3, 257, 3, 4, 108, 103, 4, 434, 92, 10, 30, 546, 5, 236, 9, 93, 2, 165, 93, 8, 9, 64, 1610, 381, 40, 401, 14, 9, 97, 20, 671, 168, 14, 9, 982, 13, 5195, 902, 43, 4, 1318, 121, 20, 8, 9, 334, 2, 107, 125, 12, 2, 96, 241, 6, 9, 341, 57, 43, 33, 5, 1734, 2819, 34, 6, 253, 14, 46, 10, 99, 89, 3, 19, 353, 12, 3, 4, 59, 27, 9, 118, 2, 544, 111, 5, 421, 2, 205, 142, 4, 940, 4, 1151, 1193, 781, 101, 304, 24, 21, 2, 144, 430, 4, 55, 8, 177, 329, 287, 10, 196, 4, 500, 3, 45, 144, 6, 781, 3, 110, 3, 4, 2247, 2, 144, 380, 8, 23, 6, 255, 21, 10, 314, 530, 284, 396, 3, 4, 252, 24, 781, 30, 2, 49, 95, 14, 902, 43, 4, 106, 216, 20, 241, 33, 43, 190, 61, 222, 1254, 34, 2, 96, 6, 57, 43, 6, 253, 14, 46, 10, 136, 37, 743, 4, 99, 4, 9, 637, 366, 11, 52, 50, 167, 441, 124, 367, 3, 19, 11, 171, 2, 622, 9, 265, 274, 71, 422, 14, 136, 5, 664, 4, 59, 222, 6, 902, 43, 27, 118, 2, 107, 3, 516, 120, 3, 56, 203, 31, 105, 2, 70, 95, 14, 43, 4, 39, 136, 5, 9, 218, 2, 83, 5, 139, 111, 164, 223, 2]\n" 250 | ] 251 | } 252 | ], 253 | "source": [ 254 | "# Check mappings\n", 255 | "mapped_sample = vocab(tokenizer(sample_data))\n", 256 | "print('Source text:')\n", 257 | "print(sample_data)\n", 258 | "print('\\n')\n", 259 | "print('Mapped sample:')\n", 260 | "print(mapped_sample)" 261 | ] 262 | }, 263 | { 264 | "cell_type": "markdown", 265 | "id": "cd401c53-04a6-4cee-bf39-ae446e6ac2a9", 266 | "metadata": {}, 267 | "source": [ 268 | "# 3. Create DataLoader" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": 11, 274 | "id": "1f3d9056-9364-4cb9-ae1c-e43e1fc84a6c", 275 | "metadata": {}, 276 | "outputs": [], 277 | "source": [ 278 | "class Collate():\n", 279 | " def __init__(self, tokenizer, vocab, max_padding, pad_idx):\n", 280 | " self.tokenizer = tokenizer\n", 281 | " self.vocab = vocab\n", 282 | "\n", 283 | " self.max_padding = max_padding\n", 284 | " self.pad_idx = pad_idx\n", 285 | "\n", 286 | " \n", 287 | " def collate_fn(self, batch):\n", 288 | " src_list = []\n", 289 | " tgt_list = []\n", 290 | "\n", 291 | " # Prepare source and target batch\n", 292 | " for sentence in batch:\n", 293 | " # convert text to vocab tensor\n", 294 | " tokens = self.tokenizer(sentence)\n", 295 | " src_mapping = torch.tensor(self.vocab(tokens[:-1]), dtype=torch.int64)\n", 296 | " tgt_mapping = torch.tensor(self.vocab(tokens[1:]), dtype=torch.int64)\n", 297 | " # pad sequence\n", 298 | " src_padded = pad(src_mapping, [0, self.max_padding - len(src_mapping)], value=self.pad_idx)\n", 299 | " tgt_padded = pad(tgt_mapping, [0, self.max_padding - len(tgt_mapping)], value=self.pad_idx)\n", 300 | " # append padded sequence to corresponding lists\n", 301 | " src_list.append(src_padded)\n", 302 | " tgt_list.append(tgt_padded)\n", 303 | "\n", 304 | " # stack batch\n", 305 | " src = torch.stack(src_list)\n", 306 | " tgt = torch.stack(tgt_list)\n", 307 | "\n", 308 | " return (src, tgt)" 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": 12, 314 | "id": "a5aa5a77-c71f-469c-9bb6-681c0be398f7", 315 | "metadata": {}, 316 | "outputs": [ 317 | { 318 | "name": "stdout", 319 | "output_type": "stream", 320 | "text": [ 321 | "Num. training data: \t 16079\n", 322 | "Num. validation data: \t 4019\n" 323 | ] 324 | } 325 | ], 326 | "source": [ 327 | "# Split dataset into training and validation splits\n", 328 | "train_ds, valid_ds = random_split(text_ds, [1-VALIDATION_SPLIT, VALIDATION_SPLIT])\n", 329 | "print(\"Num. training data: \\t\", len(train_ds))\n", 330 | "print(\"Num. validation data: \\t\", len(valid_ds))" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": 13, 336 | "id": "81a9fcdf-4830-4f63-a3e4-8cf9058ee28f", 337 | "metadata": {}, 338 | "outputs": [ 339 | { 340 | "name": "stdout", 341 | "output_type": "stream", 342 | "text": [ 343 | "index of token: 0\n" 344 | ] 345 | } 346 | ], 347 | "source": [ 348 | "pad_idx = vocab.get_stoi()['']\n", 349 | "print('index of token: ', pad_idx)\n", 350 | "\n", 351 | "collate = Collate(tokenizer, vocab, MAX_PAD_LEN+1, pad_idx)\n", 352 | "\n", 353 | "train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, \n", 354 | " shuffle=True, num_workers=8, pin_memory=True,\n", 355 | " collate_fn=collate.collate_fn)\n", 356 | "\n", 357 | "valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, \n", 358 | " shuffle=False, num_workers=8, pin_memory=True,\n", 359 | " collate_fn=collate.collate_fn)" 360 | ] 361 | }, 362 | { 363 | "cell_type": "markdown", 364 | "id": "b251e235-0cdd-4c87-9c1e-75d2faa61997", 365 | "metadata": {}, 366 | "source": [ 367 | "## 4. Build LSTM model" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": 14, 373 | "id": "cb877b2b-f1df-4ac6-a66f-b0d15bd37557", 374 | "metadata": {}, 375 | "outputs": [ 376 | { 377 | "data": { 378 | "text/plain": [ 379 | "==========================================================================================\n", 380 | "Layer (type:depth-idx) Output Shape Param #\n", 381 | "==========================================================================================\n", 382 | "LSTM_Net [32, 201, 8628] --\n", 383 | "├─Embedding: 1-1 [32, 201, 100] 862,800\n", 384 | "├─LSTM: 1-2 [32, 201, 128] 249,856\n", 385 | "├─Linear: 1-3 [32, 201, 8628] 1,113,012\n", 386 | "==========================================================================================\n", 387 | "Total params: 2,225,668\n", 388 | "Trainable params: 2,225,668\n", 389 | "Non-trainable params: 0\n", 390 | "Total mult-adds (G): 1.67\n", 391 | "==========================================================================================\n", 392 | "Input size (MB): 0.05\n", 393 | "Forward/backward pass size (MB): 455.69\n", 394 | "Params size (MB): 8.90\n", 395 | "Estimated Total Size (MB): 464.65\n", 396 | "==========================================================================================" 397 | ] 398 | }, 399 | "execution_count": 14, 400 | "metadata": {}, 401 | "output_type": "execute_result" 402 | } 403 | ], 404 | "source": [ 405 | "class LSTM_Net(nn.Module):\n", 406 | "\n", 407 | " def __init__(self, vocab_size):\n", 408 | " super().__init__()\n", 409 | " self.embedding = nn.Embedding(num_embeddings=vocab_size,\n", 410 | " embedding_dim=EMBEDDING_DIM,\n", 411 | " padding_idx=pad_idx)\n", 412 | " \n", 413 | " self.lstm = nn.LSTM(input_size=EMBEDDING_DIM,\n", 414 | " hidden_size=HIDDEN_DIM,\n", 415 | " num_layers=2,\n", 416 | " batch_first=True)\n", 417 | " \n", 418 | " self.output = nn.Linear(in_features=HIDDEN_DIM,\n", 419 | " out_features=vocab_size)\n", 420 | " \n", 421 | " def forward(self, x):\n", 422 | " x = self.embedding(x)\n", 423 | " x, hidden_state = self.lstm(x)\n", 424 | " return self.output(x)\n", 425 | "\n", 426 | "\n", 427 | "model = LSTM_Net(len(vocab))\n", 428 | "torchinfo.summary(model=model, input_size=(BATCH_SIZE, MAX_PAD_LEN+1), \n", 429 | " dtypes=[torch.int64], depth=3)" 430 | ] 431 | }, 432 | { 433 | "cell_type": "markdown", 434 | "id": "44068c25-c912-4cfd-823d-b0368ca302e9", 435 | "metadata": {}, 436 | "source": [ 437 | "## 5. Train step functions" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": 15, 443 | "id": "b9b80bb9-1b77-4ec2-9113-5caccdef4561", 444 | "metadata": {}, 445 | "outputs": [], 446 | "source": [ 447 | "class TextGenerator():\n", 448 | " def __init__(self, index_to_word):\n", 449 | " self.index_to_word = index_to_word\n", 450 | "\n", 451 | " # Scaling the model's output probability with temperature\n", 452 | " def sample_from(self, probs, temperature):\n", 453 | " probs = probs ** (1 / temperature)\n", 454 | " probs = probs / np.sum(probs)\n", 455 | " return np.random.choice(len(probs), p=probs), probs\n", 456 | "\n", 457 | " # Generate text\n", 458 | " def generate(self, model, start_prompt, max_tokens, temperature, output_info=False):\n", 459 | " model.eval()\n", 460 | " \n", 461 | " start_tokens = vocab(tokenizer(start_prompt))\n", 462 | " sample_token = None\n", 463 | " info = []\n", 464 | " \n", 465 | " while len(start_tokens) < max_tokens and sample_token != 0: # also avoid padding index\n", 466 | " input_prompts = torch.tensor(start_tokens, device=DEVICE).unsqueeze(0)\n", 467 | " probs = model(input_prompts)[0][-1]\n", 468 | " probs = nn.functional.softmax(probs, dim=-1)\n", 469 | " sample_token, probs = self.sample_from(probs.detach().cpu().numpy(), temperature)\n", 470 | " \n", 471 | " start_tokens.append(sample_token)\n", 472 | " if output_info:\n", 473 | " info.append({'token': np.copy(start_tokens), 'word_probs': probs})\n", 474 | "\n", 475 | " output_text = [self.index_to_word[token] for token in start_tokens if token != 0]\n", 476 | " print(' '.join(output_text))\n", 477 | " return info" 478 | ] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "execution_count": 16, 483 | "id": "94b067d3-b05c-4440-be23-8aa533c24b37", 484 | "metadata": {}, 485 | "outputs": [], 486 | "source": [ 487 | "# Training function\n", 488 | "def train_step(model, dataloader, loss_fn, optimizer):\n", 489 | " \n", 490 | " model.train()\n", 491 | " total_loss = 0\n", 492 | " \n", 493 | " for sources, targets in dataloader: \n", 494 | " optim.zero_grad()\n", 495 | " \n", 496 | " sources, targets = sources.to(DEVICE), targets.to(DEVICE)\n", 497 | " preds = model(sources)\n", 498 | " loss = loss_fn(preds.reshape(-1, preds.shape[-1]), targets.reshape(-1))\n", 499 | " loss.backward()\n", 500 | " optim.step()\n", 501 | "\n", 502 | " total_loss += loss.item()\n", 503 | "\n", 504 | " return total_loss / len(dataloader)\n", 505 | "\n", 506 | "\n", 507 | "# Evaluation function\n", 508 | "def eval(model, dataloader, loss_fn):\n", 509 | "\n", 510 | " model.eval()\n", 511 | " valid_loss = 0\n", 512 | " \n", 513 | " for sources, targets in dataloader:\n", 514 | " sources, targets = sources.to(DEVICE), targets.to(DEVICE)\n", 515 | " preds = model(sources)\n", 516 | " loss = loss_fn(preds.reshape(-1, preds.shape[-1]), targets.reshape(-1))\n", 517 | " valid_loss += loss.item()\n", 518 | "\n", 519 | " return valid_loss / len(dataloader)" 520 | ] 521 | }, 522 | { 523 | "cell_type": "markdown", 524 | "id": "9ca44dbf-52cf-4863-a696-45d7f24c37f0", 525 | "metadata": {}, 526 | "source": [ 527 | "## 6. Training" 528 | ] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "execution_count": 17, 533 | "id": "b7a4753f-1211-400f-942d-0c28b8dab770", 534 | "metadata": {}, 535 | "outputs": [], 536 | "source": [ 537 | "model = LSTM_Net(len(vocab)).to(DEVICE)\n", 538 | "\n", 539 | "# if torch.__version__.split('.')[0] == '2':\n", 540 | "# torch.set_float32_matmul_precision('high')\n", 541 | "# model = torch.compile(model, mode=\"max-autotune\")\n", 542 | "# print('model compiled')\n", 543 | "\n", 544 | "loss_fn = nn.CrossEntropyLoss()\n", 545 | "optim = torch.optim.Adam(model.parameters())\n", 546 | "\n", 547 | "text_generator = TextGenerator(index_to_word)" 548 | ] 549 | }, 550 | { 551 | "cell_type": "code", 552 | "execution_count": 18, 553 | "id": "c4b36ca0-c2ca-4701-9741-2d37041de665", 554 | "metadata": {}, 555 | "outputs": [ 556 | { 557 | "name": "stdout", 558 | "output_type": "stream", 559 | "text": [ 560 | "Epoch: 1\tepoch time 0.11 min\n", 561 | "\ttrain loss: 4.2283, valid loss: 3.5230\n", 562 | "Epoch: 2\tepoch time 0.11 min\n", 563 | "\ttrain loss: 3.0208, valid loss: 2.7528\n", 564 | "Epoch: 3\tepoch time 0.11 min\n", 565 | "\ttrain loss: 2.5472, valid loss: 2.4550\n", 566 | "Epoch: 4\tepoch time 0.11 min\n", 567 | "\ttrain loss: 2.3142, valid loss: 2.2764\n", 568 | "Epoch: 5\tepoch time 0.11 min\n", 569 | "\ttrain loss: 2.1636, valid loss: 2.1592\n", 570 | "Epoch: 6\tepoch time 0.11 min\n", 571 | "\ttrain loss: 2.0621, valid loss: 2.0800\n", 572 | "Epoch: 7\tepoch time 0.11 min\n", 573 | "\ttrain loss: 1.9867, valid loss: 2.0197\n", 574 | "Epoch: 8\tepoch time 0.11 min\n", 575 | "\ttrain loss: 1.9275, valid loss: 1.9755\n", 576 | "Epoch: 9\tepoch time 0.11 min\n", 577 | "\ttrain loss: 1.8799, valid loss: 1.9377\n", 578 | "Epoch: 10\tepoch time 0.11 min\n", 579 | "\ttrain loss: 1.8382, valid loss: 1.9050\n", 580 | "\n", 581 | "Generated text:\n", 582 | "recipe for potato root vegetables | blend first 5 ingredients in processor . using electric mixer , beat shells in medium bowl until stiff peaks form . divide batter among parchment nonstick pans . fold 1 piece into pastry overhang until edges are hold cucumbers are golden brown bits . rub onto bottoms ( oil will be visible ) . bake meringues until lean on their sides , about 45 minutes . ( often into warm ) . drizzle 1 / 2 cup tomato cheese over tofu mixture . serve with tuna alongside . roast turkey until skin and cooked\n", 583 | "\n", 584 | "\n", 585 | "Epoch: 11\tepoch time 0.11 min\n", 586 | "\ttrain loss: 1.8028, valid loss: 1.8784\n", 587 | "Epoch: 12\tepoch time 0.11 min\n", 588 | "\ttrain loss: 1.7718, valid loss: 1.8569\n", 589 | "Epoch: 13\tepoch time 0.11 min\n", 590 | "\ttrain loss: 1.7438, valid loss: 1.8353\n", 591 | "Epoch: 14\tepoch time 0.11 min\n", 592 | "\ttrain loss: 1.7200, valid loss: 1.8229\n", 593 | "Epoch: 15\tepoch time 0.11 min\n", 594 | "\ttrain loss: 1.6980, valid loss: 1.8070\n", 595 | "Epoch: 16\tepoch time 0.11 min\n", 596 | "\ttrain loss: 1.6792, valid loss: 1.7945\n", 597 | "Epoch: 17\tepoch time 0.11 min\n", 598 | "\ttrain loss: 1.6613, valid loss: 1.7849\n", 599 | "Epoch: 18\tepoch time 0.11 min\n", 600 | "\ttrain loss: 1.6455, valid loss: 1.7723\n", 601 | "Epoch: 19\tepoch time 0.11 min\n", 602 | "\ttrain loss: 1.6299, valid loss: 1.7692\n", 603 | "Epoch: 20\tepoch time 0.11 min\n", 604 | "\ttrain loss: 1.6159, valid loss: 1.7581\n", 605 | "\n", 606 | "Generated text:\n", 607 | "recipe for tropical fruit soup | combine garlic and ginger in a bowl . coarsely crush both halves together with parchment paper and line with salt and wine , 1 1 / 2 teaspoons salt , and up to 1 / 2 cup . cut fat lengthwise into 1 / 8 - inch - thick slices . cut each egg crosswise into 4 1 - inch - thick slices , cut filling into 1 - inch pieces and cut slices into 1 / 4 - inch pieces . in a blender purée with remaining ingredients in a blender until smooth\n", 608 | "\n", 609 | "\n", 610 | "Epoch: 21\tepoch time 0.11 min\n", 611 | "\ttrain loss: 1.6030, valid loss: 1.7512\n", 612 | "Epoch: 22\tepoch time 0.11 min\n", 613 | "\ttrain loss: 1.5918, valid loss: 1.7453\n", 614 | "Epoch: 23\tepoch time 0.11 min\n", 615 | "\ttrain loss: 1.5798, valid loss: 1.7414\n", 616 | "Epoch: 24\tepoch time 0.11 min\n", 617 | "\ttrain loss: 1.5690, valid loss: 1.7365\n", 618 | "Epoch: 25\tepoch time 0.11 min\n", 619 | "\ttrain loss: 1.5594, valid loss: 1.7303\n", 620 | "Epoch: 26\tepoch time 0.11 min\n", 621 | "\ttrain loss: 1.5495, valid loss: 1.7276\n", 622 | "Epoch: 27\tepoch time 0.11 min\n", 623 | "\ttrain loss: 1.5407, valid loss: 1.7261\n", 624 | "Epoch: 28\tepoch time 0.11 min\n", 625 | "\ttrain loss: 1.5317, valid loss: 1.7214\n", 626 | "Epoch: 29\tepoch time 0.11 min\n", 627 | "\ttrain loss: 1.5236, valid loss: 1.7200\n", 628 | "Epoch: 30\tepoch time 0.11 min\n", 629 | "\ttrain loss: 1.5156, valid loss: 1.7185\n", 630 | "\n", 631 | "Generated text:\n", 632 | "recipe for scallops with coriander garlic sauce | season garlic with salt and pepper . heat a wok or heavy skillet over high heat . meanwhile , pat steaks dry and cut them into 1 / 4 - inch - wide pieces . wrap crushed pepper in paper or under the broiler tray over moderate heat . add the greens and transfer it to a plate . if you don ' t serve as the poaching liquid can be served immediately immediately add more toasted scallion leaves . taste for seasoning blend the soup or or serving of cold .\n", 633 | "\n", 634 | "\n" 635 | ] 636 | } 637 | ], 638 | "source": [ 639 | "loss_hist = {'train':[], 'valid':[]}\n", 640 | "\n", 641 | "for i in range(EPOCHS):\n", 642 | " prev_time = time.time()\n", 643 | " train_loss = train_step(model, train_loader, loss_fn, optim)\n", 644 | " valid_loss = eval(model, valid_loader, loss_fn)\n", 645 | "\n", 646 | " loss_hist['train'].append(train_loss)\n", 647 | " loss_hist['valid'].append(valid_loss)\n", 648 | " \n", 649 | " curr_time = time.time()\n", 650 | " print(f'Epoch: {i+1}\\tepoch time {(curr_time - prev_time) / 60:.2f} min')\n", 651 | " print(f'\\ttrain loss: {train_loss:.4f}, valid loss: {valid_loss:.4f}')\n", 652 | "\n", 653 | " if (i + 1) % 10 == 0:\n", 654 | " print('\\nGenerated text:')\n", 655 | " text_generator.generate(model, 'recipe for', MAX_VAL_TOKENS, 1.0)\n", 656 | " print('\\n')" 657 | ] 658 | }, 659 | { 660 | "cell_type": "markdown", 661 | "id": "676da91c-d20f-4d0f-a6c5-98f73e6a887a", 662 | "metadata": {}, 663 | "source": [ 664 | "## 7. Generate texts" 665 | ] 666 | }, 667 | { 668 | "cell_type": "code", 669 | "execution_count": 19, 670 | "id": "ab6bbcdf-40ba-439b-a3f4-adfc604a2504", 671 | "metadata": {}, 672 | "outputs": [], 673 | "source": [ 674 | "# print prompt and top k candidate words probability\n", 675 | "def print_probs(info, index_to_word, top_k=5):\n", 676 | " assert len(info) > 0, 'Please make `output_info=True`'\n", 677 | " for i in range(len(info)):\n", 678 | " start_tokens, word_probs = info[i].values()\n", 679 | " start_prompts = [index_to_word[token] for token in start_tokens if token != 0]\n", 680 | " start_prompts = ' '.join(start_prompts)\n", 681 | " print(f'\\nPrompt: {start_prompts}')\n", 682 | " # word_probs\n", 683 | " probs_sorted = np.argsort(word_probs)[::-1][:top_k]\n", 684 | " for idx in probs_sorted:\n", 685 | " print(f'{index_to_word[idx]}\\t{word_probs[idx] * 100:.2f}%')" 686 | ] 687 | }, 688 | { 689 | "cell_type": "code", 690 | "execution_count": 20, 691 | "id": "4b9b3d33-793a-4c33-b387-58e1caa356ef", 692 | "metadata": {}, 693 | "outputs": [ 694 | { 695 | "name": "stdout", 696 | "output_type": "stream", 697 | "text": [ 698 | "recipe for roast duck lo mein\n", 699 | "\n", 700 | "Prompt: recipe for roast duck\n", 701 | "chicken\t21.73%\n", 702 | "turkey\t19.65%\n", 703 | "pork\t11.71%\n", 704 | "beef\t11.62%\n", 705 | "rack\t4.90%\n", 706 | "\n", 707 | "Prompt: recipe for roast duck lo\n", 708 | "with\t46.12%\n", 709 | "breasts\t15.77%\n", 710 | "legs\t9.92%\n", 711 | "breast\t9.39%\n", 712 | "|\t2.87%\n", 713 | "\n", 714 | "Prompt: recipe for roast duck lo mein\n", 715 | "mein\t93.91%\n", 716 | "with\t1.38%\n", 717 | "|\t0.26%\n", 718 | "\t0.22%\n", 719 | "lamb\t0.11%\n" 720 | ] 721 | } 722 | ], 723 | "source": [ 724 | "# Candidate words probability with temperature = 1.0\n", 725 | "info = text_generator.generate(model, \n", 726 | " 'recipe for roast', \n", 727 | " max_tokens=6, \n", 728 | " temperature=1.0, \n", 729 | " output_info=True)\n", 730 | "\n", 731 | "print_probs(info, index_to_word, 5)" 732 | ] 733 | }, 734 | { 735 | "cell_type": "code", 736 | "execution_count": 21, 737 | "id": "36b3ecf1-b896-4c30-8293-499b10b8d8a7", 738 | "metadata": {}, 739 | "outputs": [ 740 | { 741 | "name": "stdout", 742 | "output_type": "stream", 743 | "text": [ 744 | "recipe for roast chicken with fresh\n", 745 | "\n", 746 | "Prompt: recipe for roast chicken\n", 747 | "chicken\t59.02%\n", 748 | "turkey\t35.69%\n", 749 | "pork\t2.68%\n", 750 | "beef\t2.58%\n", 751 | "rack\t0.03%\n", 752 | "\n", 753 | "Prompt: recipe for roast chicken with\n", 754 | "with\t100.00%\n", 755 | "breasts\t0.00%\n", 756 | "|\t0.00%\n", 757 | "thighs\t0.00%\n", 758 | "legs\t0.00%\n", 759 | "\n", 760 | "Prompt: recipe for roast chicken with fresh\n", 761 | "fresh\t56.15%\n", 762 | "rosemary\t19.81%\n", 763 | "lemon\t13.31%\n", 764 | "roasted\t3.59%\n", 765 | "garlic\t1.65%\n" 766 | ] 767 | } 768 | ], 769 | "source": [ 770 | "# Candidate words probability distribution with temperature = 1.0\n", 771 | "info = text_generator.generate(model, \n", 772 | " 'recipe for roast', \n", 773 | " max_tokens=6, \n", 774 | " temperature=0.2, \n", 775 | " output_info=True)\n", 776 | "\n", 777 | "print_probs(info, index_to_word, 5)" 778 | ] 779 | }, 780 | { 781 | "cell_type": "code", 782 | "execution_count": 22, 783 | "id": "86f03bbb-a19c-441a-bf7b-06d5fc077943", 784 | "metadata": {}, 785 | "outputs": [ 786 | { 787 | "name": "stdout", 788 | "output_type": "stream", 789 | "text": [ 790 | "recipe for roast chicken breasts | preheat oven to 425°f . halve onions lengthwise , reserving soaking seeds . toss scallions with garlic , salt , and 1 / 2 teaspoon pepper . in a small saucepan combine lemongrass , scallions , and scallions and toss with lime juice , parmesan , salt , and pepper to taste until combined . season with salt , pepper , and remaining herb mixture . garnish with lemon wedges and chives and serve immediately , passing olive oils . ( older chorizo peas can be made 2 hours ahead and refrigerated , punched\n" 791 | ] 792 | } 793 | ], 794 | "source": [ 795 | "# generate text with temperature = 1.0\n", 796 | "info = text_generator.generate(model, \n", 797 | " 'recipe for roast', \n", 798 | " max_tokens=100, \n", 799 | " temperature=1.0, \n", 800 | " output_info=True)" 801 | ] 802 | }, 803 | { 804 | "cell_type": "code", 805 | "execution_count": 23, 806 | "id": "a8334679-0dae-4132-8709-544b8c67f282", 807 | "metadata": {}, 808 | "outputs": [ 809 | { 810 | "name": "stdout", 811 | "output_type": "stream", 812 | "text": [ 813 | "recipe for roast turkey with creamy mushroom - wine glaze | preheat oven to 350°f . butter 13x9x2 - inch glass baking dish . mix first 4 ingredients in small bowl . season with salt and pepper . place 1 / 4 cup cheese in center of each . sprinkle with salt and pepper . bake until golden brown , about 15 minutes . cool slightly . cut into wedges . ( can be made 1 day ahead . cover and refrigerate . ) preheat oven to 400°f . place 1 / 4 of cheese in center of each of\n" 814 | ] 815 | } 816 | ], 817 | "source": [ 818 | "# generate text with temperature = 0.2\n", 819 | "info = text_generator.generate(model, \n", 820 | " 'recipe for roast', \n", 821 | " max_tokens=100, \n", 822 | " temperature=0.2, \n", 823 | " output_info=True)" 824 | ] 825 | } 826 | ], 827 | "metadata": { 828 | "kernelspec": { 829 | "display_name": "Python 3 (ipykernel)", 830 | "language": "python", 831 | "name": "python3" 832 | }, 833 | "language_info": { 834 | "codemirror_mode": { 835 | "name": "ipython", 836 | "version": 3 837 | }, 838 | "file_extension": ".py", 839 | "mimetype": "text/x-python", 840 | "name": "python", 841 | "nbconvert_exporter": "python", 842 | "pygments_lexer": "ipython3", 843 | "version": "3.9.17" 844 | } 845 | }, 846 | "nbformat": 4, 847 | "nbformat_minor": 5 848 | } 849 | -------------------------------------------------------------------------------- /PyTorch/chapter_05_autoregressive/02_pixelcnn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "aebf8f39-66c9-47bb-9ddc-5c2627f4a917", 6 | "metadata": {}, 7 | "source": [ 8 | "# PixelCNN for FashionMNIST" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "739b06cd-02d6-4b0f-bff2-a1b1a05ced68", 14 | "metadata": {}, 15 | "source": [ 16 | "**The notebook has been adapted from the notebook provided in David Foster's Generative Deep Learning, 2nd Edition.**\n", 17 | "\n", 18 | "- Book: [Amazon](https://www.amazon.com/Generative-Deep-Learning-Teaching-Machines/dp/1098134184/ref=sr_1_1?keywords=generative+deep+learning%2C+2nd+edition&qid=1684708209&sprefix=generative+de%2Caps%2C93&sr=8-1)\n", 19 | "- Original notebook (tensorflow and keras): [Github](https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition/blob/main/notebooks/05_autoregressive/02_pixelcnn/pixelcnn.ipynb)" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 1, 25 | "id": "393d1206-2d7f-421f-a92c-0a144102c509", 26 | "metadata": {}, 27 | "outputs": [ 28 | { 29 | "name": "stdout", 30 | "output_type": "stream", 31 | "text": [ 32 | "cuda\n" 33 | ] 34 | } 35 | ], 36 | "source": [ 37 | "import time\n", 38 | "import numpy as np\n", 39 | "import matplotlib.pyplot as plt\n", 40 | "\n", 41 | "import torch\n", 42 | "from torch import nn\n", 43 | "from torch.nn import functional as F\n", 44 | "from torch.utils.data import Dataset, DataLoader\n", 45 | "\n", 46 | "from torchvision import datasets\n", 47 | "from torchvision import transforms as Transforms\n", 48 | "\n", 49 | "import torchinfo\n", 50 | "\n", 51 | "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 52 | "print(DEVICE)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "id": "37f66fa4-5289-4d7c-a5ca-3d95781ab430", 58 | "metadata": {}, 59 | "source": [ 60 | "## 0. Train Parameters" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 2, 66 | "id": "64d2939e-8ba4-46c0-8017-ca4e0a66ede3", 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "IMAGE_SIZE = 16\n", 71 | "CHANNELS = 1 # The number of image channels\n", 72 | "PIXEL_LEVELS = 8\n", 73 | "N_FILTERS = 128\n", 74 | "RESIDUAL_BLOCKS = 5\n", 75 | "BATCH_SIZE = 64\n", 76 | "EPOCHS = 100" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "id": "76640e9c-a3eb-4999-b7e5-b64cb0d8e048", 82 | "metadata": {}, 83 | "source": [ 84 | "## 1. Preparing FashionMNIST dataset" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 3, 90 | "id": "8d27d90e-5f5e-49c7-a768-7f453510359d", 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "# Turn input image into (label-pixel representatin, pixel-wise labels)\n", 95 | "def collate_fn(batch):\n", 96 | " batch = torch.stack([data[0] for data in batch])\n", 97 | " value_step = 1.0 / PIXEL_LEVELS\n", 98 | " # Convert image to integer labels with provided pixel levels\n", 99 | " labels = (batch / value_step).type(torch.long)\n", 100 | " imgs = labels.type(torch.float32) / PIXEL_LEVELS\n", 101 | " return imgs, labels\n", 102 | "\n", 103 | "def get_dataloader(train=True):\n", 104 | " transform_fn = Transforms.Compose([\n", 105 | " Transforms.ToTensor(),\n", 106 | " Transforms.Resize(IMAGE_SIZE, antialias=True), \n", 107 | " ])\n", 108 | " \n", 109 | " # Load FashionMNIST dataset\n", 110 | " fashion_ds = datasets.FashionMNIST('../../data', \n", 111 | " train=train, \n", 112 | " download=True,\n", 113 | " transform=transform_fn)\n", 114 | "\n", 115 | " # Get train dataloader\n", 116 | " dataloader = DataLoader(fashion_ds, batch_size=BATCH_SIZE, shuffle=True,\n", 117 | " num_workers=8, collate_fn=collate_fn)\n", 118 | "\n", 119 | " return dataloader" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 4, 125 | "id": "9ee97742-720d-4c6b-bf5b-0c0870553f73", 126 | "metadata": {}, 127 | "outputs": [ 128 | { 129 | "data": { 130 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjYAAACxCAYAAADXnPd8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAAsTAAALEwEAmpwYAAAM7klEQVR4nO3dMVbcyBYG4O53Zil2YkLPJkxuljAZrIE14GhmCZDDKibECd5Lv+ideZQucKe61C1dfV+GLKlF0V2+R/r71v5wOOwAACr4z7kvAABgFIUNAFCGwgYAKENhAwCUobABAMpQ2AAAZfz23j/u9/th3wX/9u3bq58/ffo02efXr1+TbU9PTx+eOzpX+3rR+V9eXlLXkHE4HPZdB75j5Pi3Y3R3dzfZ5/Pnz5Nt7RhF4xONf3Su9jV//PgRX2yHpY9/daPH39jnVXjvZ+fwVvT/Q+8c3mvt7/0//vhjsu3PP//88Ljff/99su3vv/8eck1Zb429OzYAQBkKGwCgDIUNAFDG/r0lFXqf9T0+Pk62tZmLbMamV+/5b25uJtsyOZ+lP+du8y2Z59e7XTyOGZmxzuSgspY+/r2+fPny4T4/f/7sOu7q6mqy7f7+PnX+1tpzBmu2tvd+Nk8T5fRaUU7y1LmbJb/3v379+u7P2eMi2TxNu9/IHI6MDQBQnsIGAChDYQMAlHF0xub6+nqyLeqRshbRs9j2+W+0z9Kfc7e5p5H5ll7R+6S3t83Sxz8jysU8PDx8eFy0z/fv31Pnz5wryuK0lpwzqG7p7/1RPbSO0c4rmdxk1lLe+1E/mrX466+/uo6TsQEAylPYAABlKGwAgDIUNgBAGe8ugpkRNbTLyAZVRwbIMk0CM82jRi7eeCrteEfhuUxDrKzM3+3UYeU1asO8mQBw9lxRUPji4mKyrQ0iZwLN8D+9jfZGaufwkeFhlscdGwCgDIUNAFCGwgYAKENhAwCUcXR4OAp99a4I3Ssbes1cVxRorRBybUPe0VhkO4JmRAFrgb33Rato397edp2r97iI8PB4UefvaFv7mfEZ4i3RitwjV9IeJbrO0dfujg0AUIbCBgAoQ2EDAJRxdMbm8vJysi3KarTPj7M5nDnzOtHz6q1kQyrkhjiNNmMzMr9zDpkmnLtdfyPO6FzX19evfp57NeslaX+vaO4ZOc9n/m7R661pTowyKZn9lpi5mYM7NgBAGQobAKAMhQ0AUIbCBgAo4+jwcCRa8bsN1D0+Ps7x0u9qw2FR8HkrovDcyNW9qeuYFcaXIAr3Zr7wEIVSs4322gBtNEdW/JLCOUQh7K3ObUsID2eDziO5YwMAlKGwAQDKUNgAAGUobACAMmYJD1NH74romZWKOa8oBNx2GY5Eq5AvWfsejoK70Xu4DZxGX3iIjouCyL1djCvYanCX2CnCxO7YAABlKGwAgDIUNgBAGTI2GxVlYEausDvnquz0yazSvfbme5EoU5PRu9p2lLvJrO6daZpZYQVwc8PprGU17yh3c8y1u2MDAJShsAEAylDYAABlKGwAgDI2HR6OQmxR8K+iKDzcKxsGbPfbylifQyYY/PDwMNknatDXHvf8/HzcxQ0SvYfbkO5u198gLnNc9N6PrqHV+/lbY3j41GHhaGzbuWYrc89awsOjuWMDAJShsAEAylDYAABlKGwAgDJOFh5eQrfJ9hqikNlWVuE9x4q77WtuJcA3t0zgd7fb7a6urrrOtdTVvEcGhSOZsPvl5eVkW7SKfTvX9IaHfWb6rDF0/Z5TrJA9Shtgtro3AMC/oLABAMpQ2AAAZShsAIAyThYePkdY9SNLvKa5nDq8Hb1eG5iMQpYV9AZ3o27BmXNHgd/M60XnyjhF5+FMMHjuz2/7RYK7u7vJPr1h3qoh4Mw8c455dwtfColCuToPAwCsnMIGAChDYQMAlLGpBn1b1j7XPsffYyvvgSgrc3Fx0XVcZp9ole61izIR7fsnyrz0mjvz0ubJRjboy6wmvkaZBomszylyP+7YAABlKGwAgDIUNgBAGQobAKCMTTfog2NFTe4yQeHouCgY3O4XrbR96vDwuVb7rhQerfS7fKR37m8D19nzbGls/9/IUG7b7G9tjf7csQEAylDYAABlKGwAgDIUNgBAGScLDy/RVjrh7nbb+l3nlFkRO9NBOAoYR6tm39/fv/p5ZHA3Olfm2uEtUXD35eXl1c/ZrsuZ1dUj7ettWRv6jVYAzxx3jms4hjs2AEAZChsAoAyFDQBQxskyNm2zpSjzceomfltq5NSO/83NzWSfaPzb59XRPtEz8+h5eIXxzmRcMvssdUXupV4X69V+7nuzMu0c9pbsftVEuZhMVmZk5qX3GkY3AHTHBgAoQ2EDAJShsAEAylDYAABl7A+Hw7mvAQBgCHdsAIAyFDYAQBkKGwCgDIUNAFCGwgYAKENhAwCUobABAMpQ2AAAZShsAIAyFDYAQBkKGwCgDIUNAFCGwgYAKENhAwCU8dt7/7jf7w9zvfCnT58m2+7u7rrO9fnz58m2m5ubybanp6eu82ccDof96HPOOf7VVBj/L1++TLZdXV19eNz9/f1k28+fP4dcU9bo8e8d+2/fvk22XV9fv/r5169fk31eXl4m29r5Ippnonks2u+jc7+1LaPCe3+k6G8S/c1HWcp7P9K+96PPR+b9mpX5HP348WPY67019u7YAABlKGwAgDIUNgBAGfvD4e3HeXM+Z42e9T0+Pg47f/S8+vLyctj5W55zvxblRXpl8iJLH/92PDLZmWM8Pz9Ptj08PMz2eufIGURzSLQts0+UDciIshvR+dvcR/a4TDZk6e/9jOh3z2SOojxNZAsZmzZPs9v151bnFOVfe3M3MjYAQHkKGwCgDIUNAFDGu31s5pR5Fn6Mkd/N37Lb29tXP3///j11XG/GJsrTtNcwZ1ZkLhcXFx/uE41Zbz+a6PXa3M2pe92MFmUKej/3UQYjm98YJZoTR/b8WIpsNiqTe8r2Fqo4jq0ou7JE2R5zx+Si3LEBAMpQ2AAAZShsAIAyFDYAQBlnCw+fOpgXveacTZuqGNlob4mvdy4jg8LZ87eB4rWFh9vPbzY42nPuY46LQq/tftFxWwkPZ5shjgyCVxO9f87xf2qPU1ynOzYAQBkKGwCgDIUNAFCGwgYAKONs4WHIakOvS+88nAlAR52B1xbmPbU2FNq7Incks5J0trNrFI5sr30tQc85ROHe3qBwZqzZHndsAIAyFDYAQBkKGwCgjMVnbDLP1edeKXwrMk3dzqFdUbxd7XuNopxQb9O+uZv9LVWUebm7u5tsy8wP0bnauSY792QyHlHjveja+Uc2l7Tl/NJaRRkrq3sDAOwUNgBAIQobAKAMhQ0AUMbiw8NtYC9qpJUND2vc9L7eoPDIFbmj0Gt7/jZMzNTz8/O5L2F20ec5CuW2wcQoXBrNK5nVxCOZ8GoURN7K/NS7MnVvE7+Kqo3F6MC3OzYAQBkKGwCgDIUNAFCGwgYAKONs4eFsUK43UDdy5d+tmLvLcBsMzgZc2/DwErohvydzfdkAdO/vGoWw2/Gu2LE487mP5pQovJj5UsJWAr8jZYOvmdA3RNyxAQDKUNgAAGUobACAMs6WsclmYGRlTmdko73M+bOvt/bcxzlEY7vVpn3tHBJlPKJtbdO+bDYkM2dFDQG3LBrb3gaJkfZcslHLMrrhoDs2AEAZChsAoAyFDQBQhsIGAChj8Q36Rh3Ha1GDuEwzuGMCv5ljKwSFo3HMBHejfXob9EXHZca2wvh/JGr0Fs0rmQZx0Wri0X7tuaKw5Fbmtuzq3r0N+rYyjrzNHRsAoAyFDQBQhsIGAChDYQMAlHG28HCvqGOnsNjH2rDw7e3tya/h4eHh3Z93u9x1Lb2Dbu/1RYHfNnB9TLi3va6lj+MobQg1mkMynW8jmRXAI1aqfq13/M39RNyxAQDKUNgAAGUobACAMhQ2AEAZqwsPb0nUqTfTiTYTQj3mGlpRoLU3GFwhPBz97m14O/odonHMHJfVHluxy3AUOH15eek6VxtMjc6TDa/2howrGhkU7v3bUps7NgBAGQobAKAMhQ0AUMbZMjbZZ6NbbsB0dXU12dabuehdJbo1Mk+TtYVsyCm074GK4xjlN6JtGe0clZ2LormtzZD0XlNVmbE9JuPEtrhjAwCUobABAMpQ2AAAZShsAIAyFtWgrzcItuUmTW2Y+C2ZoGimGd+pg8JVjApvj7yG6G/JP9qA7zGBXyHXfy8T3vYlFCLu2AAAZShsAIAyFDYAQBkKGwCgjEWFhyOZcFjVYFgUym23ZVcAb7dFx0XB3TZgOnfgtGqgNRPM7l2BfeQ1rF20SvTT09OHx11fX3e9XjT3RCt5Z65hK+b+skfV/w/Ic8cGAChDYQMAlKGwAQDKOFvGJnoO2vscessN+qLGe9kVuJdoLdf5b2V+ryjj1K7wHu0TZaqqNjr8SDSvZBrrnbo56JbnrJGiv+0WMjbR+yfz/+c5sl5t5iz6m43+PLhjAwCUobABAMpQ2AAAZShsAIAy9ofD4dzXAAAwhDs2AEAZChsAoAyFDQBQhsIGAChDYQMAlKGwAQDK+C8cSfUsEUttDAAAAABJRU5ErkJggg==", 131 | "text/plain": [ 132 | "
" 133 | ] 134 | }, 135 | "metadata": { 136 | "needs_background": "light" 137 | }, 138 | "output_type": "display_data" 139 | } 140 | ], 141 | "source": [ 142 | "# Check dataset\n", 143 | "def plot_imgs(batch, num_rows=2, num_cols=6):\n", 144 | " plt.figure(figsize=(10, 3))\n", 145 | " for i in range(num_rows * num_cols):\n", 146 | " ax = plt.subplot(num_rows, num_cols, i+1)\n", 147 | " ax.imshow(batch[i], cmap='gray')\n", 148 | " ax.axis('off')\n", 149 | " plt.show()\n", 150 | "\n", 151 | "test_loader = get_dataloader()\n", 152 | "sample_data = next(iter(test_loader))\n", 153 | "\n", 154 | "plot_imgs(sample_data[0].permute(0, 2, 3, 1))" 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "id": "dad5f380-0074-4947-ba77-02c5a4bf6ae5", 160 | "metadata": {}, 161 | "source": [ 162 | "## 2. Build the PixelCNN\n", 163 | "\n", 164 | "This PyTorch implementation references pi-tau's GitHub repo: [Link](https://github.com/pi-tau/pixelcnn/blob/master/conv2d_mask.py)" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 5, 170 | "id": "1200c81a-8683-4290-a31a-ac4180f842da", 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "# Building MaskedConv2D layer\n", 175 | "class MaskedConv2D(nn.Conv2d):\n", 176 | " \n", 177 | " def __init__(self, mask_type, in_channels, out_channels, kernel_size, **kwargs):\n", 178 | " kwargs['padding'] = 'same'\n", 179 | " super().__init__(in_channels, out_channels, kernel_size, **kwargs)\n", 180 | "\n", 181 | " assert mask_type in ['A', 'B'], 'Mask type should be either A or B'\n", 182 | " if isinstance(kernel_size, int):\n", 183 | " kernel_size = (kernel_size, kernel_size)\n", 184 | "\n", 185 | " # Creating masks\n", 186 | " kh, kw = kernel_size\n", 187 | " mask = torch.ones_like(self.weight)\n", 188 | " mask[:, :, (kh // 2 + 1):, :] = 0\n", 189 | " mask[:, :, (kh // 2), (kw // 2 + 1):] = 0\n", 190 | " # If mask type is A, then masking the center pixel\n", 191 | " if mask_type == 'A':\n", 192 | " mask[:, :, (kh // 2), (kw // 2)] = 0\n", 193 | "\n", 194 | " # Making the mask the non-trainable parameter of the module\n", 195 | " self.register_buffer('mask', mask)\n", 196 | "\n", 197 | " def forward(self, x):\n", 198 | " return F.conv2d(x, self.weight * self.mask, \n", 199 | " stride=self.stride, padding=self.padding)" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 6, 205 | "id": "0e01823c-2bf5-44ee-ab1b-08abaf5192fb", 206 | "metadata": {}, 207 | "outputs": [ 208 | { 209 | "name": "stdout", 210 | "output_type": "stream", 211 | "text": [ 212 | "Type \"A\" mask of the conv layer:\n", 213 | "[[[[1. 1. 1. 1. 1.]\n", 214 | " [1. 1. 1. 1. 1.]\n", 215 | " [1. 1. 0. 0. 0.]\n", 216 | " [0. 0. 0. 0. 0.]\n", 217 | " [0. 0. 0. 0. 0.]]]]\n", 218 | "\n", 219 | "Type \"B\" mask of the conv layer:\n", 220 | "[[[[1. 1. 1. 1. 1.]\n", 221 | " [1. 1. 1. 1. 1.]\n", 222 | " [1. 1. 1. 0. 0.]\n", 223 | " [0. 0. 0. 0. 0.]\n", 224 | " [0. 0. 0. 0. 0.]]]]\n" 225 | ] 226 | } 227 | ], 228 | "source": [ 229 | "# Check the mask in the masked conv layer\n", 230 | "print(\"Type \\\"A\\\" mask of the conv layer:\")\n", 231 | "print(MaskedConv2D('A', 1, 1, 5).mask.numpy())\n", 232 | "\n", 233 | "print(\"\\nType \\\"B\\\" mask of the conv layer:\")\n", 234 | "print(MaskedConv2D('B', 1, 1, 5).mask.numpy())" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 7, 240 | "id": "4cb8f94b-00a2-4378-97f0-89df0c676cc7", 241 | "metadata": {}, 242 | "outputs": [], 243 | "source": [ 244 | "# Building the residual block\n", 245 | "class ResidualBlock(nn.Module):\n", 246 | "\n", 247 | " def __init__(self, in_channels, out_channels):\n", 248 | " super().__init__()\n", 249 | "\n", 250 | " # First regular 2D convolution\n", 251 | " self.conv_1 = nn.Sequential(\n", 252 | " nn.Conv2d(in_channels=in_channels, \n", 253 | " out_channels=out_channels // 2,\n", 254 | " kernel_size=1,\n", 255 | " stride=1),\n", 256 | " nn.ReLU())\n", 257 | "\n", 258 | " # Type 'B' masked convolution\n", 259 | " self.pixel_conv = nn.Sequential(\n", 260 | " MaskedConv2D(\n", 261 | " mask_type='B',\n", 262 | " in_channels=out_channels // 2,\n", 263 | " out_channels=out_channels // 2,\n", 264 | " kernel_size=3,\n", 265 | " padding='same'),\n", 266 | " nn.ReLU())\n", 267 | "\n", 268 | " # Second regular 2D convolution\n", 269 | " self.conv_2 = nn.Conv2d(in_channels=out_channels // 2,\n", 270 | " out_channels=out_channels,\n", 271 | " kernel_size=1)\n", 272 | "\n", 273 | " \n", 274 | " def forward(self, x):\n", 275 | " conv_x = self.conv_1(x)\n", 276 | " conv_x = self.pixel_conv(conv_x)\n", 277 | " conv_x = self.conv_2(conv_x)\n", 278 | " # residual connection\n", 279 | " return F.relu(conv_x + x)" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": 8, 285 | "id": "c982afc6-62e3-4688-94d2-d50c833edea0", 286 | "metadata": {}, 287 | "outputs": [ 288 | { 289 | "data": { 290 | "text/plain": [ 291 | "==========================================================================================\n", 292 | "Layer (type:depth-idx) Output Shape Param #\n", 293 | "==========================================================================================\n", 294 | "PixelCNN [1, 1, 16, 16, 8] --\n", 295 | "├─Sequential: 1-1 [1, 128, 16, 16] --\n", 296 | "│ └─MaskedConv2D: 2-1 [1, 128, 16, 16] 6,400\n", 297 | "│ └─ReLU: 2-2 [1, 128, 16, 16] --\n", 298 | "├─Sequential: 1-2 [1, 128, 16, 16] --\n", 299 | "│ └─ResidualBlock: 2-3 [1, 128, 16, 16] --\n", 300 | "│ │ └─Sequential: 3-1 [1, 64, 16, 16] 8,256\n", 301 | "│ │ └─Sequential: 3-2 [1, 64, 16, 16] 36,928\n", 302 | "│ │ └─Conv2d: 3-3 [1, 128, 16, 16] 8,320\n", 303 | "│ └─ResidualBlock: 2-4 [1, 128, 16, 16] --\n", 304 | "│ │ └─Sequential: 3-4 [1, 64, 16, 16] 8,256\n", 305 | "│ │ └─Sequential: 3-5 [1, 64, 16, 16] 36,928\n", 306 | "│ │ └─Conv2d: 3-6 [1, 128, 16, 16] 8,320\n", 307 | "│ └─ResidualBlock: 2-5 [1, 128, 16, 16] --\n", 308 | "│ │ └─Sequential: 3-7 [1, 64, 16, 16] 8,256\n", 309 | "│ │ └─Sequential: 3-8 [1, 64, 16, 16] 36,928\n", 310 | "│ │ └─Conv2d: 3-9 [1, 128, 16, 16] 8,320\n", 311 | "│ └─ResidualBlock: 2-6 [1, 128, 16, 16] --\n", 312 | "│ │ └─Sequential: 3-10 [1, 64, 16, 16] 8,256\n", 313 | "│ │ └─Sequential: 3-11 [1, 64, 16, 16] 36,928\n", 314 | "│ │ └─Conv2d: 3-12 [1, 128, 16, 16] 8,320\n", 315 | "│ └─ResidualBlock: 2-7 [1, 128, 16, 16] --\n", 316 | "│ │ └─Sequential: 3-13 [1, 64, 16, 16] 8,256\n", 317 | "│ │ └─Sequential: 3-14 [1, 64, 16, 16] 36,928\n", 318 | "│ │ └─Conv2d: 3-15 [1, 128, 16, 16] 8,320\n", 319 | "├─Sequential: 1-3 [1, 128, 16, 16] --\n", 320 | "│ └─Sequential: 2-8 [1, 128, 16, 16] --\n", 321 | "│ │ └─MaskedConv2D: 3-16 [1, 128, 16, 16] 16,512\n", 322 | "│ │ └─ReLU: 3-17 [1, 128, 16, 16] --\n", 323 | "│ └─Sequential: 2-9 [1, 128, 16, 16] --\n", 324 | "│ │ └─MaskedConv2D: 3-18 [1, 128, 16, 16] 16,512\n", 325 | "│ │ └─ReLU: 3-19 [1, 128, 16, 16] --\n", 326 | "├─Conv2d: 1-4 [1, 8, 16, 16] 1,032\n", 327 | "==========================================================================================\n", 328 | "Total params: 307,976\n", 329 | "Trainable params: 307,976\n", 330 | "Non-trainable params: 0\n", 331 | "Total mult-adds (M): 78.84\n", 332 | "==========================================================================================\n", 333 | "Input size (MB): 0.00\n", 334 | "Forward/backward pass size (MB): 3.42\n", 335 | "Params size (MB): 1.23\n", 336 | "Estimated Total Size (MB): 4.66\n", 337 | "==========================================================================================" 338 | ] 339 | }, 340 | "execution_count": 8, 341 | "metadata": {}, 342 | "output_type": "execute_result" 343 | } 344 | ], 345 | "source": [ 346 | "class PixelCNN(nn.Module):\n", 347 | "\n", 348 | " def __init__(self, in_channels, num_filters, num_res_blocks, ouput_size=PIXEL_LEVELS):\n", 349 | " super().__init__()\n", 350 | " self.channels = in_channels\n", 351 | " # Masked convolution block 1\n", 352 | " self.masked_conv_1 = nn.Sequential( \n", 353 | " MaskedConv2D(\n", 354 | " mask_type='A',\n", 355 | " in_channels=in_channels,\n", 356 | " out_channels=num_filters,\n", 357 | " kernel_size=7,\n", 358 | " stride=1,\n", 359 | " padding='same'),\n", 360 | " nn.ReLU()\n", 361 | " )\n", 362 | " # residual convolution blocks\n", 363 | " self.res_convs = nn.Sequential(*[\n", 364 | " ResidualBlock(\n", 365 | " in_channels=num_filters,\n", 366 | " out_channels=num_filters)\n", 367 | " for _ in range(num_res_blocks)])\n", 368 | "\n", 369 | " # Masked convolution block 2\n", 370 | " self.masked_conv_2 = nn.Sequential(*[\n", 371 | " nn.Sequential(\n", 372 | " MaskedConv2D(\n", 373 | " mask_type='B',\n", 374 | " in_channels=num_filters,\n", 375 | " out_channels=num_filters,\n", 376 | " kernel_size=1,\n", 377 | " padding='valid'),\n", 378 | " nn.ReLU())\n", 379 | " for _ in range(2)],\n", 380 | " )\n", 381 | " # Output convolution\n", 382 | " self.output_conv = nn.Conv2d(in_channels=num_filters,\n", 383 | " out_channels=ouput_size,\n", 384 | " kernel_size=1,\n", 385 | " stride=1,\n", 386 | " padding='valid')\n", 387 | " # We don't need a softmax layer when using CrossEntropy Loss in PyTorch\n", 388 | "\n", 389 | " def forward(self, x):\n", 390 | " x = self.masked_conv_1(x)\n", 391 | " x = self.res_convs(x)\n", 392 | " x = self.masked_conv_2(x)\n", 393 | " x = self.output_conv(x)\n", 394 | " # Manipulate the shape making predictions at the end of the output tensor\n", 395 | " x = x.reshape(x.shape[0], self.channels, PIXEL_LEVELS, IMAGE_SIZE, IMAGE_SIZE)\n", 396 | " x = x.permute(0, 1, 3, 4, 2) \n", 397 | " return x\n", 398 | "\n", 399 | "pixel_cnn = PixelCNN(CHANNELS, N_FILTERS, RESIDUAL_BLOCKS)\n", 400 | "torchinfo.summary(model=pixel_cnn, input_size=(1, 1, IMAGE_SIZE, IMAGE_SIZE))" 401 | ] 402 | }, 403 | { 404 | "cell_type": "markdown", 405 | "id": "b9093698-308f-46d3-b045-d41e0e2fde53", 406 | "metadata": {}, 407 | "source": [ 408 | "## 3. Train step functions" 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": 9, 414 | "id": "615bc215-f7fd-4112-81aa-d442388baa6c", 415 | "metadata": {}, 416 | "outputs": [], 417 | "source": [ 418 | "class ImageGenerator:\n", 419 | "\n", 420 | " def __init__(self, num_imgs):\n", 421 | " self.num_imgs = num_imgs\n", 422 | "\n", 423 | " # Sample from the model's output distribution with temperature\n", 424 | " def sample_from(self, probs, temperature):\n", 425 | " probs = probs ** (1 / temperature)\n", 426 | " probs = probs / probs.sum()\n", 427 | " return np.random.choice(len(probs), p=probs)\n", 428 | "\n", 429 | " # Generate new image pixel-by-pixel\n", 430 | " def generate(self, model, temperature):\n", 431 | " model.eval()\n", 432 | " \n", 433 | " generated_imgs = np.zeros(shape=(self.num_imgs, 1, IMAGE_SIZE, IMAGE_SIZE))\n", 434 | " batch, channels, rows, cols = generated_imgs.shape\n", 435 | "\n", 436 | " for row in range(rows):\n", 437 | " for col in range(cols):\n", 438 | " for channel in range(channels):\n", 439 | " with torch.no_grad():\n", 440 | " probs = model(torch.tensor(generated_imgs, dtype=torch.float32).cuda())[:, :, row, col]\n", 441 | " probs = nn.functional.softmax(probs, dim=-1).squeeze() # We don't have this layer in model\n", 442 | " probs = probs.detach().cpu().numpy()\n", 443 | " generated_imgs[:, channel, row, col] = [\n", 444 | " self.sample_from(x, temperature) for x in probs\n", 445 | " ]\n", 446 | " generated_imgs[:, channel, row, col] /= PIXEL_LEVELS\n", 447 | " return generated_imgs" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": 10, 453 | "id": "e7cc8d87-2773-46c8-9173-e74a0deeddcd", 454 | "metadata": {}, 455 | "outputs": [], 456 | "source": [ 457 | "# train step function\n", 458 | "def trainer(model, dataloader, loss_fn, optim):\n", 459 | " model.train()\n", 460 | " train_loss = 0\n", 461 | " \n", 462 | " for imgs, targets in dataloader:\n", 463 | " imgs, targets = imgs.to(DEVICE), targets.to(DEVICE)\n", 464 | " optim.zero_grad()\n", 465 | " logits = model(imgs)\n", 466 | " \n", 467 | " loss = loss_fn(logits.reshape(-1, PIXEL_LEVELS), targets.reshape(-1))\n", 468 | " loss.backward()\n", 469 | " optim.step()\n", 470 | " train_loss += loss.item()\n", 471 | "\n", 472 | " return train_loss / len(dataloader)\n", 473 | "\n", 474 | "# validation function\n", 475 | "def validation(model, dataloader, loss_fn):\n", 476 | " model.eval()\n", 477 | " valid_loss = 0\n", 478 | "\n", 479 | " with torch.no_grad():\n", 480 | " for imgs, targets in dataloader:\n", 481 | " imgs, targets = imgs.to(DEVICE), targets.to(DEVICE)\n", 482 | " logits = model(imgs) \n", 483 | " loss = loss_fn(logits.reshape(-1, PIXEL_LEVELS), targets.reshape(-1))\n", 484 | " valid_loss += loss.item()\n", 485 | "\n", 486 | " return valid_loss / len(dataloader)" 487 | ] 488 | }, 489 | { 490 | "cell_type": "markdown", 491 | "id": "7a103523-b6f8-4b1e-a1ba-43e65e03799a", 492 | "metadata": {}, 493 | "source": [ 494 | "## 4. Define the model, dataloader, objective, and optimizer" 495 | ] 496 | }, 497 | { 498 | "cell_type": "code", 499 | "execution_count": 11, 500 | "id": "9e468fb7-901c-46e2-84d9-d521e899f717", 501 | "metadata": {}, 502 | "outputs": [], 503 | "source": [ 504 | "pixel_cnn = PixelCNN(CHANNELS, N_FILTERS, RESIDUAL_BLOCKS).to(DEVICE)\n", 505 | "\n", 506 | "loss_fn = nn.CrossEntropyLoss()\n", 507 | "optim = torch.optim.Adam(pixel_cnn.parameters(), lr=5e-4)\n", 508 | "\n", 509 | "image_generator = ImageGenerator(num_imgs=12)\n", 510 | "\n", 511 | "train_loader = get_dataloader(train=True)\n", 512 | "valid_loader = get_dataloader(train=False)" 513 | ] 514 | }, 515 | { 516 | "cell_type": "code", 517 | "execution_count": 12, 518 | "id": "c77c7a90-7220-4b6f-ae86-65bcc699775f", 519 | "metadata": {}, 520 | "outputs": [ 521 | { 522 | "name": "stdout", 523 | "output_type": "stream", 524 | "text": [ 525 | "Epoch 1\tTime:0.09 min\n", 526 | "\tTrain loss: 0.7048 Valid loss: 0.5852\n", 527 | "Epoch 5\tTime:0.09 min\n", 528 | "\tTrain loss: 0.5281 Valid loss: 0.5233\n", 529 | "Epoch 10\tTime:0.09 min\n", 530 | "\tTrain loss: 0.5002 Valid loss: 0.5005\n", 531 | "Epoch 15\tTime:0.09 min\n", 532 | "\tTrain loss: 0.4841 Valid loss: 0.4870\n", 533 | "Epoch 20\tTime:0.09 min\n", 534 | "\tTrain loss: 0.4738 Valid loss: 0.4766\n" 535 | ] 536 | }, 537 | { 538 | "data": { 539 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjYAAACxCAYAAADXnPd8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAAsTAAALEwEAmpwYAAAM2klEQVR4nO3dQVbjSBIGYDGvjwIbOET3lgPALeAKFFcwt4B9s+y+BGyq7uJZzPTrRzqAIJ2ypeD7dlbJsiynVPGkn8iT7XY7AQBU8J9j7wAAwCgKGwCgDIUNAFCGwgYAKENhAwCUobABAMr47aN/PDk5Ofrfgm82m673nZ6e7ix7fn5+8/rh4aFr25HtdnsybGP/t4TjvxZLP/7teDw7O+t638gxO9Lo4z/y2J+fn795fX19vbPOjx8/Rn3cwR1r7LfH9T2vr697788hXF1d7Sxrx8XLy0v0vsWO/Yxo7I88H/74448PX0/TNP3999+pZa33xr47NgBAGQobAKAMhQ0AUMaHGZs53dzc7CyLcgdRViYjk2GI1mlzOO8t42PRc9RW5hlqFe1Yi8b/r1+/Pt1OdD5k3vedtRmPNedplu7i4mJn2RIzNpk8zTTt5oiijM1SZK652fe1y7LX6mhbf/3115vX9/f3XfvwlfPWHRsAoAyFDQBQhsIGAChjloxNlAO4vLz88PV7MlmZkTmcaFs/f/5881qm4XPRM9P2GWnmuWr0vgqisRedE+1Yy2Zsov5Pxi0jRdmZKGPTnr/HOJ/bz8z24FmT3oxNbw+ZrChTMzd3bACAMhQ2AEAZChsAoAyFDQBQxizh4SgE2YYloyBjJrgbidbJbD+7D5l9561M+Ky36dMaG/u1YzQzrqP1spNnRufgUifQpI6np6dP14mCuyOb+EXbbxvyRY32Ms33ou8XNfubWzYo3F4rsxNQzqn3874SjnbHBgAoQ2EDAJShsAEAylDYAABlzBIezgYcW1EoNxuyzOjdVoXZvecM4PZ2vIxE+zVy+8cyZ+A8Ot8EhVmKNnAbdR4eOWt2FObNbD/qmrwmvcHgpf5fsA93bACAMhQ2AEAZChsAoIxZMjaRTHO8jGx+J8rTzDlT+KFknmFmZs2OnnNnnrVmZ+Ses+nTUp7jLkU01kdm02Ckx8fHnWVRvqU38xI1+8vM5p1pLjgyCzTaofM0h/aVfXfHBgAoQ2EDAJShsAEAylDYAABlHCw83CsT+O0NCkfvi2ZF3mw2b17f3t5+uu0RopBsFPpt/f7775+u0xsevru7+3SdaZo34LvmANw/ovB8b2PLSLQts9KzBFG4N1qWCfPysfY6f4xrZ+//BZmZyd/jjg0AUIbCBgAoQ2EDAJShsAEAytg7PLzUTr2ZQHEUsIze137HQ33nKOiVCQ9nOgH3dhC+v7/vel+0XjbINjLUvDa9gWKdh6Gu7LWzt1P9SL2B5X32yx0bAKAMhQ0AUIbCBgAoY/EN+kbK5A6iddrGZsdsdDZqBtcKTe7WLNscMmNkYz9gnZZw3Z87r5Pljg0AUIbCBgAoQ2EDAJShsAEAyjhYeHgtMws/PDwcexf4BtZyPgDr1TZ0HRnujYLIh/6jlPcas7pjAwCUobABAMpQ2AAAZShsAIAy9g4PZ7uetp1Ws11W25BldmZtsxsDWefn5zvLXl9fj7Anb7X7tYR9Yj0yYd65uwW32z/EbN/u2AAAZShsAIAyFDYAQBkKGwCgjIN1Hm5FnVezweDe97WBYt1fORbh9sNpA7htN9ZpmqaLi4udZU9PT5++b6Srq6tPl829DxDJBnd7uxFH29+ni7E7NgBAGQobAKAMhQ0AUMYsGZs58wMjsznA9xPlaaIGfe3MwVEG5uXlZdg+wJoceibvr3DHBgAoQ2EDAJShsAEAylDYAABlHK1Bn8AvcAy9gd9INojcapv/Zbf1nULHmWaEGhYez+imeiO5YwMAlKGwAQDKUNgAAGUobACAMg4WHm47Bj8/P++sc3Z2NmTb2fXMsAz1HTpw+/r6+uk6vaHjQ4n25fr6+tN1Mkb+HlE36CiY/fj4+OZ15jeapmX9Jktz6KBw9Hltd/B/uGMDAJShsAEAylDYAABlHK1B3zG0TQFlbFibqLFlNmP2XbUN+aKMxxIa3/Vmc+YQNb5r8yzZnErGyCxLlLtpl2WbNC5hXCzBMRrv7fOZ7tgAAGUobACAMhQ2AEAZChsAoIyT7XZ77H0AABjCHRsAoAyFDQBQhsIGAChDYQMAlKGwAQDKUNgAAGUobACAMhQ2AEAZChsAoAyFDQBQhsIGAChDYQMAlKGwAQDKUNgAAGX89tE/npycbEd90Pn5+ZvXP3782Fnn4uJi1MdNV1dXO8teX1+Hbb+13W5PRm9z5PFvRcen9/hHv+Whre34Z52enr55fXt7u7POz58/d5b9+vVrZ9nz8/O4HWuMPv5LOPZrsaSx347XaZqmy8vLN6/Pzs521onGcCt6XyQa5+32o/Ojl7E/jz///PPN6+j3Pz09DY+9OzYAQBkKGwCgDIUNAFDGhxmbXlHm4u7u7tP39WZg2vzONE3T09PTp8seHx+H7cOSRMej/U168zTRtiNLyN2sTZtFmKbdzMJms9lZJ8oL3NzcjNsxSMpkwKI8TTT2W9lcTLStdln2PGJ/7TUs+n2i/ExmTLzHHRsAoAyFDQBQhsIGACjjZLt9/0/mM39PH/VDibIrrZFZlmzuIyPK5lxfX3/6viX1knh5edlZNvIYZdzf3+8smzN3s6TjnxH1+4iMfO7f5m4eHh6GbVsvj+M51tjPZFmOIXNuRWO/t8+Tsf+xNs+UHSPt7xj9PpeXl/rYAAC1KWwAgDIUNgBAGQobAKCMvRv0ReHhjEOHWbOi79Pu65Ka+EXHMXNs5/4OmVD5ko7joUUBupFh3uxnHnofqCMbgJ9T1Owv2i/N946nd5y0v9k74eHwve7YAABlKGwAgDIUNgBAGQobAKCMWWb3zjhGcLQ3sNx2Hl7SzNWZrshL0c4o/p3Cw23X32g227m14btj7AO1HTq4G41hQeH1yYybr/yu7tgAAGUobACAMhQ2AEAZe2dsooxH1Jytt5FfRpvdmKZ4hut2We/7lmRko8Poe0bHqFe7rWgm9bWJGkS1eZppyuVZovdlZhzOPntu15Ox+Vd0HmXGfnTOfKfs2Fq0GY4lNBfkX5lrWNSM8T3u2AAAZShsAIAyFDYAQBkKGwCgjFka9GVCob1h4mzgN7NetA9LD/61IcfecG92VvD2eOwTVl7qjO5f0YYOo/BwJpQbhRc3m83OsjZUFwXoohm5o+23+/pdwsOZcya6FrTrZcdvdP3LXBOj/Wo/c0nNQaNxF4Xd23EWjc2RYzHah/YzM6F85hFdw6Lfv11Pgz4A4FtS2AAAZShsAIAyFDYAQBlHm92bPoeezbs38Lv0EHZGFAzOdCzNhON6ZyCOQnZR6Dizn9E+HHp25tGi8Zo5ZzLB3ayRQeTW0rt1R2NlCaH1KOj8HUXXikh7DcuEwqcpd/2Irqtf6Sqc4Y4NAFCGwgYAKENhAwCUUSJjk2201y5bYw6kbdD1+Pi4s06UF8g08ss894+O2ZJnP99H9Nw3ej68RJn8TLZR2jEyNr2zbc+Zb5lb5nrU25BzSbLN8drxmWni9t6y76rN1GTzLe16+1z3ejNV+1x33LEBAMpQ2AAAZShsAIAyFDYAQBmLCg9H4bneQGvUyCoK2q5dJiQ9Tctv7LVWmUZ4x5AJUC5136cpNwN3b/h/7nMh2vc2YJ8N3FcIC7ch0Lln915KAH6k9phFYd5oWea4jlonKxvu3icE7o4NAFCGwgYAKENhAwCUobABAMpYVHi4VxTEi0LHbdfeQ8+UzbpkAofZGbKX0A11CfuQFQV8eztctyHjkZ2He//gYeTn3d3dzfZ5I7SdhqOQa2+4d6nn2j6i75Tt1sz/uGMDAJShsAEAylDYAABlKGwAgDIWFR7OhO6idbJhvUyIsLebKd9DG+Jbc1Bxs9nsLFvy9+k9N3vDvL2f5xrydSO7EUdB5DV1Hh7Z5bfXkq8DGe7YAABlKGwAgDIUNgBAGYvK2IxsdhVtq7fBF/zj5ubmzevo2X2UXVmLijMjy7wsS5TfyM5W3arYoC9qxtdeU25vb4d93tqPV8QdGwCgDIUNAFCGwgYAKENhAwCUcbDw8MXFxYevs4SCOaY2SBuFF6PQ48PDw6fbjrY1pygUXDFIyDplmmGOnCl8ydrrRxQwjgLFh76mLIU7NgBAGQobAKAMhQ0AUIbCBgAoY1GdhyO9IePofULG7CvTATTTMTU7g28b/ttn5t/eIOHILqcQyYTWK3bF7hV970x4ODqGmevVSNk/UNjnt3XHBgAoQ2EDAJShsAEAyjhYxqbNt/TO2t37edNklt8Rrq6uUssynp6eUsuWpH3um82ftM+xs43F2kZc2efOUQOvVjQL+Xdt6MU6aSj5sfZ6kbnGVOCODQBQhsIGAChDYQMAlKGwAQDKONlut8feBwCAIdyxAQDKUNgAAGUobACAMhQ2AEAZChsAoAyFDQBQxn8B/Dm300jkFMwAAAAASUVORK5CYII=", 540 | "text/plain": [ 541 | "
" 542 | ] 543 | }, 544 | "metadata": { 545 | "needs_background": "light" 546 | }, 547 | "output_type": "display_data" 548 | }, 549 | { 550 | "name": "stdout", 551 | "output_type": "stream", 552 | "text": [ 553 | "Epoch 25\tTime:0.09 min\n", 554 | "\tTrain loss: 0.4662 Valid loss: 0.4692\n", 555 | "Epoch 30\tTime:0.09 min\n", 556 | "\tTrain loss: 0.4613 Valid loss: 0.4654\n", 557 | "Epoch 35\tTime:0.09 min\n", 558 | "\tTrain loss: 0.4573 Valid loss: 0.4608\n", 559 | "Epoch 40\tTime:0.09 min\n", 560 | "\tTrain loss: 0.4544 Valid loss: 0.4577\n" 561 | ] 562 | }, 563 | { 564 | "data": { 565 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjYAAACxCAYAAADXnPd8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAAsTAAALEwEAmpwYAAAL50lEQVR4nO3dvXnjyNIGUPJ7biiSIwWx604AYhZzUxhOChsG5a/M3Qli5FC58BqfpWaJKrWa+Cmd4xEiwBYIgvUAL6u3p9NpAwBQwf/NPQAAgFEUNgBAGQobAKAMhQ0AUIbCBgAoQ2EDAJTxn0t/3G63fguedDqdtqO3af/nrW3/f/v27WzZ8Xg8W/by8nK11xz5eqP3v2M/dnNzc7bseDyu6tjPenh4uPh4tP1+f7bs+fn53fWWfOzf3d29ehz9j/f392fLfv/+/epxtB8Oh8PZssz+Gumtfe+KDQBQhsIGAChDYQMAlHExYwNcR5SViHI3//3vf7u2dXt7++72o/UyGZsom8M0ove1gig/E2U/WiMzHVH+pF02dYbks3a73avHvTml6L1o8zubTX9OaTRXbACAMhQ2AEAZChsAoAwZG5hBJgPz1rLWX3/9lXrNKFOT0eZuMrkfrmON+aY2i9HmPpZibfmZjCgH08r8321fm7dEWRwZGwCAT1DYAABlKGwAgDIUNgBAGcLDMIMoBBoFhb9///7qcRQ67g0Fs3zte7v0Bn1RA7hMU7glBHejoG0bhl3COD8i0+Qw+r/b/zMzUeaSuGIDAJShsAEAylDYAABlKGwAgDJKhIezHVt7u3a2nVefnp66trMkmY6Um836wnJrkZlFOxIde5nuxPBZ7Tkjmsk5E1aNTL1eJArDttt/fHwc9npziM7n2e+CNXHFBgAoQ2EDAJShsAEAylhdxiZqRtY2MdtscrMnR9uKsg9tNmfpGZtMk6zsven2Pnr2HnOm2VXW2u9rf8bSG7JVkZ0hvRWdCzINE5f+vh4Oh7NlmUZ7vZaQ5VvKzNS9erMyP3/+PFvWvtfRtjOZpM1mnvO3KzYAQBkKGwCgDIUNAFCGwgYAKGN14eEodNcbxOttkjaVKLC12+1ePY4CfZkQWbZRUxse/kyAsDc83L7mmgJ9b4maRUbL2mN7Cc34ljyb+NQzpE/9fkz1w4Wo+V77ues990RGNonLnB+yr7f25nW9M3C3661tP7hiAwCUobABAMpQ2AAAZShsAIAyVhceXkJ4cipR2HZUiGtkGKx3nL0h4Gt2QJ3KyBD8V9EGfKNzwdT7Ndu9fG2iz2YbKI4CxlOHTHvPIWsLw2b0/jgj2heZ/TpyZvXRXLEBAMpQ2AAAZShsAIAyZsvYZBtiZWbkzmoboGXvvc/VJC07e+rconGuZexzWXKTu/dMkSHJNNqLjGy+1yvzehVyOJG1NM9cyzinsPZZzSOu2AAAZShsAIAyFDYAQBkKGwCgjE+Hh7PBvEwAtzfkl5kVOVqWXa81VRgx0yQrEjVcGhXc7Z0tNhpDtrFfux+i/fLjx4/ucc1hTY0m26Br9LmZYxxLDWBHweB2rNmxVw0ZM97j4+PZsvZ8vdvtzp4TnU/b9aJtLzl07IoNAFCGwgYAKENhAwCUobABAMrYnk6nt/+43Z79se3+ObLTZyYo95nAYO8sv09PTxcfbzabzd9//73t2vgF0f4ndjqdFr3/2+M2OoYy682h/VxGn6PR+z+z76N9s4TOw9cUHTfH43HRx351cxz7vR4eHs6WRSHgNjwcrdf745aR3tr3rtgAAGUobACAMhQ2AEAZChsAoIyLnYej7qi9Ad9M19DekF9vKDhrTV1iWabeY7S38+zIwOxcnYbfE+0bnXrhbVHn+CgY3D5vCUHhj3DFBgAoQ2EDAJShsAEAyvhwxiZzvz1znzvKHPTey4+aVl2zKddSMwcsV3vMZI+hkTNrt5/nbB4l20wQWLYoK9M7I/fd3d2wbY3mig0AUIbCBgAoQ2EDAJShsAEAyrg4u/ft7e3ZH3ub1fWuN7LhVm/wsh1DFKZc+uzS1dn/81rTDMfVOPbntfZjP2rQlxE1+5s6PGx2bwCgPIUNAFCGwgYAKENhAwCUcTE8fM0QU7Yz8Fpm6xXgm5f9P6+1ByjXzLE/r7Uf+1EH4d1u9+56h8PhbJnwMADAYAobAKAMhQ0AUMZsGZtq3Oeel/0/r7XnDNbMsT+vNR37UTO+qNFeRpTD2e/3XdvqJWMDAJSnsAEAylDYAABlKGwAgDIuhocBANbEFRsAoAyFDQBQhsIGAChDYQMAlKGwAQDKUNgAAGUobACAMhQ2AEAZChsAoAyFDQBQhsIGAChDYQMAlKGwAQDKUNgAAGX859Ift9vtaaqBrN3pdNqO3uY19//Dw8PZsvv7+65t/f79+2zZ4+Nj17Z6rW3/VzN6/4/c99++fXv1+ObmJrXe09PTu895eXnpGtNIVY/9P//88+Ljj9jv958ayyVLPvZHib4vIks577tiAwCUobABAMpQ2AAAZWxPp7dv5y3xXt9SLek+d3Q/tM3BPD8/p7Z1d3f37nN6t5VdL2NJ+z8jygv8+++/13q5q1tKzuD79+9ny9qMze3tbWpbx+Px1eMoTxPlcDLZnJHWduxHomO/XRZ9ZqLsTCaLMzJzs5Rj/5qiHGX03RBlbHa73VXGtNnI2AAAX4DCBgAoQ2EDAJRxsY8NyxflaUb2EhiZg/kqslmAzHp8TJSfyWRq2jxNJNv/hsuy+bJM5iz6XK05q/aWNs8S9RyLcjAZUQamfb3oO+XHjx9ny6JxHQ6HV4+j92z094wrNgBAGQobAKAMhQ0AUIbCBgAoQ3h4ZTIN89Yi+l8qhJX/+eefs2W/fv169bhiwHFqUZi3N+AbBYwzgWI+rrc5ZbWmlm/pPcdnQsCRzOTH0XOy5+rMBJqjm/i5YgMAlKGwAQDKUNgAAGUobACAMhYVHs6EjEZ21V2jTNArCoy162U7V/Z2s8yMM7K28HC2W7CuwtPIztw9altfqRvxNY/hzLYrBoV7RefXUUHhzeb8vB+dl3tDztmxf+a7wBUbAKAMhQ0AUIbCBgAoY7KMTSY/M+V2qshmYNrnRfsxus+Z2d/RvdCReZ25ZO77R8/5+fPn+MHAwkTHfm9WJlrWm+lZe54tyqC0y7Ln13a9kRnVzOttNuffD9F6ZvcGAHiDwgYAKENhAwCUobABAMq4Snj4cDicLWtDS1EoNTPDZ7Re9HqRdgy9QaqpmgSObLTX7qNs4Lf1mUZNme0vfcbv3mDifr8fOo4evQ3OvkJjtGgm70yDvpeXl2sMp5Q//vjj1ePsZ2Hk8bqEz19WJpSbDe62z5vjBxzta0bf86MDxa7YAABlKGwAgDIUNgBAGQobAKCMi+HhKJTb2/m3N3Dbvl70+qO2nTVVeDgKT/WG4NrAVm/4bKRsiKw1VZg4E17sDRNfO5Db2xV2TaaeyZvXouOpDQpn18s8L7PtCqLzW+YHMtlz+tTaMWR+FLPZCA8DAGw2G4UNAFCIwgYAKONixibKQGRyKdmGeaPW432ZXNDUjfCi93tJzfgibTZmqc3rljouavv169ewbY38rK09X9Z7Xpz6fJpp4Nqbbf0IV2wAgDIUNgBAGQobAKAMhQ0AUMaHZ/fOhFC32+3ZsjYwlG1yN0XQiHksPSjMst3c3Mw9hC9lqc0o1zKGr2qO87wrNgBAGQobAKAMhQ0AUIbCBgAo48Ph4alNNZM2wEd9pQDzNWe7Hykaw36/n3wc5EUzk3+GKzYAQBkKGwCgDIUNAFCGwgYAKGOy8HAbAl5zR+E1jx2+uijw+/Ly0rWt29vbzw5nNdrw8Jq6+S4h1Mz/izoRj+5O7IoNAFCGwgYAKENhAwCUsfgGfb0yOZje5n+aBsJ69DbRa9eLcjhfqUHfmq0pD8TnuWIDAJShsAEAylDYAABlKGwAgDImCw9nwrzXDPx+hrAwrMPxeDxb9vT09O56UQg42lart7EffFWjZ/KOuGIDAJShsAEAylDYAABlKGwAgDIW1Xk4Cun2zqQt8Au1RcHdKCicCQ9HvtLM3VCJKzYAQBkKGwCgDIUNAFDGbA36ZGCAKYxqohdlbjJN/GAud3d3rx7vdrt3n7PZnH8/j/y+vr+/P1sWNe17fn7ufg1XbACAMhQ2AEAZChsAoAyFDQBQxvZ0Os09BgCAIVyxAQDKUNgAAGUobACAMhQ2AEAZChsAoAyFDQBQxv8AYKBg8YShpgQAAAAASUVORK5CYII=", 566 | "text/plain": [ 567 | "
" 568 | ] 569 | }, 570 | "metadata": { 571 | "needs_background": "light" 572 | }, 573 | "output_type": "display_data" 574 | }, 575 | { 576 | "name": "stdout", 577 | "output_type": "stream", 578 | "text": [ 579 | "Epoch 45\tTime:0.09 min\n", 580 | "\tTrain loss: 0.4521 Valid loss: 0.4562\n", 581 | "Epoch 50\tTime:0.09 min\n", 582 | "\tTrain loss: 0.4500 Valid loss: 0.4546\n", 583 | "Epoch 55\tTime:0.09 min\n", 584 | "\tTrain loss: 0.4481 Valid loss: 0.4535\n", 585 | "Epoch 60\tTime:0.09 min\n", 586 | "\tTrain loss: 0.4469 Valid loss: 0.4533\n" 587 | ] 588 | }, 589 | { 590 | "data": { 591 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjYAAACxCAYAAADXnPd8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAAsTAAALEwEAmpwYAAAMc0lEQVR4nO3dMVobyRYGUPG+WYqdQMgCPBPiHC/BE8EWYA0QWUsYLcDhjBdAiBN7L3rBC+ZRusB1qVuqvpyTqd3dKtqi+36ln1sn2+12BQBQwX+OPQAAgKkobACAMhQ2AEAZChsAoAyFDQBQhsIGACjjt5f+8eTkZLa/Bb+4uNjZ9u7du1eP+/nz5862r1+/TjKmfWy325Opzznn9e/1+fPnnW1fvnxJHbter5+8/vPPPycZ02o1/vX//fffn7z+559/us5ze3ub2tar9/xTX/8pr33vfeXQonvi1dXVk9fROK+urob+7LcuLy93tmU+Y9E+m81mghHtZ+TPfiv6jN3d3U12/uvr651tcz6fn7v2ZmwAgDIUNgBAGQobAKCMFzM2vaLv8dptme+9V6vV6v37909e//jxI/V+0fd6I2RxRnZ+ft61T5udyZ4ryus8PDykti1dm7lZrXK5m+g4XtbmUrL3nna/fXI47bmiLEJ0H2u198MlOjs729l2enr66nFRNmeEjM2SRJ/93gxa77kOwYwNAFCGwgYAKENhAwCUsXfGJvpOLfN38VFWpne/aAxt/4dov/v7+9QYRpbJskx5XDbv0puLicb1VjI2GVEOJ+rvEe3XbnureZ1sXmBKbaamN+cT5QRHyTUwvmxGK5MnG6H303PM2AAAZShsAIAyFDYAQBkKGwCgjL3Dw5mmUqNoA3xREG/kQNRqNW9YuGJIt5rsIpXRtg8fPrx6/j/++KNjVMsX/d6397YpG/RlQ5ztPSq6Zy3pHsxxRUHzKZ95ozSQNGMDAJShsAEAylDYAABlKGwAgDJmWd070nYQjkJG2W7Erd6VR6MxjB4ebvWGibPnmvL8bTj5razkHcms5B3JrgqeOX8UMOZ5UTfzqHt5dO9p7zXZDsLtcRW6pcPczNgAAGUobACAMhQ2AEAZe2dssg152v168zRZmXEtcVXcNoMyZQYm8369+0x53BL1rqw954rcUcbmra74HWnvUdH9IsrdRPee3tW9XxtTFd+/f391n7OzswOMpLYoQxp9XqMMWGvkxpBmbACAMhQ2AEAZChsAoAyFDQBQxsEa9GVkG+3xr/V6fewh0JgygDtimHfEMc2hvR9lApW9516txlkZeUlOT093tmWCyPyaKLQuPAwAcAAKGwCgDIUNAFCGwgYAKGOo8HCv3oCxsB5zyKysHQVwsyt3zynzfjc3N/MPZEbR/SIK8/aa81xTnhue04aFs5+7UTpjm7EBAMpQ2AAAZShsAIAyFpexyTbx09iPY+nNyhw6TxPJjP3bt287+3z48GGmEe2vbSQWrch9f3+/sy2TM4juM705g1HyCXOLmupxGNlGe5lMzcifVzM2AEAZChsAoAyFDQBQhsIGACjjaOHhqDle7+q5vU2rBIw5lKWsiJ0JMN/e3u5s+/vvv6cfTIfodzqzCnG0TxuOdL8Yz9nZ2c42q3tPLxsUHqXprRkbAKAMhQ0AUIbCBgAoQ2EDAJSxd3h4ykBdJuSXle0SClPrXd17VO3PM/LYo3tI+3ufDTj2BiFH7sgKPZa2qrwZGwCgDIUNAFCGwgYAKENhAwCUsXd4OBsqygTxpuxamAnwRe8XBYyXFpximTKh3EwweeoxjBoWjoLCmY7m0e/zlH+4AKPKPmMzz8+Rn4tmbACAMhQ2AEAZChsAoIxZVveOvntrt/V+px199xd9bxiNYZSVR6HXCHmXKOdzc3Nz+IEEovtDJmPTe+4pRfenis3+otW3o1W6Mx4fH/cdzps35WdslCa4ZmwAgDIUNgBAGQobAKAMhQ0AUMYs4eHI9fX1q/v0Bo+yYcBMYHnkpkMwgrmbBM4tCktG954pg5CZgGZmn1HCmfuIAr+Xl5ev7tcbMOZfd3d3O9syz+alMWMDAJShsAEAylDYAABlKGwAgDL2Dg9HYaRecwd32w6ksDRR5+Glh3nnluk4vpQ/GljKOF+y2Wx2tkXB4NPT0yevdRneX/bzs/TPmRkbAKAMhQ0AUIbCBgAoY++MzdK/i3urzs/PU9seHh5efL3ve772fjwV5Wl6czfZlcKXnuFpm9pF96xM47t97nVtzqfiqt0sU2/2NDpulAaSZmwAgDIUNgBAGQobAKAMhQ0AUMbJdrs99hgAACZhxgYAKENhAwCUobABAMpQ2AAAZShsAIAyFDYAQBkKGwCgDIUNAFCGwgYAKENhAwCUobABAMpQ2AAAZShsAIAyFDYAQBm/vfSPJycn26ne6N27d09ev3//fmefr1+/TvV2Ke2YVqvV6ufPn13n2m63J/uOpzXl9c+4uLjY2XZ3d/fqcdE+9/f3k4wpq8L1X7Kpr/+U1/78/PzJ64eHh6lOPYTRP/uXl5dPXp+dne3sc3p6urOt3e/x8XFnn9vb251t379//8UR7mfkz351z117MzYAQBkKGwCgDIUNAFDGyXb7/Nd5vd/1RZmLq6urJ6+jPM3Hjx973i6tzdREmZIoY/Pjx49X9xn9e+6MzP9bJLoeUYZqThWu/5KNnDP4/Pnzk9fr9XqqUw9hpM9+lHlpMzZTinI3UcYmGtdURv7sVydjAwCUp7ABAMpQ2AAAZbzYxyYjymBE29ocRpTLiDIvvb1toh410fl79qlqyr4+U/5fQlbbs4b5RNmZKfM0bX4me+5Mn5xPnz71D4zhmbEBAMpQ2AAAZShsAIAyFDYAQBl7h4d7w7bZcG9v4DRqENfbNK4da9uwr4pDN9UDliu7mGXvopTR+Xu1weMoiLzZbCZ7P47LjA0AUIbCBgAoQ2EDAJShsAEAytg7PNwrCqr2drmNROHkjEwweMpxHkt0faKfXaCYJentPJw57uHhoevcVUVB4d79egPGve8XBZPfang486xc2jPPjA0AUIbCBgAoQ2EDAJQxS8bm/v7+1X2yGZh2pfB9vutrMyTZ/EjFnEl0HaP/t/b6V7wWS5Jd4fit5gUibX7my5cvXeeJMjbr9brrXBVM2UAvm9dZ6vsdS9v0Nnru9t7To+a5vQ11p2bGBgAoQ2EDAJShsAEAylDYAABl7B0e/vjxY9dxbSh1tcqtyD13eDU6fxs67m3+N7pMc8LsyuZVV0CfUxRovL29ffW4KYPC0RjakOjj4+Nk7ze1qNFeuy0KAfc29mMaUYO+OQO+c69MPrXMMyd6dkXP2Yw2dBz9sUm7z2q1Wt3d3e1sa4+Nfpbo/O0z5Ff+cMiMDQBQhsIGAChDYQMAlKGwAQDKONrq3pFsqOjQ2lDWKN0Vp9YbzM4GyzIdqd+yT58+7WzLBHejbsRROLINIkdhyWgM/I/Vveej8/C/ontntK3VGzCO9D53r6+vu47L/LHJrzw/zNgAAGUobACAMhQ2AEAZQ2VsMk16jrG6dNVMTY/sd69WAX9Z9B3/X3/99eq2bAYmyuK07xnlcEZuvpeRab6XbcYnUzO+kbMyvaI8TW9T2N77cCbz0pvD6W3e+ivHmbEBAMpQ2AAAZShsAIAyFDYAQBlDhYejgNShV9bONJurunJ19HMdOiBWVW8I+LXzPHeuKBic2ScKDy89UDwVYeI6NpvNzrZRVvKOmty1z71sKHju52WPaOyZ58WvBKHN2AAAZShsAIAyFDYAQBkKGwCgjKHCw5EpV5zOmDrEtHRt+Cy6riMG1I6pXUV7tRonmPj/3koouO00nA0BZzsUM5body3TdXtk7X239/n2VpixAQDKUNgAAGUobACAMhQ2AEAZRwsPz905UaB1Gpn/pyjIVqHzcBs4jERBxUwoN+og3BswzoxzH23QcuTQ8ZzdgXUehmUwYwMAlKGwAQDKUNgAAGUcLGMzVeZl7uZ4vSuPvmXR/2207f7+/hDDmUym0VeUb4maf11eXj55nc2pZPaL9sk0IOvN5oycsdFU77ii35m2YWX7u/Dccb0Zt4wo40YdZmwAgDIUNgBAGQobAKAMhQ0AUMbBwsO9jd569plb1TDxCNd2dJmwYhSuzawuHG27ubn5hdHB4WRD65vN5sXXMDUzNgBAGQobAKAMhQ0AUIbCBgAo42ire19cXMx6/qoB30Nrr6POzK+LwpFt0LLtxrpaxcHL3s6qc4p+vqib7JJEHYvX6/URRrJsI3el5u0wYwMAlKGwAQDKUNgAAGUcLWMzt95VwOVFXub69GmzMpqUUdGImTDeHjM2AEAZChsAoAyFDQBQhsIGACjjZLvdHnsMAACTMGMDAJShsAEAylDYAABlKGwAgDIUNgBAGQobAKCM/wJ2MjjR7x4isgAAAABJRU5ErkJggg==", 592 | "text/plain": [ 593 | "
" 594 | ] 595 | }, 596 | "metadata": { 597 | "needs_background": "light" 598 | }, 599 | "output_type": "display_data" 600 | }, 601 | { 602 | "name": "stdout", 603 | "output_type": "stream", 604 | "text": [ 605 | "Epoch 65\tTime:0.09 min\n", 606 | "\tTrain loss: 0.4455 Valid loss: 0.4531\n", 607 | "Epoch 70\tTime:0.09 min\n", 608 | "\tTrain loss: 0.4442 Valid loss: 0.4513\n", 609 | "Epoch 75\tTime:0.09 min\n", 610 | "\tTrain loss: 0.4435 Valid loss: 0.4532\n", 611 | "Epoch 80\tTime:0.09 min\n", 612 | "\tTrain loss: 0.4424 Valid loss: 0.4527\n" 613 | ] 614 | }, 615 | { 616 | "data": { 617 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjYAAACxCAYAAADXnPd8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAAsTAAALEwEAmpwYAAAMo0lEQVR4nO3dMZbbxtIGUMx/3lKsRAq9ADudBXCWIEfyFjjaghVJSyAXoND2AhyOktFe+KL/+E2zxCk2GyRYvDcjBDabIICpA3wq3O12uwkAoIL/u/QEAABGUdgAAGUobACAMhQ2AEAZChsAoAyFDQBQxn8O/ePd3Z3/Cz5N0+Pj44vXT09Pe+tsNpu70Z87cvv/+uuvL17/9ddfo4ZehN1ut+jtfy3evn27t+zh4eHVdVar1dDtf4vbPiPa9k9PT4ve99s5v3v3LvW+9jz77du3UVMaavS5p/q+v1qtUuu1+8kxf3ddsQEAylDYAABlKGwAgDIOZmxuUXT/b7PZvHi91Hu9x2gzN9NUL3fDS+2+He3rUf6hzUhU2P+XKMrPtL9HNp9yKdE+lckoRtr3RWPbF5cl2ofb3zGy3W73lp3yd9cVGwCgDIUNAFCGwgYAKOOmMzbR/cBIxfu4Mja3J3OvO8o/tPe/23vfP3ofP5bJ00TLom2/Xq/HTexEmR4lvVmZaPtUPDdfs7bn1TTtnz+iPM1ortgAAGUobACAMhQ2AEAZChsAoIybDg9HYbSqIcgoLAytTEAzE0LmsChkWeHck2kg2Bv4jcY+RxCVWBQCz/zng3NwxQYAKENhAwCUobABAMpQ2AAAZQgPN4TR4KVsh25OM2fwdg72i9sWhYeX8vfTFRsAoAyFDQBQhsIGACjjpjI27T3BqveIe5vxzf3E73b87DzbOXgKeZ8KDeC4XVXP19dqyRlVV2wAgDIUNgBAGQobAKAMhQ0AUEbZ8HAUNMs8lThapw1dLimE2fuk5SiAOzI8HI31559/vvq+v//+e9gcTtUbVlxSE7X/tdR5nVv0u0ZP225tNpu9Zdlt2rsvLelc08o0FJym/e8w9/sYL2rGF4l+o/Z3jI6ZzPFxzPnLFRsAoAyFDQBQhsIGAChDYQMAlFE2PByFatuAUhRGikJS7fuWFOibO/Db28U4mkMbDP7ll1/21smEoXvndKwoCNcui/YFIcfLiUKI7e8RHeO9v1kUKM6MFb3v2kTfc2R34Mx2jD5PSH68bHg4Wi/73lFzmCZXbACAQhQ2AEAZChsAoIwSGZvofvWo+3qRJd3DHdm8bs4neUfLslmZc2VqWr1Zqt6MxchszrXnfDJZjaipXuZ9vdvmlHNKJjuWyQct6dyTlck2RjIN+qJl17iNlqbd10fmqXp/n2OOW1dsAIAyFDYAQBkKGwCgDIUNAFDG4sPDbUApCuFlQ31taKk3HDqyCdUtGdnsr7Ver7vGPiQKubXL5m5I1oaMe5+MfMocRo19jN5gcLbpZo9Tvnc7h2is6DsvqRlor3OERRmrd9v3Prk745hjwRUbAKAMhQ0AUIbCBgAoQ2EDAJSxqPBwFPLLdOzs7WR5jc7dhbfd/lGQd2TH4mvTG+aN3pcNlN6CzLEahRJ7g4q9T4nu/bzo+1U4P/UGhSt8d+b9jzU6DwMAN0lhAwCUobABAMpQ2AAAZVwsPHxKB+FLW3pHzCjMmwkdnzuYfEt695klhirPMaftdptaljFnoLE3LHst57pLao+ZJR4LVbXHzMhjqDeofwxXbACAMhQ2AEAZChsAoIyTMzbR/bIoT9BmaqL7pb1P741E42ee7r2Upxv/yJw5GM33jte7Dy09p1XJ6Pv3I/TmhWC07N/wc2vnpUEfAHCTFDYAQBkKGwCgDIUNAFDGwfBw79O2M0aGkzJB4VPGqtgYKhtCbtcTJn5pZDA1+8TvS1vinFiuNgQaBaej4yjTxPBWQtinBGn/V3TsPjw8dI11bsc0CXTFBgAoQ2EDAJShsAEAylDYAABlHAwPnzvMOPKpnyPH8pTZf0WhY4Hi411TSH2JnXu5vGh/jQK/7f5jfzos2oajOuJHQeH1ep2b2IxGP/HbFRsAoAyFDQBQhsIGACjjYMZms9nsLZvzycUfP358dZ2R2YRrvNfbZlxGNUxcsmv7zpn96phmU5fWzvUajxvO49z7dfS35tqb9mUaE0bHYO958VqO52Pm6YoNAFCGwgYAKENhAwCUobABAMq42+12l54DAMAQrtgAAGUobACAMhQ2AEAZChsAoAyFDQBQhsIGAChDYQMAlKGwAQDKUNgAAGUobACAMhQ2AEAZChsAoAyFDQBQhsIGACjjP4f+8e7ubneuiWS9fft2b9m7d+/2lq1Wq67xo7G22+2L15vNZm+dp6enu64PPGCJ2z/yxx9/7C17fn7eW/b169cXr79//z5sDrvd7ma3/xKM3v4jt317znh4eEi97+np6eDraZqmb9++9U9sEPv+ZS153+/1888/v7rOP//8c9Y5RJ/3o23vig0AUIbCBgAoQ2EDAJRxMGNTXZSnie6jt8uWcF99Dh8+fNhb9ubNmxevf/rpp1fXmaZpur+/f3VZm7mZpmn69OnTq/OEH4kyeG1GLhId9+1Y0fni8fExPzlOEv22UV4q+i0z+8AtiLIzUXbl/fv3r47122+/DZnTNOXnleWKDQBQhsIGAChDYQMAlHFyxia67zkygxKN34run7qneliUlYlyMVF+pldmLBkbThFlLjLnkEib1YjGmfv8x7+i33a9Xu8ti7Z/lI/qcW2Zqkw/msjI/MwluGIDAJShsAEAylDYAABlKGwAgDJODg9nGyRFyzIyoa9so71e7fgjx76UkaHgSBRObhvyRXOI3jfyYZnU1hsUZvmyv220Xhsyzga8r+lcv4QHVy6FKzYAQBkKGwCgDIUNAFCGwgYAKOPo8PBqtXrxOhvoyoSAlxrUupWne5/b8/Pz3rIoUCw8TFbmPJM9fh3ny3JK9+Bb+C2jJ3J/+fJlts+LwsqZAPOcc/p/rtgAAGUobACAMhQ2AEAZBzM2UX4mc5+zzeFEsnmaORv79ap6v3bupn2Z8aMGfbesPQazzSh799HMMb/dbrvGHi2b72u3RfS+pXwn/qXZ4ulGNe2LxonelxkrygL1zutHXLEBAMpQ2AAAZShsAIAyFDYAQBlHN+h7fHw8+HqacuHhrJFjZWTCykttJMh1i/b1zP4frROFhzP7bTTWtQdrM987CmVX/U8C1+Lh4eHSU1isTCg4Wi/7vnOL5tWGjI9p7OeKDQBQhsIGAChDYQMAlKGwAQDKODo83IqCeR8/ftxb1ttFMhP8y4Ybe0OEbaCyQqgw2+G3fQL33N2J7+/v95Z9+vRp1s88h3b/j4KRvZ2+I72h+2jfbpdVDBPP2an8GkX7z7m3UeZvxtzn4vY7L6Ubcm8IONsteFTH4qyoG3E7fhQe/vz5czieKzYAQBkKGwCgDIUNAFDG0Rmb9t5rdN915H3IkfdQe+8RZ54OzBjfv3+/9BSOks3FzHnczJ0zaOe+2Wxm/bysJeRiqjb2620WGVnC9mhzVb37zlKaBvY26LvEHHp5ujcAwKSwAQAKUdgAAGUobACAMg6Gh0c261pqMDLj2huSRaJGeJHehny9IeBs48BzyOyzUQgxs2xkAD071qhjaQmh3VNE2yH6TqMCp9dozu96ib8Fvd+n3Qd6A9SX0hvAbYPBI5vxRUaP74oNAFCGwgYAKENhAwCUobABAMo4+eneWUsIAcMxon02E3yMnh6debL8uS1xTnNog6PZIOkthYXntMT/OBLNKRMqv9Wu89kuw5knhc8dRJ4mV2wAgEIUNgBAGQobAKAMhQ0AUMbZwsNcp7YTcNRROOoWnOk8/Pz83D+xC2kDhr0Bx2sOIS7lPwLMHe5tf6OlfO9bMXL7Z463azomswHcTOg3Gwye8/NGh45dsQEAylDYAABlKGwAgDJkbG7EqKdmX+Lp25mcz7WR17g+t9LQcJri73pNGZRRMsfpJbbLly9f9pZF2ZV2WbRONNb79+9fvM5kYLLzGvXE8UNcsQEAylDYAABlKGwAgDIUNgBAGcLDN+LNmzcHX19CNIdrbNrHckSh11GN/KJxbilQfGnZJ3LPabvd7i1br9dnncOPZEK50TptUPiUsXplgsga9AEAN0lhAwCUobABAMpQ2AAAZQgP34g2lPv169e9de7v7881nWma8kHhCp2GGa+3O25vwDh6n/DwGHOHgHvHb3/fzWazt85SwsMZ2e69bVA36ij8+fPn1FinhIB7uWIDAJShsAEAylDYAABlyNjciDan8vvvv++t8+HDh71l0XoZUX6mnUOU84mWQSTKt6xWq9R6PetEPKV9jCU8OTz6LW/h941yN21+JsrFRLmbpXDFBgAoQ2EDAJShsAEAylDYAABl3O12u0vPAQBgCFdsAIAyFDYAQBkKGwCgDIUNAFCGwgYAKENhAwCU8V+ln4VseX7RWwAAAABJRU5ErkJggg==", 618 | "text/plain": [ 619 | "
" 620 | ] 621 | }, 622 | "metadata": { 623 | "needs_background": "light" 624 | }, 625 | "output_type": "display_data" 626 | }, 627 | { 628 | "name": "stdout", 629 | "output_type": "stream", 630 | "text": [ 631 | "Epoch 85\tTime:0.09 min\n", 632 | "\tTrain loss: 0.4417 Valid loss: 0.4498\n", 633 | "Epoch 90\tTime:0.09 min\n", 634 | "\tTrain loss: 0.4408 Valid loss: 0.4513\n", 635 | "Epoch 95\tTime:0.09 min\n", 636 | "\tTrain loss: 0.4402 Valid loss: 0.4489\n", 637 | "Epoch 100\tTime:0.09 min\n", 638 | "\tTrain loss: 0.4396 Valid loss: 0.4489\n" 639 | ] 640 | }, 641 | { 642 | "data": { 643 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjYAAACxCAYAAADXnPd8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAAsTAAALEwEAmpwYAAAMqElEQVR4nO3dMXbbSBIGYGjfHMVKZkI7n0nl3D6CFYln0BmsSDqCfQCl49wO7cS+CzeY93ZHzZJUajRIsPh9mWCgCYMgUQ/4WX223W4nAIAK/nPoHQAAGEVhAwCUobABAMpQ2AAAZShsAIAyFDYAQBm/PfWPZ2dni/0W/OrqamfZxcXFzrLz8/Nnx/r58+fOsvv7+51lNzc3yb17ue12ezZ6zCWPf9bvv//+4O/Pnz+ntnv37t3Osh8/fgzZp0jV49+6vr7eWRYd6/Z9i0Tv5fv377v2a/TxX/LYv379OrWsdXd3t8TuzLb2c7/9Xv/48WPXONH3fOb6EG0bXQuia0bGMZ37We3n4cOHD4uNPU3T9O3bt51l7ectWuexY++ODQBQhsIGAChDYQMAlPFkxqZXlJVpMzXROr1evXq1syzz7HXJzM0xyuQ3ouxGlJ2J8hvt+Nm8TkXRcfzjjz92lrXHP1onkskzRdmc79+/7yxr36foPFmLTFYmu137TD87dpQFOGXt93P0ff3r169h20Xa60GU8Yz05m6OXXuu936uel9vmqbp9vb2wd9v3rxJj+eODQBQhsIGAChDYQMAlDE7YxM994z6FETrjZJ9zrrZbB78HeVw2nWq+vTp086ybH4jI8pqtNmM6PXWnN/IivIzvf+vke9JprdNtE6bxVnLexQ9l4/6bbT9MLJ9bHr71mT7dJyKbK+ZVvZ7vUe0T1Hu5lQzNq2lz9/RGR53bACAMhQ2AEAZChsAoAyFDQBQxuzwcDYYtmQIK9PcaZp293XJQPPatKHQ3lBq1FQvGiszftQgLgo1Lzl55hJ6J5KMtCHs6LhmQsFZ0bFux4/et0PIBg4zzcaioPDIif9amUn/qloyFPyYdhLM6LoVLWuvEYfY92qy4f05gWV3bACAMhQ2AEAZChsAoAyFDQBQxiKdhzN6u1GOHCsTFpumGoGxkR1sW1GX4d7Xi7Y7tvBwJsw78v3Y9/FZ8lzahyiU2NuxuFdv12SWk/nBSYVrQUZvcLf3MxK9nvAwAMCksAEAClHYAABlzM7YXFxcdK3X27CvtxnfqRuVi1g6XxGNHzUFXIsoT3PsGZTnrKVBX/QMPtP8K7vdvq1hH6rqbc466rpV0dKZMxkbAIBJYQMAFKKwAQDKUNgAAGXMDg9ntaHfKIQVBYPbWVkjm80mNdYpB79GzgDN4UTvY2+DvmM/J7KzBI8aPzv2nNDjc/swcuxTlr0WZK4/p2Dk52ofzf/csQEAylDYAABlKGwAgDIUNgBAGS8OD/fO5t3qDQrPccrh4TV2w43Cq8ceaH1M5v+VCQFH62QDxaOObTSb+yHet94wbxRejGbRzmzXuw9m8j6s7A9OOE7u2AAAZShsAIAyFDYAQBkvzti0zyHfvn27s06Uw2mfaUZ5mswzzmjs3rFgCb15k5E5lZFjtXmd6+vrnXUOMeN3b+Yl29ivHX/pJmWZ8TXoo6LovM5k3B7jjg0AUIbCBgAoQ2EDAJShsAEAylhkdm/B3cPKBEdHzhJ9ytbY+HCO6Bxog8FrPk8yIcSRIeBe0X5GTfsqyDRevbi4OPg+UIc7NgBAGQobAKAMhQ0AUIbCBgAoY5HwcGYG8N6AcXa7UbOQH6Mo0NqGhdcaAK0Wxj02UVfhtZ4r2Q7CPeuM1r5mbwfhNQSfX+rm5ubB31FQeOnwcMsPXJ62hvNszj64YwMAlKGwAQDKUNgAAGUobACAMhYJD0facNjV1dXOOvf39zvL2pBXFDI7Pz+fuXe1fP78+dllhwjpfv/+fe+vyf+1IeAoKBydO8euDepGwd2lw5KZsHBmnTWEOueKugBHy0Z+r0fXFo7LS859d2wAgDIUNgBAGQobAKCMvWVsWpvNJrWsl2eqD71///7B358+fdpZZ9+5m3bW6CqihnZtdiX6v2dmZZ+zD22m5hTyNNM0TXd3d0PGzj7jH9l8r53xu3fsNYma47VN/KZpt8lqNnMT5XWi8fnHseS2ZGwAgJOksAEAylDYAABlKGwAgDIWCQ9H4bA20KWp3ulpA61rnTV6rqgRYRvcjZrjLR0eZp5DBHcrhIUzen/s0QaMp8nM3SOsMVAsPAwAnCSFDQBQhsIGAChDYQMAlLG3zsNtOEx4+LCi8GrbnXia+gOtUXg16nZ87KKgcG8HZ4Hf0zZqBvBTIihMxB0bAKAMhQ0AUIbCBgAo42Cze3NYmdmf6RPlbtiPahmUav8f6PWSz4I7NgBAGQobAKAMhQ0AUIbCBgAo42y73R56HwAAhnDHBgAoQ2EDAJShsAEAylDYAABlKGwAgDIUNgBAGQobAKAMhQ0AUIbCBgAoQ2EDAJShsAEAylDYAABlKGwAgDIUNgBAGb899Y9nZ2fbpV749evXO8s+fPjw7LJv3751j393d/fsWO06Wdvt9qxrwycsefyz2uMYHddI5n2Kxup9347t+GfP/1b2/Mwc28yxnqbd472P47/vc//Vq1c7y87Pz59d59evX8Ne7/7+vmv8tZ/7f/3114O/r6+vd9b5+++/nx0n2m4Njv3cz/j69euwsd68eTNsrMeOvTs2AEAZChsAoAyFDQBQxtl2+/jjvN5nfZn8QDZfsW9RxuDy8vLZ7db+nDsjc/yjHEg2q9GuF40V5TfaZdH7cWzHP5OnmaZ8pqxH7+ctem+PKWeQydNkffz4ce7u/E+UsWmXReus/dxv8zNt5uYxbaYmyuFksjlLO6ZzP9J7Lb69vd1ZlsmtZr/TMuvJ2AAA5SlsAIAyFDYAQBlP9rHJ6O3HkclSHEIm99Hb62ZNevMVUb6lt7dNheO4b3NyaGv4fB2zKN/y8+fP1LaZDE801sXFxbP7sCZRfqa3/0y7XW//G+aLvjtG9qOJtN91L/n+cscGAChDYQMAlKGwAQDKUNgAAGXMDg9nG41lRMHINjB0iCZ+a2gcONeo/8PS4dXMOUCfzCSY0bGucP7/WzSxZLSsDe5GNptNalmrtyFg1FzwVAgKz5dtxpf5zI/8rh79He+ODQBQhsIGAChDYQMAlKGwAQDKmB0e7rWGzsPZUKQOuU/LBk57O0m2AXXvx0O9nZ+z1hDgP4RM198odByFh6+urrr2oQ0L94aOD6ntRjwyBBx1OhYyfpmR1+LeH4iM5o4NAFCGwgYAKENhAwCUsUjGZmSjvXa7OTmcTIMy+Y2X631munQ2pIKRTfXW0KRxaZmmetkmd+16UU4mythEs3S360XbRdY+mzfrNmeG7H+LGvGOvIZr0AcA8AiFDQBQhsIGAChDYQMAlDE7PJwN/fSGmEYGFduxoqDw5eXlsNdbM7NoH6clPw9Zaw7YRwHfTFO7KFDcBnznNMeLAsWZfWgJE7MPvT9aWAt3bACAMhQ2AEAZChsAoAyFDQBQxiLh4WhZ1LkQeFom4DsyiL+WrsJRB+He4Gy20/AomaBwtvNwu+/Z7SqKZvKOnMLs3kt/TjM/9undh32Ejt2xAQDKUNgAAGUobACAMvbWoI/1G/kcdUlr3KeXWmsOLWq+l2nWtQ+9DewyM36PzK5EjfwyuZtoncy+n4ooY3N9fd29besUsjmRTPO9aJ1oWaZ55z6a/bljAwCUobABAMpQ2AAAZShsAIAyZoeHj1mFEOrSMiEy9mdOwHtUg74oIHh7e5vah6woSNuGcqOgcLRss9k8Oc405Wb3Htnob99NAw8pE8qNQsDtdn/++Wfq9aKg8DGFh9fwI4Le2b0zIeDo/xctywSRH+OODQBQhsIGAChDYQMAlKGwAQDKOFh4eGT3wX10MoRD6A38jjz/54T4ekWdgDPdgaN1bm5uhuzTIVSc3TvbQbhd78uXLzvrRIHfaPx2vexM4YfQ2703I/vjgyWvn/vocO+ODQBQhsIGAChDYQMAlKGwAQDKOOnOwzzUGyIbOaX9Kevt5NkbvNNFev0qhIV7O/ruO/CbCR2vxchwb2asfb/eXO7YAABlKGwAgDIUNgBAGavK2Ix85i8/8LTePMecsUY9W6363i55zLKiWbrbmXcvLy/3tTswRJufWWt2hjHcsQEAylDYAABlKGwAgDIUNgBAGYuEh7MN2zic3mZ8a5g1fQ37UFV0bNvwMDxmrU3u2tnDo/1c677zcu7YAABlKGwAgDIUNgBAGQobAKCMVXUeHqk3HFtVezxO+Vjwj+xnRKdhstYQts2EgNewnyzHHRsAoAyFDQBQhsIGACijbMamV4VsjmaIx2np8yxzXhzbuQ7QcscGAChDYQMAlKGwAQDKUNgAAGWcbbfbQ+8DAMAQ7tgAAGUobACAMhQ2AEAZChsAoAyFDQBQhsIGACjjv9430wzW34bdAAAAAElFTkSuQmCC", 644 | "text/plain": [ 645 | "
" 646 | ] 647 | }, 648 | "metadata": { 649 | "needs_background": "light" 650 | }, 651 | "output_type": "display_data" 652 | } 653 | ], 654 | "source": [ 655 | "for i in range(EPOCHS):\n", 656 | " prev_time = time.time()\n", 657 | " \n", 658 | " train_loss = trainer(pixel_cnn, train_loader, loss_fn, optim)\n", 659 | " valid_loss = validation(pixel_cnn, valid_loader, loss_fn)\n", 660 | " \n", 661 | " curr_time = time.time()\n", 662 | "\n", 663 | " if i == 0 or (i + 1) % 5 == 0:\n", 664 | " print(f'Epoch {i+1:3d}\\tTime:{(curr_time - prev_time) / 60:.2f} min')\n", 665 | " print(f'\\tTrain loss: {train_loss:.4f} Valid loss: {valid_loss:.4f}')\n", 666 | " \n", 667 | " if (i + 1) % 20 == 0:\n", 668 | " generated_imgs = image_generator.generate(pixel_cnn, 1.0)\n", 669 | " generated_imgs = np.transpose(generated_imgs, (0, 2, 3, 1))\n", 670 | " plot_imgs(generated_imgs)" 671 | ] 672 | }, 673 | { 674 | "cell_type": "markdown", 675 | "id": "7e832f09-a6d6-45c5-8d99-741cf35b4c86", 676 | "metadata": {}, 677 | "source": [ 678 | "## 5. Generate Images" 679 | ] 680 | }, 681 | { 682 | "cell_type": "code", 683 | "execution_count": 13, 684 | "id": "a8227496-8fda-4050-b511-3f33eebefed0", 685 | "metadata": {}, 686 | "outputs": [ 687 | { 688 | "data": { 689 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjYAAACxCAYAAADXnPd8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAAsTAAALEwEAmpwYAAADcElEQVR4nO3bQWoDMRBFwXTw/a/cOYHNLIRFHlXbASM+WjwEnt39AQAo+L19AACAU4QNAJAhbACADGEDAGQIGwAgQ9gAABmvTx9nxn/BH9rdOf2b9n/O/ned3t/2z7n7d7n797zb3osNAJAhbACADGEDAGQIGwAgQ9gAABnCBgDIEDYAQIawAQAyhA0AkCFsAIAMYQMAZAgbACBD2AAAGcIGAMgQNgBAhrABADKEDQCQIWwAgAxhAwBkCBsAIEPYAAAZwgYAyBA2AECGsAEAMoQNAJAhbACADGEDAGQIGwAgQ9gAABnCBgDIEDYAQIawAQAyhA0AkCFsAIAMYQMAZAgbACBD2AAAGcIGAMgQNgBAhrABADKEDQCQIWwAgAxhAwBkCBsAIEPYAAAZwgYAyBA2AECGsAEAMoQNAJAhbACADGEDAGQIGwAgQ9gAABnCBgDIEDYAQIawAQAyhA0AkCFsAIAMYQMAZAgbACBD2AAAGcIGAMgQNgBAhrABADKEDQCQIWwAgAxhAwBkCBsAIEPYAAAZwgYAyBA2AECGsAEAMoQNAJAxu3v7DAAAR3ixAQAyhA0AkCFsAIAMYQMAZAgbACBD2AAAGcIGAMgQNgBAhrABADKEDQCQIWwAgAxhAwBkCBsAIEPYAAAZr08fZ2a/dZD/bnfn9G/a/zn733V6f9s/5+7f5e7f8257LzYAQIawAQAyhA0AkCFsAIAMYQMAZAgbACBD2AAAGcIGAMgQNgBAhrABADKEDQCQIWwAgAxhAwBkCBsAIEPYAAAZwgYAyBA2AECGsAEAMoQNAJAhbACADGEDAGQIGwAgQ9gAABnCBgDIEDYAQIawAQAyhA0AkCFsAIAMYQMAZAgbACBD2AAAGcIGAMgQNgBAhrABADKEDQCQIWwAgAxhAwBkCBsAIEPYAAAZwgYAyBA2AECGsAEAMoQNAJAhbACADGEDAGQIGwAgQ9gAABnCBgDIEDYAQIawAQAyhA0AkCFsAIAMYQMAZAgbACBD2AAAGcIGAMgQNgBAhrABADKEDQCQIWwAgAxhAwBkCBsAIEPYAAAZwgYAyBA2AECGsAEAMoQNAJAhbACADGEDAGQIGwAgQ9gAABnCBgDIEDYAQMbs7u0zAAAc4cUGAMgQNgBAhrABADKEDQCQIWwAgAxhAwBk/AGvbTpbpi0yWQAAAABJRU5ErkJggg==", 690 | "text/plain": [ 691 | "
" 692 | ] 693 | }, 694 | "metadata": { 695 | "needs_background": "light" 696 | }, 697 | "output_type": "display_data" 698 | } 699 | ], 700 | "source": [ 701 | "# With temperature = 0.1\n", 702 | "generated_imgs = image_generator.generate(pixel_cnn, 0.1)\n", 703 | "generated_imgs = np.transpose(generated_imgs, (0, 2, 3, 1))\n", 704 | "plot_imgs(generated_imgs)" 705 | ] 706 | }, 707 | { 708 | "cell_type": "code", 709 | "execution_count": 14, 710 | "id": "358949c7-4d08-4fac-965d-c408142cbd75", 711 | "metadata": {}, 712 | "outputs": [ 713 | { 714 | "data": { 715 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjYAAACxCAYAAADXnPd8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAAsTAAALEwEAmpwYAAANUUlEQVR4nO3dPXbcxhIG0OE7XgqVSKEXYKd0Ti7BjMgtiN4CGUlLoHIztXI7lBJpL/Oi9zM9RbLUAwwaNfdmAwFgCwAxdYCP1Wfb7XYDAFDBv5YeAADAVBQ2AEAZChsAoAyFDQBQhsIGAChDYQMAlPHTS/94dnZW+m/Bf/31171ld3d3e8t++eWXnc+fP3+O1jmbalz/MeXxv7+/3/l8cXEx1a43T09Pe8tub28n23/Gdrsd+viP4Pz8fOdze01sNpvNmzdv9pa15zc633/++eekx3/OY395eZla9unTp659Rdr7ytevX1PbZbj2lzX18R/h2Lf3isj379+PMJKXPXfsPbEBAMpQ2AAAZShsAIAyzl6aUmGEd31T+uuvv3Y+t9mZAw39nrvNU9zc3HTtJ/teNcpqzOmUcwbR+/AoQ9XmnqLtovPbrhdlbC4uLobNGbQ5mGwuJuPdu3ep9b58+bLzOcrvZDI9kVO+9kcwcsamvQ9E94VMniYS5SiPnbuRsQEAylPYAABlKGwAgDJe7GOzZlGPmlbUjyaS7GOT2tdS2veomSwFx5V5/x2t05tnGqEPxSHevn27t+zq6iq1Xo8oT9NmZ7J6++bAc6J7Q2+Wsr2nfPv2bW+dKGMT9cZa4j7jiQ0AUIbCBgAoQ2EDAJShsAEAyjip8HBvCLhdr23099x2S4lCwJkwWG94OLuvtYdVRxA1x+s9rpnrZCRtCDgKCs+pNyicFYWcp5wsk9p679+9v/PRfX8UntgAAGUobACAMhQ2AEAZChsAoIyy4eHIH3/8sfM50514s9kPBkfh4bXpDYxlA2rR/oWHXxYFg+eUPZdRR9OedUaRnZF7TlEQOQpD393dHWE09bWdnqPjv/agdnTPneqPAaJ7d3bfDw8Pk4zhR3hiAwCUobABAMpQ2AAAZfxwxqbNpYyQN4myMpn8zEhN9UaQeReanS125OZNh+idLbd9R33sPE0km3nKXBfRrL6jaGfNXmIW7TbXE83uza62YWGUjYqOY2+GKpNnmrtJ49pFuZs2y3eMrKUnNgBAGQobAKAMhQ0AUIbCBgAo48Xw8AjNobJN9DKmCgtHgen3799Psu8pROGsNswbNVSLAq29M8ZWkPm/T9mYLnP8T7nJ4QiN9TKicbah01OeyTsK/GaWZUPfvQHf6Ly150l4+Mctcc/yxAYAKENhAwCUobABAMpQ2AAAZRwcHp4y3Btpg7q9XYY3m83m8+fPhw9opdpgahR67Q0KR12GK4Rco/9D24U36sobHce2I+cph7KPrQ2Fzh0A7d1/FCgeWTTeKIDbfo9k/59tmDoKGEeB4kwIOxrD2o7/MUX3q5Hv8Z7YAABlKGwAgDIUNgBAGT88u3erd3bv3mzOIbOJz50HGllmNuls077WyO9al5A5HtmMjWM7vSUa/a2tsVuUN2mzMpmGdpEoA9ObeemdJT0aw9rOUSvKOt7c3Ox8nvJ+MnLuxhMbAKAMhQ0AUIbCBgAoQ2EDAJRxcHi41yEh4Dn3VVUb6opCXm0Tuc0mFx6OQmvsao935riyKxP6Xcus2dkA7VKi8WXCw3MGs6Nw75SB39HPyWui+/ec+x75vu+JDQBQhsIGAChDYQMAlKGwAQDKWCw8zLKi8GrUeThjlG6Ta+KYvWxNoc1T0YZro7BtNNt2xqjne8Tg+Y+Y8z4TBYpHua95YgMAlKGwAQDKUNgAAGUobACAMoSHT1Q2PJzpWDxyB0pgfGsP6a5FdK/OdCzu3W4pntgAAGUobACAMhQ2AEAZMjYnKpuLOT8/79qOl0V5pvZYbzbTzgI+SvOsOU2Z1ehtGheNoZ31eu6Zqjk9c96bo3vHyHlLT2wAgDIUNgBAGQobAKAMhQ0AUIbwMP8VBb9OIXB6DG0wOBsKjgLFregcRdu1y0Y+t23Y9hBtCDgK944wu3TvzNiw2Ww2Dw8PXdtF94q2+d4ooeAsT2wAgDIUNgBAGQobAKAMhQ0AUIbw8ImaMjg6cgh1FG0YL3vMeo9tNlA8qig8PFUIOFpnyrByxOzVrEkmLJyd3XuJ7wdPbACAMhQ2AEAZChsAoAwZGziCERpcte+6o8zNKeSlMvmduT0+Ph7158HURrinPccTGwCgDIUNAFCGwgYAKENhAwCUcbbdbpceAwDAJDyxAQDKUNgAAGUobACAMhQ2AEAZChsAoAyFDQBQhsIGAChDYQMAlKGwAQDKUNgAAGUobACAMhQ2AEAZChsAoAyFDQBQxk8v/ePZ2dn2WAPZbDabt2/f7i27urp6dbvHx8e9ZV+/fp1kTFnb7fZs6n0e+/gfW3S+37179+p2nz592ltW4fjf3NzsLbu9vX11u2idp6enScaUNfXxH+Ha//3333c+f/jwIbXdx48fX13W7vu5ZdfX16/uu8K1H7m4uNj5HP1+vHnzpmvfU/7OrOnaj74rM/fcL1++7C27u7vbWzbK964nNgBAGQobAKAMhQ0AUMaLGZtji971vX//vmtf0fs/xhKd78vLy71lFc7l+fn5zufoHX+bKci6v7/fWxblBdplx87hjOznn39+ddk///zTva+///67a1xt7ibK2FQQXfvRdT2VaN/RsoeHh9nGsCaZHM5IPLEBAMpQ2AAAZShsAIAyjpaxaXuWRO/sMlmK6O/ko1xGpP0b/mP/zT27ot4IkbWdp+hdfdSDo8f3799T60U/r80xRBmbTN+ctWlzKlEGJiObsendtndca9PmzTab6X4/DhFd++3vW8VcWuY+HH1fR8tGuVd7YgMAlKGwAQDKUNgAAGUobACAMmYJD0ch4GjCw4zeMFIUKG7HEE2mGC3jeLKB4lFEQcjeRnu9P69XNM5o2ZoCk9FEknOGcqN9zxkUXjJg3F4bvdd5dA33TmY5t0zgfu16J8GMthvl+9MTGwCgDIUNAFCGwgYAKENhAwCUMUt4uDconA0o9e4rs+9Rwk+sVxSObDuYfvv27VjDeVYU2JwynLyEKcO1mdm9s8vm3O5Y2iDt3N2Cs122pxL9TrbL1v77Ef0xzpzfsUvxxAYAKENhAwCUobABAMo4OGMT5Wl6G/5Mqfe9IS/LzqQ+5/FvZ2lfg953821DsLmbm43aKO0QmaxM5OPHjzufP3z48Oo6z8nMMJ7J3SyZsTl2vqT9ednMTWacUaO93377LTewE9B+r2vQBwCwEIUNAFCGwgYAKENhAwCUcXB4OBsS7W3al9lP7wzg0din3P/aRP/3q6urV9eJjk8U8G2PdxREPpVjHcnMFH5IgPPYDc+WEIVyM037epvjRfuOZhhvQ8a9+8qGleewdLD8kGu/DQvf3t4eOpxVyn4Pt/fhkYPCEU9sAIAyFDYAQBkKGwCgDIUNAFDGweHhKEAUdSlsQ6jZDratQ8Kl7VinCjRX0Z6jyN3d3d6yKFgWnadMII1da59NeE5TzuQdyXQLvr6+3lsWBYPbfUUB45Gs5bqLAvHRLN1tWPgUgvRzGzlQ7IkNAFCGwgYAKENhAwCUcXDGJpJp2NabsTlElA/hf6JsVGYW9uxM7VNlmkbK5mSblrXv9EfNMIw6rsiUGZvMvqLmeJk8zXPLMvvqWWcKmezKCNdKNIb7+/u9ZTI1P6a9V2ezraM0uPXEBgAoQ2EDAJShsAEAylDYAABlzBIeHlVvIOpUZELAhzRlGin0C/8vE8qNmvH1NtpbcpbujCiUu/Ts3lmCws/L3oPb78bou2HkBree2AAAZShsAIAyFDYAQBkKGwCgjJMKD7fBqarh4UyoKzP79pQ/r6poJuE1iwKibZB0beHM3m69mYBvtO9oWW+X5Ha7Y3UeHkHmOhuh+/HI5uz2nu04vwRPbACAMhQ2AEAZChsAoAyFDQBQxkmFhyuKwmFXV1ep9Vp3d3ddY5iyo/Ao094fIgo9Pj09vbrdlEHITPDy4uJisp+3hGz33kxwN1ons102zNuulw0T94aODxVdP21Q/tjXa3Y7geLnzX1/jb4Llrh/e2IDAJShsAEAylDYAABlyNgUdHl52bVdNEt3dubujMfHx659r60BYNS0r83YRJmbTDbgkOZ47f7XnrGJRLmbdgbuKLcyd+O7dlzZTE+7XdUGfdG139ugby2zkB9Db1Pa9p6b3S7TyO8YmRtPbACAMhQ2AEAZChsAoAyFDQBQxkmFh0eejXRUbYgsCvxGx3XKgFiFBn2ZWcDnnjW73X80plMNXkbB3evr68n234Z+szOAjzS797Fndc/MKt/7u3Yq2nvz3N+Bo3zHemIDAJShsAEAylDYAABlKGwAgDJOKjx8yjKdJKMOv+2yqLNkNCv42gK/vaLw4v39fWq9pUXdj7MBzTXpnVl7BCN1Gm6vl5ubm4VG8j+3t7d7y6Lr+lS19+tMZ+Apf15E52EAgB+gsAEAylDYAABlyNisXOad5tyi2cTnbto3srVkUjTj+7H1lsi7jJSxWVo0kzcv683P9M7uPeUYDuGJDQBQhsIGAChDYQMAlKGwAQDKONtut0uPAQBgEp7YAABlKGwAgDIUNgBAGQobAKAMhQ0AUIbCBgAo498Ift7xN1wLOQAAAABJRU5ErkJggg==", 716 | "text/plain": [ 717 | "
" 718 | ] 719 | }, 720 | "metadata": { 721 | "needs_background": "light" 722 | }, 723 | "output_type": "display_data" 724 | } 725 | ], 726 | "source": [ 727 | "# With temperature = 1.0\n", 728 | "generated_imgs = image_generator.generate(pixel_cnn, 1.0)\n", 729 | "generated_imgs = np.transpose(generated_imgs, (0, 2, 3, 1))\n", 730 | "plot_imgs(generated_imgs)" 731 | ] 732 | } 733 | ], 734 | "metadata": { 735 | "kernelspec": { 736 | "display_name": "Python 3", 737 | "language": "python", 738 | "name": "python3" 739 | }, 740 | "language_info": { 741 | "codemirror_mode": { 742 | "name": "ipython", 743 | "version": 3 744 | }, 745 | "file_extension": ".py", 746 | "mimetype": "text/x-python", 747 | "name": "python", 748 | "nbconvert_exporter": "python", 749 | "pygments_lexer": "ipython3", 750 | "version": "3.9.0" 751 | } 752 | }, 753 | "nbformat": 4, 754 | "nbformat_minor": 5 755 | } 756 | --------------------------------------------------------------------------------