├── .gitignore ├── 2d_conditional.ipynb ├── 2d_distribution.ipynb ├── README.md ├── checkpoints └── mnist_model.pt ├── ddim ├── __init__.py ├── ddim.py └── predictor.py ├── demo.ipynb ├── mnist_analysis.ipynb ├── mnist_conditional.ipynb └── mnist_train.py /.gitignore: -------------------------------------------------------------------------------- 1 | mnist_data 2 | .ipynb_checkpoints 3 | __pycache__ 4 | 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DDIM 2 | 3 | This repo contains some experiments with [Denoising Diffusion Implicit Models](https://openreview.net/forum?id=St1giarCHLP). In particular, I play with simple distributions and see how DDIM translates latents to datapoints. This is a particularly interesting class of models because the latent space is intrinsic, it's not learned, and it's really cool to investigate. 4 | -------------------------------------------------------------------------------- /checkpoints/mnist_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unixpickle/ddim/27950a639afbbe3de8f8f72d8605f14f91a5742a/checkpoints/mnist_model.pt -------------------------------------------------------------------------------- /ddim/__init__.py: -------------------------------------------------------------------------------- 1 | from .ddim import Diffusion, create_alpha_schedule 2 | from .predictor import Predictor, CNNPredictor, BayesPredictor, train_predictor 3 | 4 | __all__ = [ 5 | "Diffusion", 6 | "create_alpha_schedule", 7 | "Predictor", 8 | "CNNPredictor", 9 | "BayesPredictor", 10 | "train_predictor", 11 | ] 12 | -------------------------------------------------------------------------------- /ddim/ddim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm.auto import tqdm 4 | 5 | 6 | def create_alpha_schedule(num_steps=100, beta_0=0.0001, beta_T=0.02): 7 | betas = np.linspace(beta_0, beta_T, num_steps) 8 | result = [1.0] 9 | alpha = 1.0 10 | for beta in betas: 11 | alpha *= 1 - beta 12 | result.append(alpha) 13 | return np.array(result, dtype=np.float64) 14 | 15 | 16 | class Diffusion: 17 | """ 18 | A numpy implementation of the DDPM and DDIM setup. 19 | """ 20 | 21 | def __init__(self, alphas): 22 | self.alphas = alphas 23 | 24 | @property 25 | def num_steps(self): 26 | return len(self.alphas) - 1 27 | 28 | def sample_q(self, x_0, ts, epsilon=None): 29 | """ 30 | Sample from q(x_t | x_0) for a batch of x_0. 31 | """ 32 | if epsilon is None: 33 | epsilon = np.random.normal(size=x_0.shape) 34 | alphas = self.alphas_for_ts(ts, x_0.shape) 35 | return np.sqrt(alphas) * x_0 + np.sqrt(1 - alphas) * epsilon 36 | 37 | def predict_x0(self, x_t, ts, epsilon_prediction): 38 | alphas = self.alphas_for_ts(ts, x_t.shape) 39 | return (x_t - np.sqrt(1 - alphas) * epsilon_prediction) / np.sqrt(alphas) 40 | 41 | def ddim_previous(self, x_t, ts, epsilon_prediction): 42 | """ 43 | Take a ddim sampling step given x_t, t, and epsilon prediction. 44 | """ 45 | x_0 = self.predict_x0(x_t, ts, epsilon_prediction) 46 | return self.sample_q(x_0, ts - 1, epsilon=epsilon_prediction) 47 | 48 | def ddim_sample(self, x_T, predictor, progress=False): 49 | """ 50 | Sample x_0 from x_t using DDIM, assuming a method on predictor called 51 | predict_epsilon(x_t, alphas). 52 | """ 53 | x_t = x_T 54 | t_iter = range(1, self.num_steps + 1)[::-1] 55 | if progress: 56 | t_iter = tqdm(t_iter) 57 | for t in t_iter: 58 | ts = np.array([t] * x_T.shape[0]) 59 | alphas = self.alphas_for_ts(ts) 60 | x_t = self.ddim_previous(x_t, ts, predictor.predict_epsilon(x_t, alphas)) 61 | return x_t 62 | 63 | def ddim_sample_cond(self, x_T, predictor, x_cond, mask): 64 | """ 65 | Like ddim_sample(), but condition on the part of x_cond which is True 66 | in the boolean mask. 67 | """ 68 | x_t = x_T 69 | for t in range(1, self.num_steps + 1)[::-1]: 70 | ts = np.array([t] * x_T.shape[0]) 71 | alphas = self.alphas_for_ts(ts) 72 | x_t = np.where(mask, self.sample_q(x_cond, ts), x_t) 73 | x_t = self.ddim_previous(x_t, ts, predictor.predict_epsilon(x_t, alphas)) 74 | return np.where(mask, x_cond, x_t) 75 | 76 | def ddpm_previous( 77 | self, x_t, ts, epsilon_prediction, epsilon=None, cond_prediction=None 78 | ): 79 | if epsilon is None: 80 | epsilon = np.random.normal(size=x_t.shape) 81 | alphas_t = self.alphas_for_ts(ts, x_t.shape) 82 | alphas_prev = self.alphas_for_ts(ts - 1, x_t.shape) 83 | alphas = alphas_t / alphas_prev 84 | betas = 1 - alphas 85 | prev_mean = (1 / np.sqrt(alphas)) * ( 86 | x_t - betas / np.sqrt(1 - alphas_t) * epsilon_prediction 87 | ) 88 | if cond_prediction is not None: 89 | prev_mean += betas * cond_prediction 90 | return prev_mean + np.sqrt(betas) * epsilon 91 | 92 | def ddpm_sample(self, x_T, predictor): 93 | """ 94 | Sample x_0 from x_t using DDPM. 95 | 96 | Usage is the same as ddim_sample(). 97 | """ 98 | x_t = x_T 99 | for t in range(1, self.num_steps + 1)[::-1]: 100 | ts = np.array([t] * x_T.shape[0]) 101 | alphas = self.alphas_for_ts(ts) 102 | x_t = self.ddpm_previous(x_t, ts, predictor.predict_epsilon(x_t, alphas)) 103 | return x_t 104 | 105 | def ddpm_sample_cond(self, x_T, predictor, x_cond, mask, num_subsamples=1): 106 | """ 107 | Create a masked-conditional sample using DDPM. 108 | 109 | See ddim_sample_cond() for usage details. 110 | """ 111 | x_t = x_T 112 | for t in range(1, self.num_steps + 1)[::-1]: 113 | samples = [] 114 | for _ in range(num_subsamples): 115 | ts = np.array([t] * x_T.shape[0]) 116 | x_t = np.where(mask, self.sample_q(x_cond, ts), x_t) 117 | alphas = self.alphas_for_ts(ts) 118 | x_next = self.ddpm_previous( 119 | x_t, ts, predictor.predict_epsilon(x_t, alphas) 120 | ) 121 | samples.append(x_next) 122 | x_t = np.mean(samples, axis=0) 123 | return np.where(mask, x_cond, x_t) 124 | 125 | def ddpm_sample_cond_energy(self, x_T, predictor, cond_fn): 126 | """ 127 | Create a sample using an energy function cond_fn as a conditioning 128 | signal, to compute p(x)*p(y|x), where cond_fn is grad_x log(p(y|x)). 129 | """ 130 | x_t = x_T 131 | for t in range(1, self.num_steps + 1)[::-1]: 132 | ts = np.array([t] * x_T.shape[0]) 133 | alphas = self.alphas_for_ts(ts) 134 | x_t = self.ddpm_previous( 135 | x_t, 136 | ts, 137 | predictor.predict_epsilon(x_t, alphas), 138 | cond_prediction=cond_fn(x_t, alphas), 139 | ) 140 | return x_t 141 | 142 | def ddpm_sample_cond_energy_inpaint( 143 | self, x_T, predictor, x_cond, mask, temp=1.0, eps=1e-2 144 | ): 145 | def cond_fn(x_t, alphas): 146 | while len(alphas.shape) < len(x_t.shape): 147 | alphas = alphas[..., None] 148 | with torch.enable_grad(): 149 | alphas_torch = torch.from_numpy(alphas).float() 150 | x_t_torch = torch.from_numpy(x_t).float().requires_grad_(True) 151 | eps_pred = predictor(x_t_torch, alphas_torch.view(-1)) 152 | x_start = ( 153 | x_t_torch - (1 - alphas_torch).sqrt() * eps_pred 154 | ) / alphas_torch.sqrt() 155 | 156 | # This should be the variance of the x_start prediction, 157 | # but instead we use the variance of a signal noised to 158 | # the current timestep as a reasonable guess. 159 | sigmas = eps + 1 - alphas_torch 160 | 161 | log_density = -((torch.from_numpy(x_cond) - x_start) ** 2) / ( 162 | 2 * sigmas 163 | ) 164 | loss = (log_density * torch.from_numpy(mask).float()).sum() 165 | grad = torch.autograd.grad(loss, x_t_torch)[0] 166 | return grad.detach().numpy() / temp 167 | 168 | return self.ddpm_sample_cond_energy(x_T, predictor, cond_fn) 169 | 170 | def alphas_for_ts(self, ts, shape=None): 171 | alphas = self.alphas[ts] 172 | if shape is None: 173 | return alphas 174 | while len(alphas.shape) < len(shape): 175 | alphas = alphas[..., None] 176 | return alphas 177 | -------------------------------------------------------------------------------- /ddim/predictor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.optim import Adam 5 | from tqdm.auto import tqdm 6 | 7 | 8 | def train_predictor(diffusion, data_batches, lr=1e-4): 9 | predictor = None 10 | optim = None 11 | losses = [] 12 | for batch in tqdm(data_batches): 13 | if predictor is None: 14 | predictor = Predictor(batch.shape[1:]) 15 | optim = Adam(predictor.parameters(), lr=lr) 16 | dev = next(predictor.parameters()).device 17 | ts = torch.randint( 18 | low=1, high=diffusion.num_steps + 1, size=(batch.shape[0],) 19 | ).to(dev) 20 | epsilon = torch.randn(*batch.shape).to(dev) 21 | samples = ( 22 | torch.from_numpy( 23 | diffusion.sample_q( 24 | batch, ts.cpu().numpy(), epsilon=epsilon.cpu().numpy() 25 | ) 26 | ) 27 | .float() 28 | .to(dev) 29 | ) 30 | alphas = torch.from_numpy(diffusion.alphas_for_ts(ts.cpu().numpy())).to(dev) 31 | predictions = predictor(samples, alphas.float()) 32 | loss = torch.mean((epsilon - predictions) ** 2) 33 | losses.append(loss.item()) 34 | optim.zero_grad() 35 | loss.backward() 36 | optim.step() 37 | return predictor, losses 38 | 39 | 40 | class Predictor(nn.Module): 41 | def __init__(self, data_shape, num_layers=1, channels=128): 42 | super().__init__() 43 | self.data_shape = data_shape 44 | 45 | self.register_buffer( 46 | "timestep_coeff", torch.linspace(start=0.1, end=100, steps=channels)[None] 47 | ) 48 | self.timestep_phase = nn.Parameter(torch.randn(channels)[None]) 49 | self.input_embed = nn.Linear(int(np.prod(data_shape)), channels) 50 | self.timestep_embed = nn.Sequential( 51 | nn.Linear(channels, channels), 52 | nn.GELU(), 53 | nn.Linear(channels, channels), 54 | ) 55 | self.layers = nn.Sequential( 56 | nn.GELU(), 57 | *[ 58 | nn.Sequential(nn.Linear(channels, channels), nn.GELU()) 59 | for _ in range(num_layers) 60 | ], 61 | nn.Linear(channels, int(np.prod(data_shape))), 62 | ) 63 | 64 | def forward(self, inputs, alphas): 65 | embed_alphas = torch.sin( 66 | (self.timestep_coeff * alphas.float()[:, None]) + self.timestep_phase 67 | ) 68 | embed_alphas = self.timestep_embed(embed_alphas) 69 | embed_ins = self.input_embed(inputs.view(inputs.shape[0], -1)) 70 | out = self.layers(embed_ins + embed_alphas) 71 | return out.view(inputs.shape) 72 | 73 | def predict_epsilon(self, inputs_np, alphas_np): 74 | dev = next(self.parameters()).device 75 | inputs = torch.from_numpy(inputs_np).float().to(dev) 76 | alphas = torch.from_numpy(alphas_np).float().to(dev) 77 | with torch.no_grad(): 78 | return self(inputs, alphas).detach().cpu().numpy().astype(inputs_np.dtype) 79 | 80 | 81 | class CNNPredictor(nn.Module): 82 | def __init__(self, data_shape, num_res_blocks=7, channels=128): 83 | super().__init__() 84 | assert len(data_shape) == 3 85 | self.data_shape = data_shape 86 | 87 | self.register_buffer( 88 | "timestep_coeff", 89 | torch.linspace(start=0.1, end=1000, steps=channels * 4)[None], 90 | ) 91 | self.timestep_phase = nn.Parameter(torch.randn(channels * 4)[None]) 92 | self.timestep_embed = nn.Sequential( 93 | nn.Linear(channels * 4, channels), 94 | nn.GELU(), 95 | nn.Linear(channels, channels), 96 | ) 97 | self.input_embed = nn.Conv2d(data_shape[0], channels, 1) 98 | self.res_blocks = nn.ModuleList([]) 99 | self.timestep_blocks = nn.ModuleList([]) 100 | for _ in range(num_res_blocks): 101 | block = nn.Sequential( 102 | nn.GroupNorm(8, channels), 103 | nn.GELU(), 104 | nn.Conv2d(channels, channels, 3, padding=1), 105 | nn.GELU(), 106 | nn.Conv2d(channels, channels, 3, padding=1), 107 | SELayer(channels), 108 | ) 109 | self.res_blocks.append(block) 110 | self.timestep_blocks.append( 111 | nn.Sequential( 112 | nn.Linear(channels, channels), 113 | nn.GELU(), 114 | nn.Linear(channels, channels), 115 | ) 116 | ) 117 | self.out_layer = nn.Conv2d(channels, data_shape[0], 3, padding=1) 118 | 119 | def forward(self, inputs, alphas): 120 | assert inputs.shape[1:] == self.data_shape 121 | embed_alphas = torch.sin( 122 | (self.timestep_coeff * alphas.float()[:, None]) + self.timestep_phase 123 | ) 124 | embed_alphas = self.timestep_embed(embed_alphas) 125 | out = self.input_embed(inputs) 126 | for block, ts_block in zip(self.res_blocks, self.timestep_blocks): 127 | out = out + block(out + ts_block(embed_alphas)[..., None, None]) 128 | out = self.out_layer(out) 129 | return out 130 | 131 | def predict_epsilon(self, inputs_np, alphas_np): 132 | dev = next(self.parameters()).device 133 | inputs = torch.from_numpy(inputs_np).float().to(dev) 134 | alphas = torch.from_numpy(alphas_np).float().to(dev) 135 | with torch.no_grad(): 136 | return self(inputs, alphas).detach().cpu().numpy().astype(inputs_np.dtype) 137 | 138 | 139 | class SELayer(nn.Module): 140 | """ 141 | A squeeze-excitation layer, from: 142 | https://github.com/moskomule/senet.pytorch/blob/23839e07525f9f5d39982140fccc8b925fe4dee9/senet/se_module.py 143 | 144 | This layer provides global context to each local part of an image, 145 | allowing for larger receptive field without much extra compute. 146 | """ 147 | 148 | def __init__(self, channel, reduction=8): 149 | super().__init__() 150 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 151 | self.fc = nn.Sequential( 152 | nn.Linear(channel, channel // reduction, bias=False), 153 | nn.ReLU(inplace=True), 154 | nn.Linear(channel // reduction, channel, bias=False), 155 | nn.Sigmoid(), 156 | ) 157 | 158 | def forward(self, x): 159 | b, c, _, _ = x.size() 160 | y = self.avg_pool(x).view(b, c) 161 | y = self.fc(y).view(b, c, 1, 1) 162 | return x * y.expand_as(x) 163 | 164 | 165 | class BayesPredictor(nn.Module): 166 | """ 167 | An epsilon predictor that uses Bayes rule to predict epsilon without any 168 | learnable parameters--just a bunch of data. 169 | """ 170 | 171 | def __init__(self, data_batch): 172 | super().__init__() 173 | self.data_batch = data_batch 174 | 175 | def forward(self, inputs, alphas): 176 | while len(alphas.shape) < len(inputs.shape): 177 | alphas = alphas[..., None] 178 | means = torch.sqrt(alphas)[:, None] * self.data_batch[None] 179 | variances = 1 - alphas 180 | while len(variances.shape) < len(means.shape): 181 | variances = variances[..., None] 182 | logits = -( 183 | 0.5 * torch.log(variances) 184 | + (0.5 / variances) * (inputs[:, None] - means) ** 2 185 | ) 186 | while len(logits.shape) > 2: 187 | logits = torch.sum(logits, dim=-1) 188 | logits = logits - torch.max(logits, dim=1, keepdims=True)[0] 189 | probs = torch.exp(logits) 190 | probs = probs / torch.sum(probs, dim=1, keepdims=True) 191 | while len(probs.shape) < len(self.data_batch.shape) + 1: 192 | probs = probs[..., None] 193 | x_0 = torch.sum(torch.from_numpy(self.data_batch[None]) * probs, dim=1) 194 | return (inputs - torch.sqrt(alphas) * x_0) / torch.sqrt(1 - alphas) 195 | 196 | def predict_epsilon(self, inputs_np, alphas_np): 197 | inputs = torch.from_numpy(inputs_np) 198 | alphas = torch.from_numpy(alphas_np) 199 | return self(inputs, alphas).detach().numpy().astype(inputs_np.dtype) 200 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from ddim import BayesPredictor, Diffusion, create_alpha_schedule, train_predictor" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import matplotlib.pyplot as plt\n", 19 | "import numpy as np" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 3, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "USE_BAYES = True\n", 29 | "DATASET = 'bimodal' # uniform, bimodal" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 4, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "def generate_data(batch_size, num_batches):\n", 39 | " if DATASET == 'uniform':\n", 40 | " return np.random.uniform(size=(num_batches, batch_size, 1))\n", 41 | " elif DATASET == 'bimodal':\n", 42 | " raw_data = 0.2 * np.random.uniform(size=(num_batches, batch_size, 1))\n", 43 | " offsets = np.random.randint(low=0, high=2, size=raw_data.shape).astype(raw_data.dtype)\n", 44 | " return raw_data - 0.1 + (offsets - 0.5) * 2\n", 45 | " else:\n", 46 | " raise ValueError(DATASET)" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 5, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "diffusion = Diffusion(create_alpha_schedule(num_steps=1000))" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 6, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "if USE_BAYES:\n", 65 | " model = BayesPredictor(generate_data(1000, 1)[0])\n", 66 | "else:\n", 67 | " data = generate_data(batch_size=1000, num_batches=1000)\n", 68 | " print('mean', np.mean(data), 'std', np.std(data))\n", 69 | " model, losses = train_predictor(diffusion, data, lr=2e-3)\n", 70 | " plt.plot(losses)\n", 71 | " plt.show()\n", 72 | " print('final loss', np.mean(losses[-10:]))" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 7, 78 | "metadata": {}, 79 | "outputs": [ 80 | { 81 | "data": { 82 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADaBJREFUeJzt3X+s3fVdx/HnSzqGmXNQuNZK6cpCN8I/wHJDmBgT6TBsGlodki1Gr6amLlEzNxOt7i+NicM/xJmYJQ3grslkYHVpXci2rkCIiWO7CBs/utHSjKxNf9xtMDdjmMW3f9wv5krv5Xzvveece/vp85HcnO/3e76n58339j777bfnHFJVSJLOfT+y2gNIkobDoEtSIwy6JDXCoEtSIwy6JDXCoEtSIwy6JDXCoEtSIwy6JDVi3Tif7LLLLqstW7aM8ykl6Zz3+OOPf7uqJgbtN9agb9myhZmZmXE+pSSd85K80Gc/L7lIUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiPG+k5RSTpX3XXguWU/9sO3vH2IkyzOM3RJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RG9Ap6kouT7E3y9SSHkrwryfokB5Ic7m4vGfWwkqTF9T1D/zjwuaq6GrgWOATsBg5W1VbgYLcuSVolA4Oe5C3AzwL3AFTVD6vqJWA7MN3tNg3sGNWQkqTB+pyhXwnMAn+X5Ikkdyd5E7Chqk50+5wENoxqSEnSYH2Cvg54J/CJqroe+E9ec3mlqgqohR6cZFeSmSQzs7OzK51XkrSIPkE/Bhyrqse69b3MBf5Uko0A3e3phR5cVXuqarKqJicmJoYxsyRpAQODXlUngW8leUe3aRvwLLAfmOq2TQH7RjKhJKmXdT33+z3gU0kuBI4Cv8ncHwYPJNkJvADcMZoRJUl99Ap6VT0JTC5w17bhjiNJWi7fKSpJjTDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5JjTDoktSIdX12SvJN4PvAK8CZqppMsh64H9gCfBO4o6peHM2YkqRBlnKG/nNVdV1VTXbru4GDVbUVONitS5JWyUouuWwHprvlaWDHyseRJC1X36AX8IUkjyfZ1W3bUFUnuuWTwIahTydJ6q3XNXTgZ6rqeJKfAA4k+fr8O6uqktRCD+z+ANgFsHnz5hUNK0laXK8z9Ko63t2eBj4D3ACcSrIRoLs9vchj91TVZFVNTkxMDGdqSdJZBgY9yZuSvPnVZeDngaeB/cBUt9sUsG9UQ0qSButzyWUD8Jkkr+7/D1X1uSRfAR5IshN4AbhjdGNK0srddeC51R5hpAYGvaqOAtcusP07wLZRDCVJWjrfKSpJjTDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5Jjej7eejntJV8IM+Hb3n7ECeRpNHxDF2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGtE76EkuSPJEks9261cmeSzJkST3J7lwdGNKkgZZyhn6h4BD89bvBO6qqquAF4GdwxxMkrQ0vYKeZBPwC8Dd3XqAm4G93S7TwI5RDChJ6qfvx+f+NfCHwJu79UuBl6rqTLd+DLh8oQcm2QXsAti8efOyB13JR+BK0vlg4Bl6kl8ETlfV48t5gqraU1WTVTU5MTGxnF9CktRDnzP0m4DbkrwXuAj4ceDjwMVJ1nVn6ZuA46MbU5I0yMAz9Kr646raVFVbgPcDD1XVrwIPA7d3u00B+0Y2pSRpoJW8Dv2PgI8kOcLcNfV7hjOSJGk5lvT/FK2qR4BHuuWjwA3DH0mStBy+U1SSGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRA4Oe5KIkX07y1STPJPnTbvuVSR5LciTJ/UkuHP24kqTF9DlDfxm4uaquBa4Dbk1yI3AncFdVXQW8COwc3ZiSpEEGBr3m/KBbfUP3VcDNwN5u+zSwYyQTSpJ66XUNPckFSZ4ETgMHgOeBl6rqTLfLMeDy0YwoSeqjV9Cr6pWqug7YBNwAXN33CZLsSjKTZGZ2dnaZY0qSBlnSq1yq6iXgYeBdwMVJ1nV3bQKOL/KYPVU1WVWTExMTKxpWkrS4Pq9ymUhycbf8o8AtwCHmwn57t9sUsG9UQ0qSBls3eBc2AtNJLmDuD4AHquqzSZ4FPp3kz4EngHtGOKckaYCBQa+qrwHXL7D9KHPX0yVJa4DvFJWkRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWrEwKAnuSLJw0meTfJMkg9129cnOZDkcHd7yejHlSQtps8Z+hngD6rqGuBG4HeSXAPsBg5W1VbgYLcuSVolA4NeVSeq6t+75e8Dh4DLge3AdLfbNLBjVENKkgZb0jX0JFuA64HHgA1VdaK76ySwYaiTSZKWpHfQk/wY8E/A71fVf8y/r6oKqEUetyvJTJKZ2dnZFQ0rSVpcr6AneQNzMf9UVf1zt/lUko3d/RuB0ws9tqr2VNVkVU1OTEwMY2ZJ0gL6vMolwD3Aoar6q3l37QemuuUpYN/wx5Mk9bWuxz43Ab8GPJXkyW7bnwAfAx5IshN4AbhjNCNKkvoYGPSq+lcgi9y9bbjjSJKWy3eKSlIjDLokNcKgS1IjDLokNcKgS1IjDLokNcKgS1IjDLokNcKgS1IjDLokNcKgS1IjDLokNcKgS1IjDLokNcKgS1IjDLokNcKgS1IjDLokNcKgS1IjDLokNcKgS1IjDLokNcKgS1IjBgY9yb1JTid5et629UkOJDnc3V4y2jElSYP0OUP/JHDra7btBg5W1VbgYLcuSVpFA4NeVY8C333N5u3AdLc8DewY8lySpCVa7jX0DVV1ols+CWxYbMcku5LMJJmZnZ1d5tNJkgZZ8T+KVlUB9Tr376mqyaqanJiYWOnTSZIWsdygn0qyEaC7PT28kSRJy7HcoO8HprrlKWDfcMaRJC1Xn5ct3gf8G/COJMeS7AQ+BtyS5DDw7m5dkrSK1g3aoao+sMhd24Y8iyRpBXynqCQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiNWFPQktyb5RpIjSXYPayhJ0tItO+hJLgD+FngPcA3wgSTXDGswSdLSrOQM/QbgSFUdraofAp8Gtg9nLEnSUq0k6JcD35q3fqzbJklaBetG/QRJdgG7utUfJPnGqJ9zAZcB317OAz8y5EHWkGUfk4Z5TM7mMTnbko/JEDry1j47rSTox4Er5q1v6rb9P1W1B9izgudZsSQzVTW5mjOsNR6Ts3lMzuYxOdtaPiYrueTyFWBrkiuTXAi8H9g/nLEkSUu17DP0qjqT5HeBzwMXAPdW1TNDm0yStCQruoZeVQ8CDw5pllFa1Us+a5TH5Gwek7N5TM62Zo9Jqmq1Z5AkDYFv/ZekRjQZ9CS/kuSZJP+TZNF/jT6fProgyfokB5Ic7m4vWWS/V5I82X01+Y/cg77vSd6Y5P7u/seSbBn/lOPV45j8RpLZeb83fms15hyXJPcmOZ3k6UXuT5K/6Y7X15K8c9wzLqTJoANPA78MPLrYDufhRxfsBg5W1VbgYLe+kP+qquu6r9vGN9549Py+7wRerKqrgLuAO8c75Xgt4Wfh/nm/N+4e65Dj90ng1te5/z3A1u5rF/CJMcw0UJNBr6pDVTXoDUzn20cXbAemu+VpYMcqzrKa+nzf5x+rvcC2JBnjjON2vv0sDFRVjwLffZ1dtgN/X3O+BFycZON4pltck0Hv6Xz76IINVXWiWz4JbFhkv4uSzCT5UpIWo9/n+/5/+1TVGeB7wKVjmW519P1ZeF93eWFvkisWuP98sib7MfK3/o9Kki8CP7nAXR+tqn3jnmcteL1jMn+lqirJYi9vemtVHU/yNuChJE9V1fPDnlXnnH8B7quql5P8NnN/g7l5lWfSa5yzQa+qd6/wl+j10QXnktc7JklOJdlYVSe6vxqeXuTXON7dHk3yCHA90FLQ+3zfX93nWJJ1wFuA74xnvFUx8JhU1fz//ruBvxzDXGvZmuzH+XzJ5Xz76IL9wFS3PAWc9beYJJckeWO3fBlwE/Ds2CYcjz7f9/nH6nbgoWr7DRsDj8lrrg/fBhwa43xr0X7g17tXu9wIfG/eJc3VU1XNfQG/xNw1rZeBU8Dnu+0/BTw4b7/3As8xdwb60dWee8TH5FLmXt1yGPgisL7bPgnc3S3/NPAU8NXududqzz2iY3HW9x34M+C2bvki4B+BI8CXgbet9sxr4Jj8BfBM93vjYeDq1Z55xMfjPuAE8N9dS3YCHwQ+2N0f5l4Z9Hz3szK52jNXle8UlaRWnM+XXCSpKQZdkhph0CWpEQZdkhph0CWpEQZdkhph0CWpEQZdkhrxv+nC9zQqbtHeAAAAAElFTkSuQmCC\n", 83 | "text/plain": [ 84 | "
" 85 | ] 86 | }, 87 | "metadata": { 88 | "needs_background": "light" 89 | }, 90 | "output_type": "display_data" 91 | }, 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "mean 0.1518110806875614 std 0.98625144122289\n" 97 | ] 98 | } 99 | ], 100 | "source": [ 101 | "x_T = np.random.normal(size=(200, 1))\n", 102 | "samples = diffusion.ddpm_sample(x_T, model)\n", 103 | "plt.hist(samples.reshape(-1), 20, alpha=0.5)\n", 104 | "plt.show()\n", 105 | "\n", 106 | "print('mean', np.mean(samples), 'std', np.std(samples))" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 8, 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "data": { 116 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADahJREFUeJzt3X+sX/Vdx/HnSzqGmTooXGulZGWhjPAPsNwQJsZEOgybhlZFssVoNTV1iZrJTLS6vzQmDv+wzsQsaQB3TSYDO5fWhWx2BUJMHNtF2PjRjZZmZG36426DuRnDLHv7xz0sV3ov33Pv/f5oP30+kpvvOed7vvf77rm9z56efr+3qSokSee+H5n0AJKk4TDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5JjVgzzie77LLLauPGjeN8Skk65z3xxBPfrKqpQfuNNegbN25kdnZ2nE8pSee8JC/22c9LLpLUCIMuSY3oFfQkFyfZk+SrSQ4meVeStUn2JznU3V4y6mElSUvre4b+UeCzVXUNcB1wENgJHKiqTcCBbl2SNCEDg57krcDPAfcCVNX3q+plYAsw0+02A2wd1ZCSpMH6nKFfCcwB/5DkyST3JHkLsK6qjnf7nADWLfbgJDuSzCaZnZubG87UkqQz9An6GuCdwMeq6gbgv3nd5ZWa/2+PFv2vj6pqd1VNV9X01NTAl1FKklaoT9CPAker6vFufQ/zgT+ZZD1Ad3tqNCNKkvoYGPSqOgF8I8k7uk2bgeeAfcC2bts2YO9IJpQk9dL3naJ/AHwiyYXAEeC3mf/D4MEk24EXgTtHM6IkDceu/c9P5HnvuvXqsTxPr6BX1VPA9CJ3bR7uOJKklfKdopLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY1Y02enJF8Hvgu8Cpyuqukka4EHgI3A14E7q+ql0YwpSRpkOWfoP19V11fVdLe+EzhQVZuAA926JGlCVnPJZQsw0y3PAFtXP44kaaX6Br2Af0vyRJId3bZ1VXW8Wz4BrFvsgUl2JJlNMjs3N7fKcSVJS+l1DR342ao6luQngf1JvrrwzqqqJLXYA6tqN7AbYHp6etF9JEmr1+sMvaqOdbengE8DNwInk6wH6G5PjWpISdJgA4Oe5C1Jfvy1ZeAXgGeAfcC2brdtwN5RDSlJGqzPJZd1wKeTvLb/P1XVZ5N8CXgwyXbgReDO0Y0pSRpkYNCr6ghw3SLbvwVsHsVQkqTl852iktQIgy5JjTDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5Jjej7w7kmbtf+51f82LtuvXqIk0jS2ckzdElqhEGXpEYYdElqhEGXpEYYdElqhEGXpEYYdElqhEGXpEYYdElqhEGXpEYYdElqhEGXpEYYdElqhEGXpEYYdElqhEGXpEYYdElqRO+gJ7kgyZNJPtOtX5nk8SSHkzyQ5MLRjSlJGmQ5Z+gfBA4uWL8b2FVVVwEvAduHOZgkaXl6BT3JBuAXgXu69QC3AHu6XWaAraMYUJLUT98z9L8F/hj4Qbd+KfByVZ3u1o8Cly/2wCQ7kswmmZ2bm1vVsJKkpQ0MepJfAk5V1RMreYKq2l1V01U1PTU1tZJPIUnqYU2PfW4Gbk/yXuAi4CeAjwIXJ1nTnaVvAI6NbkxJ0iADz9Cr6k+rakNVbQTeBzxcVb8OPALc0e22Ddg7siklSQOt5nXofwJ8KMlh5q+p3zuckSRJK9HnkssPVdWjwKPd8hHgxuGPJElaCd8pKkmNMOiS1AiDLkmNMOiS1AiDLkmNMOiS1AiDLkmNMOiS1IhlvbHoXLVr//Mrfuxdt149xEkkaXQ8Q5ekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRgwMepKLknwxyZeTPJvkz7vtVyZ5PMnhJA8kuXD040qSltLnDP0V4Jaqug64HrgtyU3A3cCuqroKeAnYProxJUmDDAx6zftet/qm7qOAW4A93fYZYOtIJpQk9dLrGnqSC5I8BZwC9gMvAC9X1elul6PA5aMZUZLUR6+gV9WrVXU9sAG4Ebim7xMk2ZFkNsns3NzcCseUJA2yrFe5VNXLwCPAu4CLk6zp7toAHFviMburarqqpqemplY1rCRpaX1e5TKV5OJu+UeBW4GDzIf9jm63bcDeUQ0pSRpszeBdWA/MJLmA+T8AHqyqzyR5Dvhkkr8EngTuHeGckqQBBga9qr4C3LDI9iPMX0+XJJ0FfKeoJDXCoEtSIwy6JDXCoEtSIwy6JDXCoEtSIwy6JDXCoEtSIwy6JDXCoEtSIwy6JDXCoEtSIwy6JDXCoEtSIwy6JDXCoEtSIwy6JDXCoEtSIwy6JDXCoEtSIwy6JDXCoEtSIwy6JDXCoEtSIwy6JDXCoEtSIwYGPckVSR5J8lySZ5N8sNu+Nsn+JIe620tGP64kaSl9ztBPA39UVdcCNwG/l+RaYCdwoKo2AQe6dUnShAwMelUdr6r/7Ja/CxwELge2ADPdbjPA1lENKUkabFnX0JNsBG4AHgfWVdXx7q4TwLqhTiZJWpbeQU/yY8CngD+sqv9aeF9VFVBLPG5Hktkks3Nzc6saVpK0tF5BT/Im5mP+iar6l27zySTru/vXA6cWe2xV7a6q6aqanpqaGsbMkqRF9HmVS4B7gYNV9TcL7toHbOuWtwF7hz+eJKmvNT32uRn4DeDpJE912/4M+AjwYJLtwIvAnaMZUZLUx8CgV9W/A1ni7s3DHUeStFK+U1SSGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRA4Oe5L4kp5I8s2Db2iT7kxzqbi8Z7ZiSpEH6nKF/HLjtddt2AgeqahNwoFuXJE3QwKBX1WPAt1+3eQsw0y3PAFuHPJckaZlWeg19XVUd75ZPAOuGNI8kaYVW/Y+iVVVALXV/kh1JZpPMzs3NrfbpJElLWGnQTyZZD9Ddnlpqx6raXVXTVTU9NTW1wqeTJA2y0qDvA7Z1y9uAvcMZR5K0Un1etng/8B/AO5IcTbId+Ahwa5JDwLu7dUnSBK0ZtENVvX+JuzYPeRZJ0ir4TlFJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGrCroSW5L8rUkh5PsHNZQkqTlW3HQk1wA/D3wHuBa4P1Jrh3WYJKk5VnNGfqNwOGqOlJV3wc+CWwZzliSpOVaTdAvB76xYP1ot02SNAFrRv0ESXYAO7rV7yX52qifc4HLgG+u5hN8aEiDnGVWfVwa5DFZnMflTMs+JkPoyNv67LSaoB8DrliwvqHb9v9U1W5g9yqeZ8WSzFbV9CSe+2zmcTmTx2RxHpcznc3HZDWXXL4EbEpyZZILgfcB+4YzliRpuVZ8hl5Vp5P8PvA54ALgvqp6dmiTSZKWZVXX0KvqIeChIc0yChO51HMO8LicyWOyOI/Lmc7aY5KqmvQMkqQh8K3/ktSIpoKe5NeSPJvkB0mW/Ffo8+1HFiRZm2R/kkPd7SVL7Pdqkqe6jyb/gXvQ1z7Jm5M80N3/eJKN459yvHock99KMrfg98bvTGLOcUpyX5JTSZ5Z4v4k+bvumH0lyTvHPeNimgo68AzwK8BjS+1wnv7Igp3AgaraBBzo1hfzP1V1ffdx+/jGG4+eX/vtwEtVdRWwC7h7vFOO1zK+Hx5Y8HvjnrEOORkfB257g/vfA2zqPnYAHxvDTAM1FfSqOlhVg964dD7+yIItwEy3PANsneAsk9Tna7/wWO0BNifJGGcct/Px+2GgqnoM+PYb7LIF+Mea9wXg4iTrxzPd0poKek/n448sWFdVx7vlE8C6Jfa7KMlski8kaTH6fb72P9ynqk4D3wEuHct0k9H3++FXu0sLe5Jcscj955uzsiMjf+v/sCX5PPBTi9z14araO+55zhZvdFwWrlRVJVnqpU1vq6pjSd4OPJzk6ap6Ydiz6pzzr8D9VfVKkt9l/m8wt0x4Ji3inAt6Vb17lZ+i148sONe80XFJcjLJ+qo63v218NQSn+NYd3skyaPADUBLQe/ztX9tn6NJ1gBvBb41nvEmYuAxqaqFv/57gL8ew1xnu7OyI+fjJZfz8UcW7AO2dcvbgDP+JpPkkiRv7pYvA24GnhvbhOPR52u/8FjdATxcbb9ZY+Axed214duBg2Oc72y1D/jN7tUuNwHfWXBZc3KqqpkP4JeZv5b1CnAS+Fy3/aeBhxbs917geebPPj886bnHcFwuZf7VLYeAzwNru+3TwD3d8s8ATwNf7m63T3ruER2LM772wF8At3fLFwH/DBwGvgi8fdIznwXH5K+AZ7vfG48A10x65jEck/uB48D/dk3ZDnwA+EB3f5h/ddAL3ffL9KRnrirfKSpJrTgfL7lIUpMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ14v8AMWb5rcrhmLYAAAAASUVORK5CYII=\n", 117 | "text/plain": [ 118 | "
" 119 | ] 120 | }, 121 | "metadata": { 122 | "needs_background": "light" 123 | }, 124 | "output_type": "display_data" 125 | }, 126 | { 127 | "name": "stdout", 128 | "output_type": "stream", 129 | "text": [ 130 | "mean 0.17585555054618038 std 0.9831223405397485\n" 131 | ] 132 | } 133 | ], 134 | "source": [ 135 | "x_T = np.random.normal(size=(200, 1))\n", 136 | "samples = diffusion.ddim_sample(x_T, model)\n", 137 | "plt.hist(samples.reshape(-1), 20, alpha=0.5)\n", 138 | "plt.show()\n", 139 | "\n", 140 | "print('mean', np.mean(samples), 'std', np.std(samples))" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 9, 146 | "metadata": {}, 147 | "outputs": [ 148 | { 149 | "data": { 150 | "image/png": "\n", 151 | "text/plain": [ 152 | "
" 153 | ] 154 | }, 155 | "metadata": { 156 | "needs_background": "light" 157 | }, 158 | "output_type": "display_data" 159 | } 160 | ], 161 | "source": [ 162 | "x_T = np.linspace(-3, 3, 100)\n", 163 | "samples = diffusion.ddim_sample(x_T.reshape([-1, 1]), model)\n", 164 | "plt.xlabel('latent')\n", 165 | "plt.ylabel('sample')\n", 166 | "plt.ylim(-2, 2)\n", 167 | "plt.plot(x_T, samples)\n", 168 | "plt.show()" 169 | ] 170 | } 171 | ], 172 | "metadata": { 173 | "kernelspec": { 174 | "display_name": "Python 3", 175 | "language": "python", 176 | "name": "python3" 177 | }, 178 | "language_info": { 179 | "codemirror_mode": { 180 | "name": "ipython", 181 | "version": 3 182 | }, 183 | "file_extension": ".py", 184 | "mimetype": "text/x-python", 185 | "name": "python", 186 | "nbconvert_exporter": "python", 187 | "pygments_lexer": "ipython3", 188 | "version": "3.6.3" 189 | } 190 | }, 191 | "nbformat": 4, 192 | "nbformat_minor": 2 193 | } 194 | -------------------------------------------------------------------------------- /mnist_conditional.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "metadata": { 3 | "language_info": { 4 | "codemirror_mode": { 5 | "name": "ipython", 6 | "version": 3 7 | }, 8 | "file_extension": ".py", 9 | "mimetype": "text/x-python", 10 | "name": "python", 11 | "nbconvert_exporter": "python", 12 | "pygments_lexer": "ipython3", 13 | "version": "3.8.5-final" 14 | }, 15 | "orig_nbformat": 2, 16 | "kernelspec": { 17 | "name": "python3", 18 | "display_name": "Python 3", 19 | "language": "python" 20 | } 21 | }, 22 | "nbformat": 4, 23 | "nbformat_minor": 2, 24 | "cells": [ 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "from IPython.display import display\n", 32 | "from PIL import Image\n", 33 | "import matplotlib.pyplot as plt\n", 34 | "import numpy as np\n", 35 | "import torch\n", 36 | "from tqdm.auto import tqdm\n", 37 | "\n", 38 | "from ddim import BayesPredictor, CNNPredictor, Diffusion, create_alpha_schedule\n", 39 | "from mnist_train import create_datasets" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "diffusion = Diffusion(\n", 49 | " create_alpha_schedule(num_steps=100, beta_0=0.001, beta_T=0.2)\n", 50 | ")" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 4, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "def load_trained_model():\n", 60 | " model = CNNPredictor((1, 28, 28))\n", 61 | " model.load_state_dict(torch.load('checkpoints/mnist_model.pt'))\n", 62 | " return model\n", 63 | "\n", 64 | "cnn_model = load_trained_model()" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 5, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "def load_test_data():\n", 74 | " _, loader = create_datasets(10000, False)\n", 75 | " batch = next(iter(loader))[0]\n", 76 | " return batch.numpy()\n", 77 | "\n", 78 | "test_data = load_test_data()" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 6, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "def show_sample(sample):\n", 88 | " image = sample*0.3081 + 0.1307\n", 89 | " image = (image * 255).clip(0, 255).astype('uint8')\n", 90 | " image = image.reshape([4, 4, 1, 28, 28]).transpose(0, 3, 1, 4, 2).reshape([28*4, 28*4, 1])\n", 91 | " image = np.concatenate([image] * 3, axis=-1)\n", 92 | " display(Image.fromarray(image))" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 11, 98 | "metadata": {}, 99 | "outputs": [ 100 | { 101 | "output_type": "display_data", 102 | "data": { 103 | "text/plain": "", 104 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAHAAAABwCAIAAABJgmMcAAAWW0lEQVR4nO2dfVATxxvHkzRBkkEGUIMRobairS+8KFUZFLCKL7Sio4JTiy+jAgNqtdXB1+oERKltba34ii9FK1QoiKJCAamKZgrKS6GCIAUUEME0gJJAJtm9/P7Yur3m9e4S2/46+fzhhNvb5577unu7t/vcLotlxYoVK1as/N+wcuXKmpoajUZDkFAoFBEREfb29sxscjgcW1tb4+fw+Xy7lxg6x8HBoaCgICMjQ2+ql5dXZGTkqFGjqLiUmJiYlpYWGxvr7Ozs4OCAj3t4ePD5fCoWTDNgwIA9e/aoVCpogIyMDEdHR7pmuVzu/v37m5qa3nrrLd1UR0dHkUg0bdq0zs5OfCFDpoKCggiCaGxsHDt2rG5qbGwsQRDLli2j4lVnZycqK7W1tQqForq6Ojc3t7u7W6lU9vb2nj59OvAlbDab+s3+BbFYbEhKDEV3yXz66acob2Njo1AoxMe5XO7GjRubmprI9qVSqVQqNWSquLiYIIj09HS9qUjQa9euUfFq/Pjx+/fv7+jouH79OrkutrS0nDlz5ujRo/jI2rVr6d4yi8ViJSUlqdVq8r1duHDh0KFDNjY2Fy5cwAeTkpLoWs7IyEB579y5g6szh8PZvXs3+XIymay8vHzKlCnjx4/Xa2fZsmUQQpOCEgTx1VdfUfRNKBSKRKK4uLjdu3e//fbbc+fORVUwODi4r68PWZPL5XPmzHFzc6Nxzxs3biSrWVFR4e/vz+PxUKqjo2NaWhpKUqvVGzZsoGGaJOiFCxfwwTVr1qCDe/bsKSwsvHXrllgsNm4H6dXZ2TlhwgS9JwwfPrylpYUgiCdPnowcOZKWk7osWLDgq6++Wrx4cUhISGlpaXV1NVVNhw0bVl5eTi4sQUFBWufY2NhcvnwZpZaVlYlEIuqeYUE7OztdXFzQwevXrxcWFt64cWPw4MGurq5U7CiVSoIgNm/ebOSc+vp6VKw8PT2pe2iScePGpaam5uXlTZ06VSuJo3v2kCFDvL29jVtUqVTFxcXo94QJEwYPHszArcGDB+MGNDw8PDg4ePHixb///ntra6vJvBMnTuRw9DivRVlZGQPHTFJTU1NRUTF58uTk5GStJNM+ZWdnV1ZWWtCb48eP49+44HR2dgIAurq6qFjg8XgbNmzgcrkAgBcvXhg5U/eGLUVeXp6jo+Pw4cNNn+rl5YUr+7Vr12xsbPSe5uvr++zZM3Sah4cHdVe4XO7169dRRolEwqCX5+XlhSpyXV2d8TMDAwNfRZVnsVj+/v4EQZCbAYSJElpUVKRSqfQmlZSUUCxQWgAAbt68iX77+voGBwczMILYtGmT8RMOHDjA2LgR+Hx+eHg4AODLL7/USjJd5Q0hEAhee+01Znlv3Lih0WjYbDaHwwkMDKSbff369ejHmjVrhpJITU09f/780KFDZ8yYkZOTk5OTo/fdwUxsbW3Pnz8fFRXV29vb0tKilcplbDcmJubNN99kllcikRQUFMyZM0ej0Wg0GrrZcau1cOHChQsXaqV++OGHzLyiyNy5cxcuXNja2nru3Llnz55ppZoQdOfOnRKJpLS01OJuJSUlzZkzh1nexMTElJSUY8eOcbl/+D927FiVSvXGG2+oVKpbt26hgyNGjEAv8ikpKQ8ePKB7FS6X6+fnx2azR4wY8ejRI2TQxcUlPj6exWLt3bsXtXgcDocgCGOGxo8f39fXR35B0r2Sg4MDfoPs6+sbN24cXXdFIhF6yzx48CDdvLq8884748ePDw0NDQkJwQfxm9KhQ4foGpw2bVpJSQl6L0D/6rJt27bw8PDU1FTT5g4dOoQFraysHDFiBDl19uzZOFWpVC5YsICKixwOx9/f/5133kF/Ojk51dfXW0pQvTAWNCIi4smTJwRBAAAIgoiLi2toaCAIorm5Gaspl8u9vLxSU1N1+/Z6EAqFSqWSrGlAQACPx+PxeAEBAb/88gtO+uWXXyh6+cUXX0AIFQpFcXFxcXFxdXU1svBvE3Tp0qVIuIcPHx48eHD48OFCoXD69OmnTp0KCwtbtGiRSCQSCoWDBg2i583hw4fhX8nIyMBvjYiqqqo33niDosFTp05BHRobG19FQ4xAgmo0GlqClpaWEgRx/fr1mTNnzpo1Kysra9euXRbwZvLkyWfPntWVAHPv3j3y4JtJRCLR/fv3tYzQHVihBRJUrVaHhoZSz5WWlobKdVdX13fffff555/Tem0xho2NTVxcnO4Ac1dXV1FREa0BEYRIJEpJScHl3dPTEzfTrwIkaH9/P61cYrG4qakJDdbV1dVptR8WYMWKFatXr75y5QoSIicnZ+7cuYytsdlsgUAgEAioDG2YCTNBEW5ubvPmzWM+OG/FihUrVqxYsWLFYsTExBAEUVJS8k878q+GRu8aDQaPHj3azEs6Ozvv2LHjxIkTZtrRYvHixVohEadPn4YQTpo0yUzLXC531KhR8fHxbW1tmpdUVFTEx8cbibsyTXp6OnrvZJbdzs7O19f34sWLUqm0o6MjICDA1dV19EvM8ozF8vHxUSqVDQ0N+EhkZKRarQYAXLlyhbFZDoczatSohoYGQwMad+/efe+99xhaLysrYyYom80ODQ3Ny8sjj8729/dLpVKlUimXy+VyeV1d3dWrV52dnZn5lpWVBQC4dOkSPtLS0gIAaGpqio2NZWbTzs5u69atRoaHEImJiUysjxw5squri5mgLi4uMpmMIIjq6mo0d3b69OmwsLBJkyahoQculxsWFlZTU1NVVcVAU6FQWFlZCQBYvnw5PtjT0wMAmDZtGl1riEGDBj148EBLu76+vra2tra2NiQFhDA7O9vX15fJBb788ktkorCwkEF2e3t7JyenAQMGGDnn22+/JQji3LlzdI3PnDkTAPD8+XMcMhQfH69UKltbWylG9ehCDk5AUhYVFc2ePZvFYg0bNiwrKwsdd3JyYmafdeTIEWTi888/Z2jCKIsXL5bL5QCA7du3082LBL1//z7609XVFUIIADhz5gwzZ0aOHPnw4UOspkKhiIqKEggEo0ePDgkJwcUTQuju7s7sEiy5XA4hlMvl8+fPZ2jCMOvWrUMP1s8++4xBdi1B9+7da6agM2fOJBfPS5cuCQQCvcPtDx48MF7t9BMcHIzy4xlaC7JmzZqnT5/KZDKxWMxs/BEJWldXh6J6sKDvv/8+M5dEIhFZNYlEcuvWLb0t0vHjx5mMkWNzlhVUIBAkJSWpVKquri4/Pz/GdpCguFFCs4EAgGHDhjEzOGDAgLy8PL0KYsrKysLCwnDMLD1OnDgBIVSpVDgmyXyEQmFNTQ1BED09PYzbYgRZ0IEDBwIAIIS3b99m3LedMGFCa2urISlLS0vDwsIM1XTTxZXP56O27NChQ52dncxc1MLd3T0jI2Ps2LFpaWmRkZF9fX3mWGOz2eRnBQqZqqyslMvlDKx5e3tfuXLF0IxZREREenq6WQ4vWbIE/c+8ePHi8OHDzA29xMPD48mTJ0qlkkFAh15wCW1vb9+0aRMqof39/T09PRKJhFbHlsvlbtq0yUhNr6qqYhzR9Qd+fn7t7e0QQh8fH4tMAfr4+PT398tksjfffJPP5/P5fIZPopdgQTHoGVpfXx8bG0vdOJvNPnr0KFm+7u7uzMxMdPuYjIwMs16UJ06ciAzt37+fuRUdm+fPn8/Jybly5YpCoWhoaDhw4IBYLBaLxfPmzaM7IWpIULpBtlwut6amhqwd+oLGz88vLS2to6MDH9+wYYNAIKBl/E9QowkhtEh918XR0dHJyWnRokUff/zxjh07CIKge6GJEydKpVIAgEKhWLdu3ZkzZyCjVp7H48lkMqza5s2bcUlcs2aNQqHASVlZWcwjClAT/+oExQgEggMHDhAE8dNPP9HNO2HChBkzZqC4LdwPpSuojY0NlqylpQWVwVmzZhUWFpLVhBBu3LiRrod/0t3d/TcI6uvrW1tbSxBEcXHx66+/bo4piwiqUqmOHj0qFovJkZ0Qwt9++y0mJsasgJfc3FyCICwuaHR0dGBgoK+vb3JyskKhUKvVTU1NCxYsMLOBYpEEffvtt+nmPXjwoL62/Q9SU1NXrVplpnssT0/P9vb24uJiHx8fc229ZOrUqUqlEgWyKpXKwsLCWbNmMfgOVy9Y0GPHjtHNa2Njk5SUpCtlc3NzVFTU3xA+xJCFCxei0ZCUlJQZM2ZY1nhgYGBeXh7jNyU7O7u4uDhUVDs6OuLi4saMGfNKg9qsWLFixYoVK1b+Jj7++OOuri4U0FFfX894ptMkbDY7JiZGo9E8fvzY8rH35jB9+nT0/SwZsVg8ffp0Btba2trInfAHDx6Eh4db2mXWoEGDyF8GmTnLYEm0pEQ6isVi/Cddg1qCojfx2tpac0eCSTg5OVVVVWH7dMet/yQoKEihULS2tupdGIkBWmreuHEDJ2FNyQepoCsowlIz3h4eHmVlZWhAA6k5dOhQJoaCgoJkMhkass3NzdVKtbe337Zt28WLF6kbnD59OpYMDSdr1XEsNK26HxwcfPLkyZMnT27evFkkEh07dgzd+e3bt6kbMcTQoUOlUimEEAnKXM21a9cqFAqFQrFr167vv/8eAKC1uFp0dDQAoKenh/oKPrh4GqrXWHG6mpKxtbVFMxnNzc1mjgSyWKxVq1ah/x61Wp2ens6wpn/yySfd3d29vb2LFi1isVhDhgyRSCTk2DZ0GQDAvn37KNokV3Yjp5nzMMVMnToVQlhbW2tmlKSrq2tlZSUS1PhqMcawtbUtKysDAOzcuRMfDAgIwH+6uLigx1Zubi71YRiKRY/8WGDgvLOz84IFCwoKCiCECoUChXcxpqioCKlZWVnJvIkLDQ2FEBq6Hy6Xi6Iq2tvb8TJWJqFYPBG4kNJw+iXr16/HjdLVq1fNGWPdsmUL/sY9ISGBsR1WRUUFAEB3NTGEr68vaqb27NlD0SDdJyM+n0Gtv3v3rkVaeTc3N7ysYWlpKXmOc9GiRdu3b1+yZAklQ/Pnz0eR1Hr7SePGjXv27BkSlPpkAN3HoqUEzc/PZ7bkGYfDKSgoQEPgKpXK2dnZzs5u37599+7dQx/go6Ta2lrTyxCEhYUhvXQFFQgEP/zwA0rVXV7HCHQbbvx8YNDQkwWFEFZUVJDDmini7++PsisUiujo6H379qF19nS7uqbjQw0J6uHhkZ2djSMJaK1ySf3pyWLUyjs5OeGpvdWrV1dWVpLDZVtaWmjFGfN4vNzcXJS3sbGxp6cH/Ualtb6+vrm5GR0pKSkZOHCgaYt1dXUow2ESUqlUo9EUFBS8ePEC0lw2lLpAzF6W8vPzt27dSj4yePDgDRs2PH/+HN0IranKadOm6ZbE9vb2kpKSefPm2djYfPPNN+jgN998Q8mit7d3fX09DmjBrF271sHBoaOjAwBAa6afokZYTbqVva2trbGxcdKkSVq97o8++oiBoOjrJi3Onj3r5ubG4/HwekGbNm3SWzz19CKrqqp8fHwmTZrk7++P62l6evrDhw/xOXfu3KHuIsK4Rjdu3CCfQDcQdcSIEWiZpaKiInxw5syZtIwYYdmyZQEBAR0dHZMnT2axWGVlZfn5+b29vRYwjUoos2eo3lpP7lQx68+Tn5i6qNXqpUuXUremt4TCl+/yEEKZTGZyfVUaMBCUXJexpmjITmvwidn7+4oVK9ADSi+//fYbLWvGBb1169bu3bsZOGkQ1GTRGl7VO5ysBbMXTcyRI0f0avrkyZPIyEhapmJiYvQKqlQqExISTC6/T5tly5bRLaEIQ5qaKSUmOjpaKzL23LlzY8aMoWuHzWYvX748ISHh8ePHyE5paWlCQgLzgFDjCIVCAEBJSYlQKKQb2KX1uGRcx/9TODg4PH36FADQ29vr5eX1T7vzn2DKlClSqZTiooxWrFixYsWK5UBrvZo522XlDzw8PCCEAIAhQ4b80778V0hNTVUoFLSXHLZiCKhv5XXzWbVqVXJyslKpVCqVycnJ+fn5sbGxUVFRFr/Qv4ulS5daXNAxY8Y0NTVpbSeGAAAcOXLEgtd6RQwYMKCiogJCuHr1aq0kE2EKjo6OFl8294cffjAUIcNms1EhXbdunQWvaG9vrxuQRB4vN8KYMWN4PF51dTU+4uHhcfjwYU9PT72zZCYEpT65RtG5devWGV/sjcPhREVF8Xi8pKSkX3/9ldmFQkJCyP8lw4YN09oL4tGjR1RCsrhc7vHjxz09PVtaWoqLi9Gi/WKxGC1P0N/fn52dTc8zNDhoqSpvZM0zXb777jsGl/D29jb0PMHDmgcOHKDSxjo5OTU3N2t9Mk6GyUizBQVduXKlkS1CIYR37tzBM7TozmkFpQYFBendewDT3Ny8bds2ik+w9evXo5CO3t5eiUQikUi0BM3IyGCywFBAQACEsLu728wPMUeNGqW1YSdGKpVmZ2eHhYXZ29v7+vqSk3bs2EHRPgplJectLS3NzMzcvXt3ZmZmZmZmSkoK9a402swMCYcX5poyZcqePXtwwOyUKVOYCMFisTQaDYTQzI59cnKyXjUlEgne8cvOzu7cuXPkVCq7/bFYrLVr1/b29uJcKpUqPDyc8f67fD4/Pz8fqfbjjz+S5zxmzJihVCoBABKJhPnXnw8fPtR6U3J1dQ0LC6M12E7eqgVCmJyc7Ofn5+fnR57axsttkTFp2d7eHoeEq9Xq8vJyWnOculRVVSE1ZTKZVsjJpUuXAABKpdKcrRFYpaWlWFBbW9uIiIju7m4AQH9/P8WVHl1dXbUE1W3o3333XfKiHojm5maTxidOnIiLJ/VHhBGQoFKpVOsz6alTp6KtHdvb2826AIQQAIBmOxISEiCE5eXlERERT58+ra+vp2IBb2WFSElJ0ZrwGj169M2bN7XUpB7gh57Ou3btssj37Fu2bElISNBdDQIVIwAA1UBGQ6SmpqJGbdWqVX19fRcvXkRypKenAwCoWNBaiBN9MmVvb8/n893d3b/++mvdsgkhPHXqFEUPf/75ZwjhBx98wPQWTZOVlYXaEgtsFDpu3Dh8k/jtQigU5uXlGdlXmwxyBXP16tX09PTbt28XFBTgnUG1UKvVum91hnBxcamrq7t7967FNoL/K5MnT0ZPFQAAeYs2hggEgoyMDHKMo6ura01NjZEoZy0MNfGGeP78OcVqxeVy0VT28uXLlUqlVqNsKebPn4869snJyUz6nnotorDmpKQk1Cip1ep9+/ZR7zp8//33FNWUyWTUdy9PTExcuXIl+r169er+/v5XMTCWmJgIAFCpVGZ+AvEXzp49i9/Anj59GhcXRyv7+vXrjYQfIVA4K62lbAiCuHz5Mv6zoKBALpdbMpKLxYqMjGxvbwcAMAg4NIabm1tSUpJUKo2Li6O0w7IO0dHRKEJCV0q5XH7kyBEGsdsSiYQsaHx8PISQ0jaG1PD29kY+AwAofkD3d+9oFR0drRunmpmZSaXLqUtwcPCpU6dCQkIqKipCQ0PPnj1ra2sbEBAgkUjMd5XL5TY0NLi6ugIATpw4YdZqYv8vODs7P378uKOjY/v27bh7b5ESyufzUX8RAEC9A/dfwN3dnTzscv/+feqfoxkBf45VXl5u8a3U/+24u7vn5eXV1NTk5ORYajkHLCj17vB/DQcHh3/aBSuvgP8B3hx7W8D2UQcAAAAASUVORK5CYII=\n" 105 | }, 106 | "metadata": {} 107 | }, 108 | { 109 | "output_type": "display_data", 110 | "data": { 111 | "text/plain": "", 112 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAHAAAABwCAIAAABJgmMcAAAML0lEQVR4nO2dfUwb5R/Aj+Z4TW1W3rrCOl8QdLLCppk0bGPENbpq2JKxmizMLDhHBJZNs6DbVFJBN+fL/hA3sZOpM5uAbODcQF4WAq6JzFIEhTGWgQLqulqY0m6X9rmrfzw/ntxa2l6Px6g/n89fcM9z3+d7H+65K71vnqMoAoFAIBD+NWzdunVwcNDj8XA8nE7nM888I5PJxMWUSCRRUVGB+0RHR0tn8ddnwYIFbW1t9fX1c7ZmZmZu3749NTVVSEoHDhw4efJkWVmZQqFYsGAB2q5Wq6Ojo4VECE5kZGRlZaXL5WL9UF9fL5fLQw1L0/TBgwdHR0fvu+8+31a5XK5UKletWmW1WtFA/kJptVqO465evfrAAw/4tpaVlXEct2XLFiFZWa1WeK4MDQ05nc6BgYHm5ubp6WmGYWZmZmpqatbMEhYWJvxgb8NgMPhTiRCYLp+XX34Z7nv16tXExES0nabpXbt2jY6O8uPbbDabzeYvVHd3N8dxdXV1c7ZCoefOnROS1dKlSw8ePHjt2rWOjg7+XBwfHz927NiRI0fQlpKSklAPmaIoqqqqyu1284+ttrb23XffjYiIqK2tRRurqqpCjVxfXw/3vXDhAprOEomkvLycP5zdbu/t7c3Kylq6dOmccbZs2cKybFChHMcdOnRIYG6JiYlKpfLVV18tLy+///77161bB6egTqe7efMmjOZwOB577LHFixeHcMy7du3i27RYLKtXrw4PD4etcrn85MmTsMntdu/cuTOE0DyhtbW1aOO2bdvgxsrKyvb29q6uLoPBEDgO9GW1WpcvXz5nh0WLFo2Pj3Mc9/PPP6ekpISUpC8bNmw4dOhQfn5+Xl5eT0/PwMCAUKdJSUm9vb38k0Wr1Xr1iYiI+OKLL2Cr2WxWKpXCM0NCrVZrcnIy3NjR0dHe3t7Z2RkfH69SqYTEYRiG47jdu3cH6HP58mV4WmVkZAjPMCjp6eknTpxoaWlZuXKlV5PEt3dCQsKyZcsCR3S5XN3d3fDn5cuXx8fHi0grPj4e3UALCgp0Ol1+fv5vv/02MTERdN8HH3xQIpkjeS/MZrOIxIIyODhosVgefvhho9Ho1RQ8p8bGxr6+PozZVFdXo5/RiWO1WgEAU1NTQiKEh4fv3LmTpmkAwB9//BGgp+8B46KlpUUuly9atCh418zMTDTZz507FxERMWc3jUZz/fp12E2tVgtPhabpjo4OuKPJZBLxKS8zMxNO5OHh4cA916xZ81dMeYqiVq9ezXEc/zYACbv77rvxjvQfJ/iUJ4QEEYoZIhQzRChmiFDMEKGYIUIxQ4RihgjFDBGKGSIUM0QoZohQAoFAIPwbKS4u5jjum2+++bsT+UcTwl3e4/F4PJ60tLR5DqlQKPbt2/fBBx/MM44X+fn5XiURNTU1LMuuWLFinpFpmk5NTa2oqJicnPTMYrFYKioqAtRdBaeuro5lWYEPJn2RSqUajeb06dM2m+3atWs5OTkqlSptlnllRlEPPfQQwzBXrlxBW7Zv3+52uwEAX375peiwEokkNTX1ypUr/iqRLl68+Pjjj4uMbjabxQkNCwvbtGlTS0sLv2bo1q1bNpuNYRiHw+FwOIaHh8+ePatQKMTldurUKQBAU1MT2jI+Pg4AGB0dLSsrExdTKpW++OKL/lQiDhw4ICZ6SkrK1NSUOKHJycl2u53juIGBgTNnzpw5c6ampkav169YseKuu+6iKIqmab1ePzg42N/fL8JpYmJiX18fAOCpp55CG2/cuAEAWLVqVajRIHFxcZcuXfJyd/PmzcnJycnJSaiCZdnGxkaNRiNmgLfffhuGaG9vF7G7TCaLjY2NjIwM0Oejjz7iOO748eOhBl+7di0A4Pfff0clQxUVFQzDTExMCKzq8YVfnABVnj9//tFHH6UoKikp6dSpU3B7bGysuPjU4cOHYYg333xTZIiA5OfnOxwOAMDevXtD3RcK/eGHH+CvKpWKZVkAwLFjx8Qlk5KSMjIygmw6nc6ioqKYmJi0tLS8vDx0erIse++994obgnI4HCzLOhyO9evXiwzhn9LSUnhhfeONN0Ts7iX09ddfn6fQtWvX8k/PpqammJiYTz75xPcCeunSpcDTbm50Oh3cv6urS1yKAdi2bduvv/5qt9sNBoO44mAodHh4GFb1IKFPPPGEuJSUSiXfmslk6urqmvOOVF1dTdN0yAOgcHiFxsTEVFVVuVyuqamp7Oxs0XGgUHRTeuutt6DQpKQkcQEjIyNbWlrmNIgwm816vR7VzCJIbRNmyPehmCFCMUOEYoYIxQwRihkiFDNEKGaIUMwQoZghQjFDhGKGCMUMEUogEAj/bZ577rmpqSlY0HH58mXRTzqDEhYWVlxc7PF4fvrpJ/gA/J9Cbm5uZ2en53YMBkNubq6IaJOTk14PzgoKCnCnTMXFxX344YdoFNFP/PHjpRJ6NBgM6NdQA3oJZVnW5XINDQ3dc889uHKOjY3t7+/nP7YTWeSi1WqdTufExMScCyOJwMtmZ2cnakJO+RuF4CsUguuJt1qtNpvNHMchmwsXLhQTSKvV2u12+ECxubnZq1Umk+3Zs+f06dPCA+bm5iJlBoPBd44j0SHNfZ1Od/To0aNHj+7evVupVL7//vvwyL/++mvhQfyxcOFCm83GsiwUKt5mSUmJ0+l0Op2vvPLKZ599BgDwWlzt2WefBQDcuHFD+Ao+6PT0N6+R8VCd8omKijpy5AjLsmNjY3feeae4IIjCwkL453G73XV1dSJn+vPPPz89PT0zM7Nx40aKohISEkwmE7+2DQ4DANi/f7/AmPzJHqDbfC6miJUrV7IsOzQ0NM8qSZVK1dfXB4UGXi0mEFFRUWazGQDw0ksvoY05OTno1+TkZHjZam5uFl43IfDU418WRCSvUCg2bNjQ1tbGsqzT6YTlXaI5f/48tNnX1yf+Frdp0yaWZf0dD03TsKril19+QctYBUXg6QlBJ2kISc+yY8cOdFM6e/asiLUOES+88ALDMDDUa6+9JjoOZbFYAAC+q4lBNBoNvE1VVlYKDBjqlRH1FzHrL168iOUuv3jxYrSsYU9PT0xMDGrauHHj3r17n3zySUGB1q9fDyup5/yclJ6efv36dSi0sLBQYHKhXhZxCW1tbRW35JlEImlra4MFgS6XS6FQSKXS/fv3f/vttxzHeWaXUh0aGoqLi/Pal9Q2YYZ8H4oZIhQzRChmiFDMEKGYIUIxQ4RihgjFDBGKGSIUM0QoZohQzBChBALhNjiOGxkZmefTLsL/UKvV8OlmQkLC353L/wsnTpxwOp2+3/UTRMLe/uIjXBQWFhqNRoZhGIYxGo2tra1lZWVFRUXYB/pnsXnzZuxClyxZMjo66vU6MQgA4PDhwxjH+ouIjIy0WCwsyz799NNeTUHKFORyufh32vnh888/91chExYWBk/S0tJSjCPKZDLfgqSRkREh+y5ZsiQ8PHxgYABtUavV7733XkZGxpyVA0GEii448JdcaWlp4MXeJBJJUVFReHh4VVXV999/L26gvLw8/p8kKSkpPT2d3+HHH38UUpJF03R1dXVGRsb4+Hh3d/d3331HUZTBYIBvOLt161ZjY2NomRUXF2Oc8gHWPPPl008/FTHEsmXL/F1PIAzDvPPOO0LusbGxsWNjY2AWeEXiU15eHnJ+GIVu3bo1wCtCWZa9cOHC2NgY/8hDKkrVarX8CmNfxsbG9uzZI/AKtmPHDljSMTMzYzKZTCaTl9D6+noxCwzl5OSwLDs9PT2fIiGKolJTU71e2Imw2WyNjY16vV4mk2k0Gn7Tvn37BMaHpaz8fXt6ehoaGsrLyxsaGhoaGj7++GPhH6Xhy8ygOLQwV1ZWVmVlJSqYzcrKEiOCoiiPx8Oy7Dw/2BuNxjltmkwmnU4H+0il0uPHj/Nbhbztj6KokpKSmZkZtJfL5SooKBD9/t3o6OjW1lZo7auvvuJXxT7yyCMMwwAATCaTmOWaICMjI17/KalUKr1eH1JBLCpjgxiNxuzs7Ozs7DvuuAP1Qctt8QkaWSaToZJwt9vd29u7efPmUI7Pm/7+fmjTbrd7LW3X1NQEAGAYZt26df52J7VNmCHfh2KGCMUMEYoZIhQzRChmiFDMEKGYIUIxQ4RihgjFDBGKGSIUM0QoZv4EFd651ol7tbUAAAAASUVORK5CYII=\n" 113 | }, 114 | "metadata": {} 115 | }, 116 | { 117 | "output_type": "display_data", 118 | "data": { 119 | "text/plain": "", 120 | "image/png": "\n" 121 | }, 122 | "metadata": {} 123 | } 124 | ], 125 | "source": [ 126 | "latents = np.random.normal(size=[16, 1, 28, 28])\n", 127 | "images = test_data[:16]\n", 128 | "image_mask = np.ones_like(images)\n", 129 | "image_mask[:, :, 14:] = 0\n", 130 | "show_sample(images)\n", 131 | "show_sample(images*image_mask)\n", 132 | "\n", 133 | "samples = diffusion.ddpm_sample_cond_energy_inpaint(\n", 134 | " latents, cnn_model, images, image_mask, temp=1.0, eps=1e-2\n", 135 | ")\n", 136 | "show_sample(images*image_mask + samples*(1-image_mask))" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 12, 142 | "metadata": {}, 143 | "outputs": [ 144 | { 145 | "output_type": "display_data", 146 | "data": { 147 | "text/plain": "", 148 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAHAAAABwCAIAAABJgmMcAAADGklEQVR4nO3av0tqcRzG8Y+hoAgiOChKU26CiGuri0ObS4gIjjUHKTg7+R/4B9ioi+I/kEQtURAK/kCDwlElUY7cwYNXLrebeux0nnueZ+oc9MUXjRJ5i3Acx3Ecx3FbzKLlyS6Xy+fz/XGz1WqZ2dztBT07O7u8vFxf+v3+UCi0+YBer3dyckLz60UikU6ns1gslE82m82KxaLH49n+iCY1Y7FYqVT6zFIUpdvtXl9fWyw7/LL/x+YXeiwWu7m5cbvd6zv39/eDweDx8TEcDovIZDK5uroajUY7ndKk5sXFxXg8Xr8b8/k8mUw6HI7tz0Tz91wu13A4XHGLxeLh4eH8/FzLEU1uSjQaXb9FuVxOK2d6U0Sk0+koipLP54+OjmhuuX89+v39XUTa7fZyudR0OpqrBQKBl5eXu7s7jX/gaYrVarXZbCKSSqVms1m9Xrfb7TT3X6FQSKfTq58zmczHx0e5XKa5/5bLZaVSWV82Go3JZBKJRGh+ub//U7q9vd28bDabDofD6XTue0izmxKPx19fX6PRqIgkEonpdKooyunpKc095/V6+/3+29tbNptdf8TViJrZFBEJBoOrD7erPT09BQIBmlrdWq32/PxcrVaPj48PIJrbVLf5/RVNjuM4juM4jjPk2DYd2GTb9EMmQDOEYhqkGUIx2TbpaGI0QygmSjOEYsI0QyimiMGaIRSTbZN+JkIzhGKiNEMoJkwzhGIasRlCMdk26WHCNEMoJkwzhGKK4DRDKKbqQjRDKKY6lGYIxeQ4juM4juO4bxjbpgObbJt+yARohlBMgzRDKCbbJh1NjGYIxURphlBMmGYIxRQxWDOEYrJt0s9EaIZQTJRmCMWEaYZQTCM2Qygm2yY9TJhmCMWEaYZQTBGcZgjFVF2IZgjFVIfSDKGYHMdxHMdxHPcNY9t0YJNt0w+ZAM0QimmQZgjFZNuko4nRDKGYKM0QignTDKGYIgZrhlBMtk36mQjNEIqJ0gyhmDDNEIppxGYIxWTbpIcJ0wyhmDDNEIopgtMMoZiqC9EMoZjqUJohFJPbbb8A/1X9jy48d6UAAAAASUVORK5CYII=\n" 149 | }, 150 | "metadata": {} 151 | }, 152 | { 153 | "output_type": "display_data", 154 | "data": { 155 | "text/plain": "", 156 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAHAAAABwCAIAAABJgmMcAAAB80lEQVR4nO3aMWrCcBiG8aRbliC4iJs4CuIhsnqA4A3MLOQS3sALeBddHB1U0C2jDmIIOBRSKYiKL/7zmee3Vdq34Wmn8HkeAAAAnuC/88NhGLZarX8frtfrOm++FnQ4HCZJUn7Zbrd7vd7tN+x2u263y+Zjg8Fgs9nkeV7ccT6fp9Nps9l8/hFruhlF0Ww2u7dVFMV2u03T1Pdf+Gf/4s0H61EUzefzRqNRfrJYLPb7/Wq16vf7nuedTqfJZJJl2UtPWdPN8Xh8PB7Lv8blchmNRkEQPP9MbP4Jw/BwOPzO5Xm+XC7jOH7nEWuy6Xc6nTd/PW79uH6Ab0NQMYKKEVSMoGIEFSOoGEHFCCpGUDGCihFUjKAAAABwgdsm8Sa3TY42DdwMWdmsyM2QlU1umz64aeNmyMpm1W6GrGxy2yTG+1AxgooRVIygYgQVI6gYQcUIKkZQMYKKEVSMoGIEBQAAgAvcNok3uW1ytGngZsjKZkVuhqxsctv0wU0bN0NWNqt2M2Rlk9smMd6HihFUjKBiBBUjqBhBxQgqRlAxgooRVIygYgQVIygAAABc4LZJvMltk6NNAzdDVjYrcjNkZZPbpg9u2rgZsrJZtZshK5vcNonxPlSMoGIEFSOoGEHFCCpGUDGCihFUjKBiBBUjqBhBxa4wg/Ve5ngsIQAAAABJRU5ErkJggg==\n" 157 | }, 158 | "metadata": {} 159 | }, 160 | { 161 | "output_type": "display_data", 162 | "data": { 163 | "text/plain": "", 164 | "image/png": "\n" 165 | }, 166 | "metadata": {} 167 | } 168 | ], 169 | "source": [ 170 | "latents = np.random.normal(size=[16, 1, 28, 28])\n", 171 | "images = np.tile(test_data[14:15], [16, 1, 1, 1])\n", 172 | "image_mask = np.ones_like(images)\n", 173 | "image_mask[:, :, 14:] = 0\n", 174 | "show_sample(images)\n", 175 | "show_sample(images*image_mask)\n", 176 | "\n", 177 | "samples = diffusion.ddpm_sample_cond_energy_inpaint(\n", 178 | " latents, cnn_model, images, image_mask, temp=1.0, eps=1e-2\n", 179 | ")\n", 180 | "show_sample(images*image_mask + samples*(1-image_mask))" 181 | ] 182 | } 183 | ] 184 | } -------------------------------------------------------------------------------- /mnist_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | from torchvision import datasets, transforms 6 | 7 | from ddim import Diffusion, CNNPredictor, create_alpha_schedule 8 | 9 | USE_CUDA = torch.cuda.is_available() 10 | DEVICE = torch.device("cpu" if not USE_CUDA else "cuda") 11 | SAVE_PATH = "mnist_model.pt" 12 | 13 | 14 | def main(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--batch-size", default=32, type=int) 17 | parser.add_argument("--lr", default=1e-4, type=float) 18 | parser.add_argument("--save-interval", default=1000, type=int) 19 | args = parser.parse_args() 20 | train_data, test_data = create_datasets(args.batch_size, USE_CUDA) 21 | diffusion = Diffusion( 22 | create_alpha_schedule(num_steps=100, beta_0=0.001, beta_T=0.2) 23 | ) 24 | model = CNNPredictor((1, 28, 28)) 25 | if os.path.exists(SAVE_PATH): 26 | model.load_state_dict(torch.load(SAVE_PATH)) 27 | model.to(DEVICE) 28 | opt = torch.optim.Adam(model.parameters(), lr=args.lr) 29 | 30 | step = 0 31 | loaders = zip(iterate_loader(train_data), iterate_loader(test_data)) 32 | for train_batch, test_batch in loaders: 33 | train_loss = compute_loss(diffusion, model, train_batch) 34 | with torch.no_grad(): 35 | test_loss = compute_loss(diffusion, model, test_batch) 36 | print(f"step {step}: test={test_loss:.5f} train={train_loss:.5f}") 37 | step += 1 38 | opt.zero_grad() 39 | train_loss.backward() 40 | opt.step() 41 | if not step % args.save_interval: 42 | model.cpu() 43 | torch.save(model.state_dict(), SAVE_PATH) 44 | model.to(DEVICE) 45 | 46 | 47 | def compute_loss(diffusion, model, batch): 48 | ts = torch.randint(low=1, high=diffusion.num_steps + 1, size=(batch.shape[0],)).to( 49 | batch.device 50 | ) 51 | epsilon = torch.randn_like(batch) 52 | samples = ( 53 | torch.from_numpy( 54 | diffusion.sample_q( 55 | batch.cpu().numpy(), ts.cpu().numpy(), epsilon=epsilon.cpu().numpy() 56 | ) 57 | ) 58 | .float() 59 | .to(batch.device) 60 | ) 61 | alphas = torch.from_numpy(diffusion.alphas_for_ts(ts.cpu().numpy())).to( 62 | batch.device 63 | ) 64 | predictions = model(samples, alphas.float()) 65 | return torch.mean((epsilon - predictions) ** 2) 66 | 67 | 68 | def iterate_loader(loader): 69 | while True: 70 | for x, _ in loader: 71 | yield x.to(DEVICE) 72 | 73 | 74 | def create_datasets(batch, use_cuda): 75 | # Taken from pytorch MNIST demo. 76 | kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} 77 | train_loader = torch.utils.data.DataLoader( 78 | datasets.MNIST( 79 | "mnist_data", 80 | train=True, 81 | download=True, 82 | transform=transforms.Compose( 83 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] 84 | ), 85 | ), 86 | batch_size=batch, 87 | shuffle=True, 88 | **kwargs, 89 | ) 90 | test_loader = torch.utils.data.DataLoader( 91 | datasets.MNIST( 92 | "mnist_data", 93 | train=False, 94 | transform=transforms.Compose( 95 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] 96 | ), 97 | ), 98 | batch_size=batch, 99 | shuffle=True, 100 | **kwargs, 101 | ) 102 | return train_loader, test_loader 103 | 104 | 105 | if __name__ == "__main__": 106 | main() 107 | --------------------------------------------------------------------------------