├── assets ├── ddpm.png ├── unet.png ├── mnist_4or9.png ├── mnist_images.png ├── sde_ddpmpp.png ├── sde_process.png ├── ddpm_optimize.png ├── time_condition.png ├── class_condition.png ├── classifier_model.png ├── ddpm_architecture.png ├── ddpm_beta_in_sde.png ├── latent_diffusion.png ├── refinenet_block.png ├── vae_architecture.png ├── langevin_experiments.png ├── refinenet_architecture.png ├── sde_predictor_corrector.png └── conditioned_normalization.png ├── Readme.md └── 01-vae.ipynb /assets/ddpm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsmatz/diffusion-tutorials/HEAD/assets/ddpm.png -------------------------------------------------------------------------------- /assets/unet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsmatz/diffusion-tutorials/HEAD/assets/unet.png -------------------------------------------------------------------------------- /assets/mnist_4or9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsmatz/diffusion-tutorials/HEAD/assets/mnist_4or9.png -------------------------------------------------------------------------------- /assets/mnist_images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsmatz/diffusion-tutorials/HEAD/assets/mnist_images.png -------------------------------------------------------------------------------- /assets/sde_ddpmpp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsmatz/diffusion-tutorials/HEAD/assets/sde_ddpmpp.png -------------------------------------------------------------------------------- /assets/sde_process.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsmatz/diffusion-tutorials/HEAD/assets/sde_process.png -------------------------------------------------------------------------------- /assets/ddpm_optimize.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsmatz/diffusion-tutorials/HEAD/assets/ddpm_optimize.png -------------------------------------------------------------------------------- /assets/time_condition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsmatz/diffusion-tutorials/HEAD/assets/time_condition.png -------------------------------------------------------------------------------- /assets/class_condition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsmatz/diffusion-tutorials/HEAD/assets/class_condition.png -------------------------------------------------------------------------------- /assets/classifier_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsmatz/diffusion-tutorials/HEAD/assets/classifier_model.png -------------------------------------------------------------------------------- /assets/ddpm_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsmatz/diffusion-tutorials/HEAD/assets/ddpm_architecture.png -------------------------------------------------------------------------------- /assets/ddpm_beta_in_sde.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsmatz/diffusion-tutorials/HEAD/assets/ddpm_beta_in_sde.png -------------------------------------------------------------------------------- /assets/latent_diffusion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsmatz/diffusion-tutorials/HEAD/assets/latent_diffusion.png -------------------------------------------------------------------------------- /assets/refinenet_block.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsmatz/diffusion-tutorials/HEAD/assets/refinenet_block.png -------------------------------------------------------------------------------- /assets/vae_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsmatz/diffusion-tutorials/HEAD/assets/vae_architecture.png -------------------------------------------------------------------------------- /assets/langevin_experiments.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsmatz/diffusion-tutorials/HEAD/assets/langevin_experiments.png -------------------------------------------------------------------------------- /assets/refinenet_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsmatz/diffusion-tutorials/HEAD/assets/refinenet_architecture.png -------------------------------------------------------------------------------- /assets/sde_predictor_corrector.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsmatz/diffusion-tutorials/HEAD/assets/sde_predictor_corrector.png -------------------------------------------------------------------------------- /assets/conditioned_normalization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsmatz/diffusion-tutorials/HEAD/assets/conditioned_normalization.png -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # Diffusion Models Tutorial (Python) 2 | 3 | This repository shows you the implementation of representative diffusion model algorithms and its guidance techniques from scratch in Python (PyTorch), with theoretical aspects behind code. 4 | 5 | ## Table of Contents 6 | 7 | - [Variational Auto-Encoder (VAE)](01-vae.ipynb) 8 | - [Denoising Diffusion Probabilistic Models (DDPM)](02-ddpm.ipynb) 9 | - [Score Matching with Langevin Dynamics (SMLD)](03-smld.ipynb) 10 | - [Score-based generative modeling with Stochastic Differential Equation (SDE)](04-sde.ipynb) 11 | - [Conditional Diffusion Models](05-class-conditional.ipynb) 12 | - [Classifier Diffusion Guidance](06-classifier-guidance.ipynb) 13 | - [Classifier-free Diffusion Guidance](07-classifier-free-guidance.ipynb) 14 | 15 | In the former part in this repository (tutorial 1 - 4), I'll focus on diffusion model's algorithms.
16 | In the latter part (tutorial 5 - 7), we proceed to steer diffusion models with external knowledge (such as, a class or a generic text prompt) and learn the fundamental of diffusion guidance. 17 | 18 | Historically image generation (including text-to-image models) has been developed and improved by adopting GAN-based methods and autoregressive methods in 2014 - 2021.
19 | In recent works, diffusion models have achieved great success, outperforming these previous approaches in the quality of image generation. (GAN-based approach, however, still has its own beneficial aspects. See below note.) 20 | 21 | > Note : The drawback of diffusion models (compared to GAN) is the computational efficiency. As you will see in the examples in this repository, diffusion models requires hundreds or thousands of iterations to generate an image - i.e., slow speed for image generation.
22 | > There exist works and challenges to improve the sampling speed in diffusion models. (See note for DDIM in [this notebook](./02-ddpm.ipynb).) 23 | 24 | *Tsuyoshi Matsuzaki @ Microsoft* 25 | -------------------------------------------------------------------------------- /01-vae.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "70657d36-7f9d-41b7-985f-51b171952c70", 6 | "metadata": {}, 7 | "source": [ 8 | "# Background : Variational Auto-Encoder (VAE) and Evidence Lower Bound (ELBO)\n", 9 | "\n", 10 | "Before jumping into diffusion models, I'll first introduce VAE (Variational Auto-encoder) for your beginning. By seeing variational auto-encoder, you will learn fundamentals of modeling with forward process / reverse process.\n", 11 | "\n", 12 | "Auto-encoder is a primitive model (neural network) to copy data (e.g, images) by encoder-decoder patterns, in which the encoder compresses data to the latent space and the decoder reconstructs it from the encoded objects.
\n", 13 | "In VAE, the optimal probability distributions to describe variables is used instead of using variables directly. VAE is considered as a part of probabilistic graph or variational Bayesian.
\n", 14 | "In this manner, VAE approximates the true distribution to draw samples.\n", 15 | "\n", 16 | "> Note : The chapter 10 in \"[Pattern Recognition and Machine Learning](https://www.microsoft.com/en-us/research/uploads/prod/2006/01/Bishop-Pattern-Recognition-and-Machine-Learning-2006.pdf)\" (Christopher M. Bishop, Microsoft) is a good reading for understanding variational Bayesian.
\n", 17 | "> See [here](https://github.com/tsmatz/gmm/) for Python example of variational Bayesian in probabilistic graph.\n", 18 | "\n", 19 | "*(back to [index](https://github.com/tsmatz/diffusion-tutorials/))*" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "id": "4634b50e-3d5e-4625-8733-c9eb779a8d1d", 25 | "metadata": {}, 26 | "source": [ 27 | "## Architecture\n", 28 | "\n", 29 | "In VAE, we suppose, the distribution of generated images $p(X)$ (where $X=\\{\\mathbf{x}^{(i)}\\}_{i=1}^N$ is a set of observed $N$ data points) is derived from the continuous latent variables $p(Z)$ (where $Z=\\{\\mathbf{z}^{(i)}\\}_{i=1}^N$ is a set of corresponding latent variables).
\n", 30 | "In this problem, however, the real distribution $p(X)$ and the hidden data points $Z=\\{\\mathbf{z}^{(i)}\\}_{i=1}^N$ are never known. And it's then intractable to solve this problem by EM algorithm (likelihood approach).
\n", 31 | "Therefore, in VAE, we apply variational Bayesian to get the optimal encoder/decoder between generated images and the latent variables, by assuming some prior distribution $p(Z)$. (See [here](https://github.com/tsmatz/gmm/) for the comparison between EM algorithm and variational Bayesian in GMM example.)\n", 32 | "\n", 33 | "For clarification, here I'll show you with handwriting digit's images (MNIST) example.\n", 34 | "\n", 35 | "![MNIST images](./assets/mnist_images.png)\n", 36 | "\n", 37 | "Firstly, we suppose that the image $\\mathbf{x}$ (vector) is generated by some latent variable $\\mathbf{z}$ (vector).
\n", 38 | "For example, each digit's image $\\mathbf{x}$ (which has 28 x 28 = 786 dimension) may have been generated by supposing some digit number. In this case, the latent variable $\\mathbf{z}$ can be the discrete digit number 0, 1, 2, ... , 9.\n", 39 | "\n", 40 | "The real latent space for generating image, however, may mostly be continuous and we then consider in stochastic space.
\n", 41 | "For example, the latent value of the following image may be the middle of digit 4 and 9.\n", 42 | "\n", 43 | "![image either 4 or 9](./assets/mnist_4or9.png)\n", 44 | "\n", 45 | "In our example, we assume the latent prior distribution $p(\\mathbf{z}) = \\mathcal{N}(\\mathbf{0},\\mathbf{I})$. (I note that here the prior distribution $p(\\mathbf{z})$ doesn't have any parameters.)\n", 46 | "\n", 47 | "> Note : In practice, the dimension of latent space should be large enough to represent the latent space.\n", 48 | "\n", 49 | "We suppose that a value $\\mathbf{x}$ is generated from some conditional distribution $p_{\\theta^{\\ast}}(\\mathbf{x}|\\mathbf{z})$, where $p_{\\theta}(\\mathbf{x}|\\mathbf{z})$ is determined by a neural network (or a function) with parameter $\\theta$, and $\\theta^{\\ast}$ is the optimal parameter.
\n", 50 | "In our example, we suppose that $p_{\\theta}(\\mathbf{x}|\\mathbf{z})$ is also a Gaussian distribution $p_{\\theta}(\\mathbf{x}|\\mathbf{z}) = \\mathcal{N}(\\mu_{\\theta}^{\\verb|dec|}(\\mathbf{z}),(\\sigma^{\\verb|dec|})^2)$ where $\\sigma^{\\verb|dec|}$ is some fixed value.
\n", 51 | "For example, when the latetnt variable $\\mathbf{z}$ represents the digit number 7, then $\\mu_{\\theta}^{\\verb|dec|}(\\mathbf{z})$ will be the vector image which represents the digit number 7.\n", 52 | "\n", 53 | "We also consider a conditional distribution $q_{\\phi}(\\mathbf{z}|\\mathbf{x})$, which distribution is also determined by a neural network (or a function) with parameter $\\phi$.
\n", 54 | "In our example, we also suppose that $q_{\\phi}(\\mathbf{z}|\\mathbf{x})$ is a Gaussian distribution $q_{\\phi}(\\mathbf{z}|\\mathbf{x}) = \\mathcal{N}(\\mu_{\\phi}^{\\verb|enc|}(\\mathbf{x}),\\verb|diag|((\\sigma_{\\phi}^{\\verb|enc|}(\\mathbf{x}))^2))$.
\n", 55 | "For example, when the image vector $\\mathbf{x}$ is above (the digit's image of the middle of digit 4 and 9), then $q_{\\phi}(\\mathbf{z}|\\mathbf{x})$ represents the latent distribution of the middle of digit 4 and 9.\n", 56 | "\n", 57 | "$q_{\\phi}(\\mathbf{z}|\\mathbf{x})$ is called a encoder, and $p_{\\theta}(\\mathbf{x}|\\mathbf{z})$ is called a decoder.\n", 58 | "\n", 59 | "![VAE architecture](./assets/vae_architecture.png)\n", 60 | "\n", 61 | "Our goal is to optimize both parameters $\\phi$ and $\\theta$ to maximize the log likelihood $\\log p(\\mathbf{x})$." 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "id": "06b0b4ba-565a-40a4-ba5e-6593f2f9962b", 67 | "metadata": {}, 68 | "source": [ 69 | "## Method\n", 70 | "\n", 71 | "Let's briefly follow the methods to solve optimal $\\phi$ and $\\theta$ along with the original paper [[Diederik P. Kingma and Max Welling, 2013](https://arxiv.org/pdf/1312.6114)].\n", 72 | "\n", 73 | "Firstly, the log likelihood of observation $p(\\mathbf{x})$ is decomposed as : \n", 74 | "\n", 75 | "$\\displaystyle \\log p(\\mathbf{x})$\n", 76 | "\n", 77 | "$\\displaystyle = \\int \\log p(\\mathbf{x}) q_{\\phi}(\\mathbf{z}|\\mathbf{x}) d\\mathbf{z} $       (because $\\int q_{\\phi}(\\mathbf{z}|\\mathbf{x}) d\\mathbf{z} = 1$)\n", 78 | "\n", 79 | "$\\displaystyle = \\mathbb{E}_{q_{\\phi}(\\mathbf{z}|\\mathbf{x})}[\\log p(\\mathbf{x})] $\n", 80 | "\n", 81 | "$\\displaystyle = \\mathbb{E}_{q_{\\phi}(\\mathbf{z}|\\mathbf{x})}\\left[\\log \\frac{p(\\mathbf{x}, \\mathbf{z})}{q_{\\phi}(\\mathbf{z}|\\mathbf{x})}\\right] + \\mathbb{E}_{q_{\\phi}(\\mathbf{z}|\\mathbf{x})}\\left[\\log\\frac{q_{\\phi}(\\mathbf{z}|\\mathbf{x})}{p(\\mathbf{z}|\\mathbf{x})}\\right] $\n", 82 | "\n", 83 | "$\\displaystyle = \\mathbb{E}_{q_{\\phi}(\\mathbf{z}|\\mathbf{x})}\\left[\\log \\frac{p(\\mathbf{x}, \\mathbf{z})}{q_{\\phi}(\\mathbf{z}|\\mathbf{x})}\\right] + D_{KL}\\left(q_{\\phi}(\\mathbf{z}|\\mathbf{x})\\|p(\\mathbf{z}|\\mathbf{x})\\right)$\n", 84 | "\n", 85 | "where $p(\\mathbf{z}|\\mathbf{x})$ is true distribution and $D_{KL}(q\\|p)$ is KL-divergence between $q$ and $p$.\n", 86 | "\n", 87 | "When $q_{\\phi}(\\mathbf{z}|\\mathbf{x})$ goes to true distribution, $D_{KL}\\left(q_{\\phi}(\\mathbf{z}|\\mathbf{x})\\|p(\\mathbf{z}|\\mathbf{x})\\right)$ goes to zero and then $\\mathbb{E}_{q_{\\phi}(\\mathbf{z}|\\mathbf{x})}\\left[\\log \\frac{p(\\mathbf{x}, \\mathbf{z})}{q_{\\phi}(\\mathbf{z}|\\mathbf{x})}\\right]$ is a lower-bound for $\\log p(\\mathbf{x})$.\n", 88 | "\n", 89 | "$\\displaystyle \\log p(\\mathbf{x}) \\geq \\mathbb{E}_{q_{\\phi}(\\mathbf{z}|\\mathbf{x})}\\left[\\log \\frac{p(\\mathbf{x}, \\mathbf{z})}{q_{\\phi}(\\mathbf{z}|\\mathbf{x})}\\right]$\n", 90 | "\n", 91 | "This right-hand side expectation (lower-bound) is called Evidence Lower Bound (ELBO) (see [here](https://github.com/tsmatz/gmm/blob/master/02-gmm-variational-inference.ipynb)), and our goal is to maximize this lower-bound in order to maximize $\\log p(\\mathbf{x})$ on the observed data points $\\{\\mathbf{x}^{(i)}\\}_{i=1}^N$.\n", 92 | "\n", 93 | "Because $p(\\mathbf{x}, \\mathbf{z}) = p(\\mathbf{x}|\\mathbf{z}) p(\\mathbf{z})$, Evidence Lower Bound (ELBO) is then decomposed as follows. :\n", 94 | "\n", 95 | "$\\displaystyle \\mathbb{E}_{q_{\\phi}(\\mathbf{z}|\\mathbf{x})}\\left[\\log \\frac{p(\\mathbf{x}, \\mathbf{z})}{q_{\\phi}(\\mathbf{z}|\\mathbf{x})}\\right] = \\mathbb{E}_{q_{\\phi}(\\mathbf{z}|\\mathbf{x})}[\\log p_{\\theta}(\\mathbf{x}|\\mathbf{z})] - D_{KL} \\left( q_{\\phi}(\\mathbf{z}|\\mathbf{x}) \\| p(\\mathbf{z}) \\right)$\n", 96 | "\n", 97 | "The second term (KL-divergence term) $D_{KL} \\left( q_{\\phi}(\\mathbf{z}|\\mathbf{x}) \\| p(\\mathbf{z}) \\right)$ represents how the distribution $q_{\\phi}(\\mathbf{z}|\\mathbf{x})$ is similar to the marginal distribution $p(\\mathbf{z})$, and this can be easily integrated analytically, because the forms of both $q_{\\phi}(\\mathbf{z}|\\mathbf{x})$ and $p(\\mathbf{z})$ are known Gaussian distributions in this example. (See below description in the implementation.)\n", 98 | "\n", 99 | "For the first term $\\mathbb{E}_{q_{\\phi}(\\mathbf{z}|\\mathbf{x})}[\\log p_{\\theta}(\\mathbf{x}|\\mathbf{z})]$, we should maximize the possibility $p_{\\theta}(\\mathbf{x}|\\mathbf{z})$ for the given input $\\mathbf{x}$, and we can then optimize by minimizing the difference between the reconstructed image $\\hat{\\mathbf{x}}$ (= $\\mu_{\\theta}^{\\verb|dec|}(\\mathbf{z})$) and the input image $\\mathbf{x}$.
\n", 100 | "In practice, this expectation computation requires estimation by sampling from $q_{\\phi}(\\mathbf{z}|\\mathbf{x})$. :\n", 101 | "\n", 102 | "$\\displaystyle \\mathbb{E}_{q_{\\phi}(\\mathbf{z}|\\mathbf{x})}[\\log p_{\\theta}(\\mathbf{x}|\\mathbf{z})] \\simeq \\frac{1}{M} \\sum_{i=1}^M \\left( \\frac{1}{L} \\sum_{l=1}^L \\left( \\log p_{\\theta}(\\mathbf{x}^{(i)} | \\mathbf{z}^{(i,l)}) \\right) \\right)$\n", 103 | "\n", 104 | "where $M$ is the number of minibatch $\\{\\mathbf{x}^{(i)}\\}_{i=1}^M$ and $L$ is the number of samples $\\mathbf{z}^{(i,l)} \\sim q_{\\phi}(\\mathbf{z}|\\mathbf{x^{(i)}})$ in each $\\mathbf{x}^{(i)}$.\n", 105 | "\n", 106 | "But optimizing this term is problematic, because $q_{\\phi}(\\mathbf{z}|\\mathbf{x})$ depends on parameter $\\phi$ which should be optimized, and the gradient $\\nabla_{\\phi}\\mathbb{E}_{q_{\\phi}(\\mathbf{z}|\\mathbf{x})}[\\cdot]$ by sampling is then intractable.\n", 107 | "\n", 108 | "In order to solve this problem, we now apply **reparameterization** (change of parameters) trick as follows.\n", 109 | "\n", 110 | "Let $\\mathbf{z}$ be a continuous random variable with $\\mathbf{z} \\sim q_{\\phi}(\\mathbf{z}|\\mathbf{x})$, and we assume that $\\mathbf{z}$ can be expressed as $\\mathbf{z} = g_{\\phi}(\\epsilon, \\mathbf{x})$, where $\\epsilon \\sim p(\\epsilon)$ is the distribution independent of parameter $\\phi$, and $g_{\\phi}(\\cdot)$ is some vector-valued function parameterized by $\\phi$.
\n", 111 | "Under this assumption, we also assume that the interval satisfies $q_{\\phi}(\\mathbf{z}|\\mathbf{x}) dz_0 dz_1 \\ldots dz_{n-1} = p(\\epsilon) d\\epsilon_0 d\\epsilon_1 \\ldots d\\epsilon_{n-1}$, where $\\mathbf{z} = (z_0, z_1, \\ldots, z_{n-1})$ and $\\epsilon = (\\epsilon_0, \\epsilon_1, \\ldots, \\epsilon_{n-1})$.\n", 112 | "\n", 113 | "> Note : In this case, it also satisfies $q_{\\phi}(\\mathbf{x}|\\mathbf{z}) \\cdot \\left| \\det \\left( \\frac{\\partial \\mathbf{z}}{\\partial \\mathbf{\\epsilon}} \\right) \\right| = p(\\mathbf{\\epsilon})$, where $\\frac{\\partial \\mathbf{z}}{\\partial \\mathbf{\\epsilon}}$ is the Jacobian. (This is a dual condition.)
\n", 114 | "> See [here](https://tutorial.math.lamar.edu/classes/calciii/changeofvariables.aspx) for the change of variables and the Jacobian.\n", 115 | "\n", 116 | "In our example, the distribution $q_{\\phi}(\\mathbf{z}|\\mathbf{x})$ is $\\mathcal{N}(\\mu_{\\phi}^{\\verb|enc|}(\\mathbf{x}),\\verb|diag|((\\sigma_{\\phi}^{\\verb|enc|}(\\mathbf{x}))^2))$ (see above), and then we can define $p(\\epsilon)$ and $g_{\\phi}(\\epsilon, \\mathbf{x})$ to satisfy above conditions as follows. :\n", 117 | "\n", 118 | "---\n", 119 | "\n", 120 | "reparameterization\n", 121 | "\n", 122 | "$\\displaystyle \\mathbf{z} = g_{\\phi}(\\epsilon, \\mathbf{x}) \\stackrel{\\mathrm{def}}{=} \\mu_{\\phi}^{\\verb|enc|}(\\mathbf{x}) + \\verb|diag|(\\sigma_{\\phi}^{\\verb|enc|}(\\mathbf{x})) \\odot \\epsilon$\n", 123 | "\n", 124 | "where $\\epsilon \\sim \\mathcal{N}(\\mathbf{0},\\mathbf{I})$, and the second term $\\verb|diag|(\\sigma_{\\phi}^{\\verb|enc|}(\\mathbf{x})) \\odot \\epsilon$ means element-wise multiplication between diagonal elements of $\\sigma_{\\phi}^{\\verb|enc|}(\\mathbf{x})$ and the vector $\\epsilon$.\n", 125 | "\n", 126 | "---\n", 127 | "\n", 128 | "In this assumption, we can construct the differentiable estimator by sampling, because the following equation holds. :\n", 129 | "\n", 130 | "$\\displaystyle \\mathbb{E}_{q_{\\phi}(\\mathbf{z}|\\mathbf{x})}[f(\\mathbf{z})]$\n", 131 | "\n", 132 | "$\\displaystyle =\\int q_{\\phi}(\\mathbf{z}|\\mathbf{x}) f(\\mathbf{z}) d \\mathbf{z}$\n", 133 | "\n", 134 | "$\\displaystyle =\\int p(\\epsilon) f(\\mathbf{z}) d \\epsilon$\n", 135 | "\n", 136 | "$\\displaystyle =\\int p(\\epsilon) f(g_{\\phi}(\\epsilon, \\mathbf{x})) d \\epsilon$\n", 137 | "\n", 138 | "$\\displaystyle =\\mathbb{E}_{p(\\epsilon)}[f(g_{\\phi}(\\epsilon, \\mathbf{x}))]$\n", 139 | "\n", 140 | "Now we can minimize the difference between the reconstructed image $\\hat{\\mathbf{x}}$ (i.e, $\\mu_{\\theta}^{\\verb|dec|}(\\mathbf{z})$) and the input image $\\mathbf{x}$ by the gradient method, and we can then finally maximize above ELBO (Evidence Lower Bound) by optimizing $\\phi$ and $\\theta$." 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "id": "6f31e8e3-32e1-4dfe-b66d-d1e388de1104", 146 | "metadata": {}, 147 | "source": [ 148 | "## Implementation" 149 | ] 150 | }, 151 | { 152 | "cell_type": "markdown", 153 | "id": "1a11d841-bae2-41e6-8d06-30f9997f8234", 154 | "metadata": {}, 155 | "source": [ 156 | "### Build and train VAE" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "id": "e3b70852-6571-4ba0-84c7-08d0ea86ca48", 162 | "metadata": {}, 163 | "source": [ 164 | "Before we start, we need to install the required packages." 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "id": "4aadf8c8-af1d-4df8-8690-219fa984b93d", 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "!pip install torch torchvision matplotlib" 175 | ] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "id": "7b773d1e-30a2-4cf6-b844-2d417bb52d31", 180 | "metadata": {}, 181 | "source": [ 182 | "We load handwriting digit's images (MNIST) dataset and dataloader.
\n", 183 | "Each batch has shape ```[batch_size, 1, 28, 28]```." 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 1, 189 | "id": "66dd3529-1ddc-4580-9086-6614429284b5", 190 | "metadata": {}, 191 | "outputs": [ 192 | { 193 | "data": { 194 | "image/png": "", 195 | "text/plain": [ 196 | "
" 197 | ] 198 | }, 199 | "metadata": {}, 200 | "output_type": "display_data" 201 | }, 202 | { 203 | "data": { 204 | "image/png": "", 205 | "text/plain": [ 206 | "
" 207 | ] 208 | }, 209 | "metadata": {}, 210 | "output_type": "display_data" 211 | }, 212 | { 213 | "data": { 214 | "image/png": "", 215 | "text/plain": [ 216 | "
" 217 | ] 218 | }, 219 | "metadata": {}, 220 | "output_type": "display_data" 221 | } 222 | ], 223 | "source": [ 224 | "import torch\n", 225 | "from torchvision import datasets, transforms\n", 226 | "import matplotlib.pyplot as plt\n", 227 | "\n", 228 | "batch_size = 100\n", 229 | "\n", 230 | "dataset = datasets.MNIST(\n", 231 | " \"./data\",\n", 232 | " train=True,\n", 233 | " download=True,\n", 234 | " transform=transforms.Compose([transforms.ToTensor()]))\n", 235 | "\n", 236 | "loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)\n", 237 | "\n", 238 | "# show examples\n", 239 | "for _, (data, _) in enumerate(loader):\n", 240 | " images = data[:3]\n", 241 | " break\n", 242 | "for i in images:\n", 243 | " plt.imshow(i[0].numpy())\n", 244 | " plt.show()" 245 | ] 246 | }, 247 | { 248 | "cell_type": "markdown", 249 | "id": "62e32935-5d3e-4290-9d4b-650b0fb72f76", 250 | "metadata": {}, 251 | "source": [ 252 | "Now we build an encoder network, which outputs $\\mu_{\\phi}^{\\verb|enc|}(\\mathbf{x})$ and the diagonal elements in $\\verb|diag|((\\sigma_{\\phi}^{\\verb|enc|}(\\mathbf{x}))^2)$ in the following distribution.\n", 253 | "\n", 254 | "$q_{\\phi}(\\mathbf{z}|\\mathbf{x}) = \\mathcal{N}(\\mu_{\\phi}^{\\verb|enc|}(\\mathbf{x}),\\verb|diag|((\\sigma_{\\phi}^{\\verb|enc|}(\\mathbf{x}))^2))$" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": 2, 260 | "id": "1e02b36f-c3d7-48af-81b9-71970704f603", 261 | "metadata": {}, 262 | "outputs": [], 263 | "source": [ 264 | "import torch.nn as nn\n", 265 | "from torch.nn import functional as F\n", 266 | "\n", 267 | "latent_dim = 32\n", 268 | "\n", 269 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 270 | "\n", 271 | "class EncoderNet(nn.Module):\n", 272 | " def __init__(self, latent_dim):\n", 273 | " super().__init__()\n", 274 | " self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1)\n", 275 | " self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)\n", 276 | " cur_dim = 64 * 7 * 7\n", 277 | " self.fc_mean = nn.Linear(cur_dim, latent_dim)\n", 278 | " self.fc_logvar = nn.Linear(cur_dim, latent_dim)\n", 279 | "\n", 280 | " def forward(self, x):\n", 281 | " x = self.conv1(x) # --> size [batch_size, 32, 14, 14]\n", 282 | " x = F.relu(x)\n", 283 | " x = self.conv2(x) # --> size [batch_size, 64, 7, 7]\n", 284 | " x = F.relu(x)\n", 285 | " x = torch.flatten(x, 1) # --> size [batch_size, 64*7*7]\n", 286 | "\n", 287 | " mean = self.fc_mean(x)\n", 288 | " logvar = self.fc_logvar(x)\n", 289 | " return mean, logvar\n", 290 | "\n", 291 | "#\n", 292 | "# Generate a model for encoder distribution\n", 293 | "#\n", 294 | "encoder_net = EncoderNet(latent_dim=latent_dim).to(device)" 295 | ] 296 | }, 297 | { 298 | "cell_type": "markdown", 299 | "id": "cb5d5de9-9258-4bc0-9f50-ade3063b2174", 300 | "metadata": {}, 301 | "source": [ 302 | "We also build a decoder network.
\n", 303 | "As we saw above, $p_{\\theta}(\\mathbf{x}|\\mathbf{z}) = \\mathcal{N}(\\mu_{\\theta}^{\\verb|dec|}(\\mathbf{z}),(\\sigma^{\\verb|dec|})^2)$ (where $\\sigma$ is some constant) and we need to get only mean value in this network." 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": 3, 309 | "id": "8ddf8e49-6eb8-406f-bb04-4245ab5b8ccc", 310 | "metadata": {}, 311 | "outputs": [], 312 | "source": [ 313 | "class DecoderNet(nn.Module):\n", 314 | " def __init__(self, latent_dim):\n", 315 | " super().__init__()\n", 316 | " self.projection = nn.Linear(latent_dim, 64 * 7 * 7)\n", 317 | " self.convtrans1 = nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2)\n", 318 | " self.convtrans2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2)\n", 319 | " self.convtrans3 = nn.ConvTranspose2d(32, 1, kernel_size=3, stride=1)\n", 320 | "\n", 321 | " def forward(self, x):\n", 322 | " x = self.projection(x) # --> size [batch_size, 64*7*7]\n", 323 | " x = torch.reshape(x, (-1, 64, 7, 7)) # --> size [batch_size, 64, 7, 7]\n", 324 | " x = self.convtrans1(x) # --> size [batch_size, 64, 15, 15]\n", 325 | " x = x[:,:,:-1,:-1] # --> size [batch_size, 64, 14, 14] (cropping)\n", 326 | " x = F.relu(x)\n", 327 | " x = self.convtrans2(x) # --> size [batch_size, 32, 29, 29]\n", 328 | " x = x[:,:,:-1,:-1] # --> size [batch_size, 32, 28, 28] (cropping)\n", 329 | " x = F.relu(x)\n", 330 | " x = self.convtrans3(x) # --> size [batch_size, 1, 30, 30]\n", 331 | " x = x[:,:,:-2,:-2] # --> size [batch_size, 1, 28, 28] (cropping)\n", 332 | " return torch.sigmoid(x)\n", 333 | "\n", 334 | "#\n", 335 | "# Generate a decoder model\n", 336 | "#\n", 337 | "decoder_net = DecoderNet(latent_dim=latent_dim).to(device)" 338 | ] 339 | }, 340 | { 341 | "cell_type": "markdown", 342 | "id": "c5983fdb-e9c4-43a5-8ccb-e5c5c1fd152b", 343 | "metadata": {}, 344 | "source": [ 345 | "We compute KL-divergence (KL-loss) : $D_{KL} \\left( q_{\\phi}(\\mathbf{z}|\\mathbf{x}) \\| p(\\mathbf{z}) \\right)$.\n", 346 | "\n", 347 | "This value is analytically computed as follows. (See \"Appendix B\" in [original paper](https://arxiv.org/pdf/1312.6114).)\n", 348 | "\n", 349 | "$\\displaystyle D_{KL} \\left( q_{\\phi}(\\mathbf{z}|\\mathbf{x}) \\| p(\\mathbf{z}) \\right)$\n", 350 | "\n", 351 | "$\\displaystyle = \\int q_{\\phi}(\\mathbf{z}|\\mathbf{x}) \\log q_{\\phi}(\\mathbf{z}|\\mathbf{x}) d\\mathbf{z} - \\int q_{\\phi}(\\mathbf{z}|\\mathbf{x}) \\log p(\\mathbf{z}) d\\mathbf{z} $\n", 352 | "\n", 353 | "$\\displaystyle = \\int \\mathcal{N}(\\mu(\\phi),\\verb|diag|(\\sigma(\\phi)^2)) \\log \\mathcal{N}(\\mu(\\phi),\\verb|diag|(\\sigma(\\phi)^2)) d\\mathbf{z} - \\int \\mathcal{N}(\\mu(\\phi),\\verb|diag|(\\sigma(\\phi)^2)) \\log \\mathcal{N}(\\mathbf{0},\\mathbf{I}) d\\mathbf{z} $\n", 354 | "\n", 355 | "$\\displaystyle = \\left( -\\frac{J}{2} \\log(2\\pi) - \\frac{1}{2} \\sum_{j=1}^J (1 + \\log (\\sigma(\\phi)_j^2)) \\right) - \\left( -\\frac{J}{2} \\log(2\\pi) - \\frac{1}{2} \\sum_{j=1}^J (\\mu(\\phi)_j^2 + \\sigma(\\phi)_j^2) \\right)$\n", 356 | "\n", 357 | "$\\displaystyle = -\\frac{1}{2} \\sum_{j=1}^J \\left( 1 + \\log (\\sigma(\\phi)_j^2) - \\mu(\\phi)_j^2 - \\sigma(\\phi)_j^2 \\right)$\n", 358 | "\n", 359 | "where $J$ is the dimensionality of latent space $\\mathbf{z}$, and $\\mu(\\phi)_j$, $\\sigma(\\phi)_j$ are j-th elements of each vector." 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": 4, 365 | "id": "564a2422-fc0b-4d10-b391-f23ba4ab0ee1", 366 | "metadata": {}, 367 | "outputs": [], 368 | "source": [ 369 | "def kl(mean, logvar):\n", 370 | " elements = logvar + 1.0 - mean**2 - torch.exp(logvar)\n", 371 | " return torch.sum(elements, dim=-1) * (-0.5)" 372 | ] 373 | }, 374 | { 375 | "cell_type": "markdown", 376 | "id": "1dd7e7ee-4b5e-42c9-8679-4bcb80c039c8", 377 | "metadata": {}, 378 | "source": [ 379 | "Get samples from $q_{\\phi}(\\mathbf{z}|\\mathbf{x})$, by applying reparameterization as follows. :\n", 380 | "\n", 381 | "$\\displaystyle \\mathbf{z} = \\mu_{\\phi} + \\sigma_{\\phi}^{\\verb|diag|} \\odot \\epsilon$\n", 382 | "\n", 383 | "where $\\epsilon \\sim \\mathcal{N}(\\mathbf{0},\\mathbf{I})$, and $\\sigma_{\\phi}^{\\verb|diag|} \\odot \\epsilon$ means element-wise multiplication between $\\sigma_{\\phi}^{\\verb|diag|}$ (diagonal elements of $\\sigma_{\\phi}$) and $\\epsilon$." 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": 5, 389 | "id": "735e0c98-46aa-4086-ad83-a3a2c422d23b", 390 | "metadata": {}, 391 | "outputs": [], 392 | "source": [ 393 | "# expand tensor, such as : [batch_size, latent_dim] --> [batch_size * num_sampling, latent_dim]\n", 394 | "def expand_for_sampling(x, num_sampling):\n", 395 | " latent_dim = x.shape[1]\n", 396 | "\n", 397 | " x_sampling = x.unsqueeze(dim=1)\n", 398 | " x_sampling = x_sampling.expand(-1, num_sampling, -1)\n", 399 | " return torch.reshape(x_sampling, (-1, latent_dim))\n", 400 | "\n", 401 | "def reparameter_sampling(mean, logvar, num_sampling):\n", 402 | " # expand mean : [batch_size, latent_dim] --> [batch_size * num_sampling, latent_dim]\n", 403 | " mean_sampling = torch.repeat_interleave(mean, num_sampling, dim=0)\n", 404 | "\n", 405 | " # expand logvar : [batch_size, latent_dim] --> [batch_size * num_sampling, latent_dim]\n", 406 | " logvar_sampling = torch.repeat_interleave(logvar, num_sampling, dim=0)\n", 407 | "\n", 408 | " # get epsilon\n", 409 | " epsilon_sampling = torch.randn_like(mean_sampling).to(device)\n", 410 | "\n", 411 | " return torch.mul(torch.exp(logvar_sampling*0.5), epsilon_sampling) + mean_sampling" 412 | ] 413 | }, 414 | { 415 | "cell_type": "markdown", 416 | "id": "e42212ad-9a51-43c2-9e9d-9c37c2c2926c", 417 | "metadata": {}, 418 | "source": [ 419 | "Now we train models to optimize $\\phi$ and $\\theta$.\n", 420 | "\n", 421 | "As I have mentioned [here](https://tsmatz.wordpress.com/2017/08/30/regression-in-machine-learning-math-for-beginners/), log likelihood in Gaussian is equivalent to MSE (mean square error) loss. Thus, you can also use MSE loss function instead of using ```gaussian_nll_loss()```, but please take care for the scale between reconstruction loss and KL loss in such case." 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": 6, 427 | "id": "12a5f647-e535-411e-a411-cdd737f15532", 428 | "metadata": {}, 429 | "outputs": [ 430 | { 431 | "name": "stdout", 432 | "output_type": "stream", 433 | "text": [ 434 | "epoch0 (iter599) - loss -46608.4448\n", 435 | "epoch1 (iter599) - loss -49586.5290\n", 436 | "epoch2 (iter599) - loss -49813.8067\n", 437 | "epoch3 (iter599) - loss -49931.7259\n", 438 | "epoch4 (iter599) - loss -50013.8970\n", 439 | "epoch5 (iter599) - loss -50073.4233\n", 440 | "epoch6 (iter599) - loss -50121.1562\n", 441 | "epoch7 (iter599) - loss -50160.6847\n", 442 | "epoch8 (iter599) - loss -50191.2306\n", 443 | "epoch9 (iter599) - loss -50219.7375\n", 444 | "epoch10 (iter599) - loss -50242.0084\n", 445 | "epoch11 (iter599) - loss -50263.6263\n", 446 | "epoch12 (iter599) - loss -50283.5339\n", 447 | "epoch13 (iter599) - loss -50297.9388\n", 448 | "epoch14 (iter599) - loss -50314.2273\n", 449 | "epoch15 (iter599) - loss -50326.3326\n", 450 | "epoch16 (iter599) - loss -50337.4770\n", 451 | "epoch17 (iter599) - loss -50349.7116\n", 452 | "epoch18 (iter599) - loss -50360.9656\n", 453 | "epoch19 (iter599) - loss -50370.1686\n", 454 | "Done\n" 455 | ] 456 | } 457 | ], 458 | "source": [ 459 | "num_sampling = 5 # the number of sampling\n", 460 | "num_epochs = 20\n", 461 | "\n", 462 | "opt = torch.optim.AdamW(list(encoder_net.parameters()) + list(decoder_net.parameters()), lr=0.001)\n", 463 | "\n", 464 | "loss_records = []\n", 465 | "for epoch_idx in range(num_epochs):\n", 466 | " for batch_idx, (data, _) in enumerate(loader):\n", 467 | " x = data.to(device)\n", 468 | "\n", 469 | " opt.zero_grad()\n", 470 | "\n", 471 | " # get mean and logvar\n", 472 | " z_mean, z_logvar = encoder_net(x)\n", 473 | " \n", 474 | " # get KL-div (KL loss)\n", 475 | " kl_loss_batch = kl(z_mean, z_logvar) # shape: [batch_size,]\n", 476 | " kl_loss = torch.sum(kl_loss_batch)\n", 477 | " \n", 478 | " # sampling by reparameterization\n", 479 | " z_samples = reparameter_sampling(z_mean, z_logvar, num_sampling)\n", 480 | " \n", 481 | " # reconstruct x by decoder\n", 482 | " decoded_x_samples = decoder_net(z_samples)\n", 483 | " decoded_x_samples = torch.reshape(decoded_x_samples, (-1, 28*28))\n", 484 | " \n", 485 | " # get BCE loss between the input x and the reconstructed x\n", 486 | " x_samples = torch.repeat_interleave(x, num_sampling, dim=0) # expand to shape: [batch_size * num_sampling, 1, 28, 28]\n", 487 | " x_samples = torch.reshape(x_samples, (-1, 28*28))\n", 488 | " pvar = torch.ones_like(x_samples).to(device) * 0.25 # sharp fit (p(x)_sigma=0.5)\n", 489 | " # pvar = torch.ones_like(x_samples).to(device) * 1.00 # loose fit (p(x)_sigma=1.0)\n", 490 | " reconstruction_loss = F.gaussian_nll_loss(\n", 491 | " decoded_x_samples,\n", 492 | " x_samples,\n", 493 | " var=pvar,\n", 494 | " reduction=\"sum\"\n", 495 | " )\n", 496 | " reconstruction_loss /= num_sampling\n", 497 | " ########## for debug\n", 498 | " # if batch_idx == 0:\n", 499 | " # print(reconstruction_loss)\n", 500 | " # print(kl_loss)\n", 501 | " \n", 502 | " # optimize parameters\n", 503 | " total_loss = reconstruction_loss + kl_loss\n", 504 | " total_loss.backward()\n", 505 | " opt.step()\n", 506 | " \n", 507 | " # log\n", 508 | " loss_records.append(total_loss.item())\n", 509 | " print(\"epoch{} (iter{}) - loss {:5.4f}\".format(epoch_idx, batch_idx, total_loss), end=\"\\r\")\n", 510 | "\n", 511 | " epoch_loss_all = loss_records[-(batch_idx+1):]\n", 512 | " average_loss = sum(epoch_loss_all)/len(epoch_loss_all)\n", 513 | " print(\"epoch{} (iter{}) - loss {:5.4f}\".format(epoch_idx, batch_idx, average_loss))\n", 514 | "\n", 515 | "print(\"Done\")" 516 | ] 517 | }, 518 | { 519 | "cell_type": "markdown", 520 | "id": "0acf2cff-1268-4c44-a941-67a0164cdbae", 521 | "metadata": {}, 522 | "source": [ 523 | "After the training is completed, show loss to see the training progress." 524 | ] 525 | }, 526 | { 527 | "cell_type": "code", 528 | "execution_count": 7, 529 | "id": "126114de-f98b-497b-b60d-e7549be86f7d", 530 | "metadata": {}, 531 | "outputs": [ 532 | { 533 | "data": { 534 | "text/plain": [ 535 | "[]" 536 | ] 537 | }, 538 | "execution_count": 7, 539 | "metadata": {}, 540 | "output_type": "execute_result" 541 | }, 542 | { 543 | "data": { 544 | "image/png": "", 545 | "text/plain": [ 546 | "
" 547 | ] 548 | }, 549 | "metadata": {}, 550 | "output_type": "display_data" 551 | } 552 | ], 553 | "source": [ 554 | "plt.plot(loss_records)" 555 | ] 556 | }, 557 | { 558 | "cell_type": "markdown", 559 | "id": "c2cf9816-fb75-4217-8023-f9eb55221e9e", 560 | "metadata": {}, 561 | "source": [ 562 | "### Generate images" 563 | ] 564 | }, 565 | { 566 | "cell_type": "markdown", 567 | "id": "e3f16541-5b4f-42ab-9086-7b07f48e2c29", 568 | "metadata": {}, 569 | "source": [ 570 | "Using the trained model, let's generate images from random latent vectors. (Here I generate 5 latent vectors.)" 571 | ] 572 | }, 573 | { 574 | "cell_type": "code", 575 | "execution_count": 8, 576 | "id": "d14719d0-507c-414c-b629-ab071bb18502", 577 | "metadata": {}, 578 | "outputs": [ 579 | { 580 | "data": { 581 | "image/png": "", 582 | "text/plain": [ 583 | "
" 584 | ] 585 | }, 586 | "metadata": {}, 587 | "output_type": "display_data" 588 | }, 589 | { 590 | "data": { 591 | "image/png": "", 592 | "text/plain": [ 593 | "
" 594 | ] 595 | }, 596 | "metadata": {}, 597 | "output_type": "display_data" 598 | }, 599 | { 600 | "data": { 601 | "image/png": "", 602 | "text/plain": [ 603 | "
" 604 | ] 605 | }, 606 | "metadata": {}, 607 | "output_type": "display_data" 608 | }, 609 | { 610 | "data": { 611 | "image/png": "", 612 | "text/plain": [ 613 | "
" 614 | ] 615 | }, 616 | "metadata": {}, 617 | "output_type": "display_data" 618 | }, 619 | { 620 | "data": { 621 | "image/png": "", 622 | "text/plain": [ 623 | "
" 624 | ] 625 | }, 626 | "metadata": {}, 627 | "output_type": "display_data" 628 | } 629 | ], 630 | "source": [ 631 | "import matplotlib.pyplot as plt\n", 632 | "\n", 633 | "latent_samples = torch.randn(5, latent_dim).to(device)\n", 634 | "with torch.no_grad():\n", 635 | " imgs = decoder_net(latent_samples)\n", 636 | "for i in imgs:\n", 637 | " plt.imshow(i[0].cpu().numpy())\n", 638 | " plt.show()" 639 | ] 640 | }, 641 | { 642 | "cell_type": "markdown", 643 | "id": "addfccfb-ca3b-4ebe-b79b-9cd70ea5f707", 644 | "metadata": {}, 645 | "source": [ 646 | "You can also apply variational auto-encoder (VAE) for a variety of tasks - such as, denoising, anomaly detection, etc.\n", 647 | "\n", 648 | "Here I simply recover image without noise by applying encoding / decoding.\n", 649 | "\n", 650 | "> Note : To construct the model for denoising, use the denoising dataset (i.e, pairs of noisy images and clean images) and train the model for denoising with VAE architecture. (Here I simply reuse above model for recovering clean image.)" 651 | ] 652 | }, 653 | { 654 | "cell_type": "code", 655 | "execution_count": 10, 656 | "id": "d722ac15-1084-44df-b324-ea70a0b02e71", 657 | "metadata": {}, 658 | "outputs": [ 659 | { 660 | "data": { 661 | "image/png": "", 662 | "text/plain": [ 663 | "
" 664 | ] 665 | }, 666 | "metadata": {}, 667 | "output_type": "display_data" 668 | }, 669 | { 670 | "data": { 671 | "image/png": "", 672 | "text/plain": [ 673 | "
" 674 | ] 675 | }, 676 | "metadata": {}, 677 | "output_type": "display_data" 678 | } 679 | ], 680 | "source": [ 681 | "# make noisy image\n", 682 | "noise_factor = 0.2\n", 683 | "for _, (data, _) in enumerate(loader):\n", 684 | " data = data.to(device)\n", 685 | " org_image = data[0]\n", 686 | " noisy_image = org_image + noise_factor * torch.randn_like(org_image)\n", 687 | " noisy_image = torch.clamp(noisy_image, min=0.0, max=1.0)\n", 688 | " break\n", 689 | "\n", 690 | "# show noisy image\n", 691 | "plt.imshow(noisy_image[0].cpu().numpy())\n", 692 | "plt.show()\n", 693 | "\n", 694 | "# recover thru encoder/decoder\n", 695 | "with torch.no_grad():\n", 696 | " mu, sigma = encoder_net(noisy_image.unsqueeze(dim=0))\n", 697 | " decoded_image = decoder_net(mu)\n", 698 | "\n", 699 | "# show the recovered image\n", 700 | "plt.imshow(decoded_image[0][0].cpu().numpy())\n", 701 | "plt.show()" 702 | ] 703 | }, 704 | { 705 | "cell_type": "code", 706 | "execution_count": null, 707 | "id": "c876ba48-e1ac-4580-9a40-b3c56d18b1a7", 708 | "metadata": {}, 709 | "outputs": [], 710 | "source": [] 711 | } 712 | ], 713 | "metadata": { 714 | "kernelspec": { 715 | "display_name": "Python 3 (ipykernel)", 716 | "language": "python", 717 | "name": "python3" 718 | }, 719 | "language_info": { 720 | "codemirror_mode": { 721 | "name": "ipython", 722 | "version": 3 723 | }, 724 | "file_extension": ".py", 725 | "mimetype": "text/x-python", 726 | "name": "python", 727 | "nbconvert_exporter": "python", 728 | "pygments_lexer": "ipython3", 729 | "version": "3.12.3" 730 | } 731 | }, 732 | "nbformat": 4, 733 | "nbformat_minor": 5 734 | } 735 | --------------------------------------------------------------------------------