├── .gitignore ├── 2d_conditional.ipynb ├── 2d_distribution.ipynb ├── README.md ├── checkpoints └── mnist_model.pt ├── ddim ├── __init__.py ├── ddim.py └── predictor.py ├── demo.ipynb ├── mnist_analysis.ipynb ├── mnist_conditional.ipynb └── mnist_train.py /.gitignore: -------------------------------------------------------------------------------- 1 | mnist_data 2 | .ipynb_checkpoints 3 | __pycache__ 4 | 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DDIM 2 | 3 | This repo contains some experiments with [Denoising Diffusion Implicit Models](https://openreview.net/forum?id=St1giarCHLP). In particular, I play with simple distributions and see how DDIM translates latents to datapoints. This is a particularly interesting class of models because the latent space is intrinsic, it's not learned, and it's really cool to investigate. 4 | -------------------------------------------------------------------------------- /checkpoints/mnist_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unixpickle/ddim/27950a639afbbe3de8f8f72d8605f14f91a5742a/checkpoints/mnist_model.pt -------------------------------------------------------------------------------- /ddim/__init__.py: -------------------------------------------------------------------------------- 1 | from .ddim import Diffusion, create_alpha_schedule 2 | from .predictor import Predictor, CNNPredictor, BayesPredictor, train_predictor 3 | 4 | __all__ = [ 5 | "Diffusion", 6 | "create_alpha_schedule", 7 | "Predictor", 8 | "CNNPredictor", 9 | "BayesPredictor", 10 | "train_predictor", 11 | ] 12 | -------------------------------------------------------------------------------- /ddim/ddim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm.auto import tqdm 4 | 5 | 6 | def create_alpha_schedule(num_steps=100, beta_0=0.0001, beta_T=0.02): 7 | betas = np.linspace(beta_0, beta_T, num_steps) 8 | result = [1.0] 9 | alpha = 1.0 10 | for beta in betas: 11 | alpha *= 1 - beta 12 | result.append(alpha) 13 | return np.array(result, dtype=np.float64) 14 | 15 | 16 | class Diffusion: 17 | """ 18 | A numpy implementation of the DDPM and DDIM setup. 19 | """ 20 | 21 | def __init__(self, alphas): 22 | self.alphas = alphas 23 | 24 | @property 25 | def num_steps(self): 26 | return len(self.alphas) - 1 27 | 28 | def sample_q(self, x_0, ts, epsilon=None): 29 | """ 30 | Sample from q(x_t | x_0) for a batch of x_0. 31 | """ 32 | if epsilon is None: 33 | epsilon = np.random.normal(size=x_0.shape) 34 | alphas = self.alphas_for_ts(ts, x_0.shape) 35 | return np.sqrt(alphas) * x_0 + np.sqrt(1 - alphas) * epsilon 36 | 37 | def predict_x0(self, x_t, ts, epsilon_prediction): 38 | alphas = self.alphas_for_ts(ts, x_t.shape) 39 | return (x_t - np.sqrt(1 - alphas) * epsilon_prediction) / np.sqrt(alphas) 40 | 41 | def ddim_previous(self, x_t, ts, epsilon_prediction): 42 | """ 43 | Take a ddim sampling step given x_t, t, and epsilon prediction. 44 | """ 45 | x_0 = self.predict_x0(x_t, ts, epsilon_prediction) 46 | return self.sample_q(x_0, ts - 1, epsilon=epsilon_prediction) 47 | 48 | def ddim_sample(self, x_T, predictor, progress=False): 49 | """ 50 | Sample x_0 from x_t using DDIM, assuming a method on predictor called 51 | predict_epsilon(x_t, alphas). 52 | """ 53 | x_t = x_T 54 | t_iter = range(1, self.num_steps + 1)[::-1] 55 | if progress: 56 | t_iter = tqdm(t_iter) 57 | for t in t_iter: 58 | ts = np.array([t] * x_T.shape[0]) 59 | alphas = self.alphas_for_ts(ts) 60 | x_t = self.ddim_previous(x_t, ts, predictor.predict_epsilon(x_t, alphas)) 61 | return x_t 62 | 63 | def ddim_sample_cond(self, x_T, predictor, x_cond, mask): 64 | """ 65 | Like ddim_sample(), but condition on the part of x_cond which is True 66 | in the boolean mask. 67 | """ 68 | x_t = x_T 69 | for t in range(1, self.num_steps + 1)[::-1]: 70 | ts = np.array([t] * x_T.shape[0]) 71 | alphas = self.alphas_for_ts(ts) 72 | x_t = np.where(mask, self.sample_q(x_cond, ts), x_t) 73 | x_t = self.ddim_previous(x_t, ts, predictor.predict_epsilon(x_t, alphas)) 74 | return np.where(mask, x_cond, x_t) 75 | 76 | def ddpm_previous( 77 | self, x_t, ts, epsilon_prediction, epsilon=None, cond_prediction=None 78 | ): 79 | if epsilon is None: 80 | epsilon = np.random.normal(size=x_t.shape) 81 | alphas_t = self.alphas_for_ts(ts, x_t.shape) 82 | alphas_prev = self.alphas_for_ts(ts - 1, x_t.shape) 83 | alphas = alphas_t / alphas_prev 84 | betas = 1 - alphas 85 | prev_mean = (1 / np.sqrt(alphas)) * ( 86 | x_t - betas / np.sqrt(1 - alphas_t) * epsilon_prediction 87 | ) 88 | if cond_prediction is not None: 89 | prev_mean += betas * cond_prediction 90 | return prev_mean + np.sqrt(betas) * epsilon 91 | 92 | def ddpm_sample(self, x_T, predictor): 93 | """ 94 | Sample x_0 from x_t using DDPM. 95 | 96 | Usage is the same as ddim_sample(). 97 | """ 98 | x_t = x_T 99 | for t in range(1, self.num_steps + 1)[::-1]: 100 | ts = np.array([t] * x_T.shape[0]) 101 | alphas = self.alphas_for_ts(ts) 102 | x_t = self.ddpm_previous(x_t, ts, predictor.predict_epsilon(x_t, alphas)) 103 | return x_t 104 | 105 | def ddpm_sample_cond(self, x_T, predictor, x_cond, mask, num_subsamples=1): 106 | """ 107 | Create a masked-conditional sample using DDPM. 108 | 109 | See ddim_sample_cond() for usage details. 110 | """ 111 | x_t = x_T 112 | for t in range(1, self.num_steps + 1)[::-1]: 113 | samples = [] 114 | for _ in range(num_subsamples): 115 | ts = np.array([t] * x_T.shape[0]) 116 | x_t = np.where(mask, self.sample_q(x_cond, ts), x_t) 117 | alphas = self.alphas_for_ts(ts) 118 | x_next = self.ddpm_previous( 119 | x_t, ts, predictor.predict_epsilon(x_t, alphas) 120 | ) 121 | samples.append(x_next) 122 | x_t = np.mean(samples, axis=0) 123 | return np.where(mask, x_cond, x_t) 124 | 125 | def ddpm_sample_cond_energy(self, x_T, predictor, cond_fn): 126 | """ 127 | Create a sample using an energy function cond_fn as a conditioning 128 | signal, to compute p(x)*p(y|x), where cond_fn is grad_x log(p(y|x)). 129 | """ 130 | x_t = x_T 131 | for t in range(1, self.num_steps + 1)[::-1]: 132 | ts = np.array([t] * x_T.shape[0]) 133 | alphas = self.alphas_for_ts(ts) 134 | x_t = self.ddpm_previous( 135 | x_t, 136 | ts, 137 | predictor.predict_epsilon(x_t, alphas), 138 | cond_prediction=cond_fn(x_t, alphas), 139 | ) 140 | return x_t 141 | 142 | def ddpm_sample_cond_energy_inpaint( 143 | self, x_T, predictor, x_cond, mask, temp=1.0, eps=1e-2 144 | ): 145 | def cond_fn(x_t, alphas): 146 | while len(alphas.shape) < len(x_t.shape): 147 | alphas = alphas[..., None] 148 | with torch.enable_grad(): 149 | alphas_torch = torch.from_numpy(alphas).float() 150 | x_t_torch = torch.from_numpy(x_t).float().requires_grad_(True) 151 | eps_pred = predictor(x_t_torch, alphas_torch.view(-1)) 152 | x_start = ( 153 | x_t_torch - (1 - alphas_torch).sqrt() * eps_pred 154 | ) / alphas_torch.sqrt() 155 | 156 | # This should be the variance of the x_start prediction, 157 | # but instead we use the variance of a signal noised to 158 | # the current timestep as a reasonable guess. 159 | sigmas = eps + 1 - alphas_torch 160 | 161 | log_density = -((torch.from_numpy(x_cond) - x_start) ** 2) / ( 162 | 2 * sigmas 163 | ) 164 | loss = (log_density * torch.from_numpy(mask).float()).sum() 165 | grad = torch.autograd.grad(loss, x_t_torch)[0] 166 | return grad.detach().numpy() / temp 167 | 168 | return self.ddpm_sample_cond_energy(x_T, predictor, cond_fn) 169 | 170 | def alphas_for_ts(self, ts, shape=None): 171 | alphas = self.alphas[ts] 172 | if shape is None: 173 | return alphas 174 | while len(alphas.shape) < len(shape): 175 | alphas = alphas[..., None] 176 | return alphas 177 | -------------------------------------------------------------------------------- /ddim/predictor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.optim import Adam 5 | from tqdm.auto import tqdm 6 | 7 | 8 | def train_predictor(diffusion, data_batches, lr=1e-4): 9 | predictor = None 10 | optim = None 11 | losses = [] 12 | for batch in tqdm(data_batches): 13 | if predictor is None: 14 | predictor = Predictor(batch.shape[1:]) 15 | optim = Adam(predictor.parameters(), lr=lr) 16 | dev = next(predictor.parameters()).device 17 | ts = torch.randint( 18 | low=1, high=diffusion.num_steps + 1, size=(batch.shape[0],) 19 | ).to(dev) 20 | epsilon = torch.randn(*batch.shape).to(dev) 21 | samples = ( 22 | torch.from_numpy( 23 | diffusion.sample_q( 24 | batch, ts.cpu().numpy(), epsilon=epsilon.cpu().numpy() 25 | ) 26 | ) 27 | .float() 28 | .to(dev) 29 | ) 30 | alphas = torch.from_numpy(diffusion.alphas_for_ts(ts.cpu().numpy())).to(dev) 31 | predictions = predictor(samples, alphas.float()) 32 | loss = torch.mean((epsilon - predictions) ** 2) 33 | losses.append(loss.item()) 34 | optim.zero_grad() 35 | loss.backward() 36 | optim.step() 37 | return predictor, losses 38 | 39 | 40 | class Predictor(nn.Module): 41 | def __init__(self, data_shape, num_layers=1, channels=128): 42 | super().__init__() 43 | self.data_shape = data_shape 44 | 45 | self.register_buffer( 46 | "timestep_coeff", torch.linspace(start=0.1, end=100, steps=channels)[None] 47 | ) 48 | self.timestep_phase = nn.Parameter(torch.randn(channels)[None]) 49 | self.input_embed = nn.Linear(int(np.prod(data_shape)), channels) 50 | self.timestep_embed = nn.Sequential( 51 | nn.Linear(channels, channels), 52 | nn.GELU(), 53 | nn.Linear(channels, channels), 54 | ) 55 | self.layers = nn.Sequential( 56 | nn.GELU(), 57 | *[ 58 | nn.Sequential(nn.Linear(channels, channels), nn.GELU()) 59 | for _ in range(num_layers) 60 | ], 61 | nn.Linear(channels, int(np.prod(data_shape))), 62 | ) 63 | 64 | def forward(self, inputs, alphas): 65 | embed_alphas = torch.sin( 66 | (self.timestep_coeff * alphas.float()[:, None]) + self.timestep_phase 67 | ) 68 | embed_alphas = self.timestep_embed(embed_alphas) 69 | embed_ins = self.input_embed(inputs.view(inputs.shape[0], -1)) 70 | out = self.layers(embed_ins + embed_alphas) 71 | return out.view(inputs.shape) 72 | 73 | def predict_epsilon(self, inputs_np, alphas_np): 74 | dev = next(self.parameters()).device 75 | inputs = torch.from_numpy(inputs_np).float().to(dev) 76 | alphas = torch.from_numpy(alphas_np).float().to(dev) 77 | with torch.no_grad(): 78 | return self(inputs, alphas).detach().cpu().numpy().astype(inputs_np.dtype) 79 | 80 | 81 | class CNNPredictor(nn.Module): 82 | def __init__(self, data_shape, num_res_blocks=7, channels=128): 83 | super().__init__() 84 | assert len(data_shape) == 3 85 | self.data_shape = data_shape 86 | 87 | self.register_buffer( 88 | "timestep_coeff", 89 | torch.linspace(start=0.1, end=1000, steps=channels * 4)[None], 90 | ) 91 | self.timestep_phase = nn.Parameter(torch.randn(channels * 4)[None]) 92 | self.timestep_embed = nn.Sequential( 93 | nn.Linear(channels * 4, channels), 94 | nn.GELU(), 95 | nn.Linear(channels, channels), 96 | ) 97 | self.input_embed = nn.Conv2d(data_shape[0], channels, 1) 98 | self.res_blocks = nn.ModuleList([]) 99 | self.timestep_blocks = nn.ModuleList([]) 100 | for _ in range(num_res_blocks): 101 | block = nn.Sequential( 102 | nn.GroupNorm(8, channels), 103 | nn.GELU(), 104 | nn.Conv2d(channels, channels, 3, padding=1), 105 | nn.GELU(), 106 | nn.Conv2d(channels, channels, 3, padding=1), 107 | SELayer(channels), 108 | ) 109 | self.res_blocks.append(block) 110 | self.timestep_blocks.append( 111 | nn.Sequential( 112 | nn.Linear(channels, channels), 113 | nn.GELU(), 114 | nn.Linear(channels, channels), 115 | ) 116 | ) 117 | self.out_layer = nn.Conv2d(channels, data_shape[0], 3, padding=1) 118 | 119 | def forward(self, inputs, alphas): 120 | assert inputs.shape[1:] == self.data_shape 121 | embed_alphas = torch.sin( 122 | (self.timestep_coeff * alphas.float()[:, None]) + self.timestep_phase 123 | ) 124 | embed_alphas = self.timestep_embed(embed_alphas) 125 | out = self.input_embed(inputs) 126 | for block, ts_block in zip(self.res_blocks, self.timestep_blocks): 127 | out = out + block(out + ts_block(embed_alphas)[..., None, None]) 128 | out = self.out_layer(out) 129 | return out 130 | 131 | def predict_epsilon(self, inputs_np, alphas_np): 132 | dev = next(self.parameters()).device 133 | inputs = torch.from_numpy(inputs_np).float().to(dev) 134 | alphas = torch.from_numpy(alphas_np).float().to(dev) 135 | with torch.no_grad(): 136 | return self(inputs, alphas).detach().cpu().numpy().astype(inputs_np.dtype) 137 | 138 | 139 | class SELayer(nn.Module): 140 | """ 141 | A squeeze-excitation layer, from: 142 | https://github.com/moskomule/senet.pytorch/blob/23839e07525f9f5d39982140fccc8b925fe4dee9/senet/se_module.py 143 | 144 | This layer provides global context to each local part of an image, 145 | allowing for larger receptive field without much extra compute. 146 | """ 147 | 148 | def __init__(self, channel, reduction=8): 149 | super().__init__() 150 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 151 | self.fc = nn.Sequential( 152 | nn.Linear(channel, channel // reduction, bias=False), 153 | nn.ReLU(inplace=True), 154 | nn.Linear(channel // reduction, channel, bias=False), 155 | nn.Sigmoid(), 156 | ) 157 | 158 | def forward(self, x): 159 | b, c, _, _ = x.size() 160 | y = self.avg_pool(x).view(b, c) 161 | y = self.fc(y).view(b, c, 1, 1) 162 | return x * y.expand_as(x) 163 | 164 | 165 | class BayesPredictor(nn.Module): 166 | """ 167 | An epsilon predictor that uses Bayes rule to predict epsilon without any 168 | learnable parameters--just a bunch of data. 169 | """ 170 | 171 | def __init__(self, data_batch): 172 | super().__init__() 173 | self.data_batch = data_batch 174 | 175 | def forward(self, inputs, alphas): 176 | while len(alphas.shape) < len(inputs.shape): 177 | alphas = alphas[..., None] 178 | means = torch.sqrt(alphas)[:, None] * self.data_batch[None] 179 | variances = 1 - alphas 180 | while len(variances.shape) < len(means.shape): 181 | variances = variances[..., None] 182 | logits = -( 183 | 0.5 * torch.log(variances) 184 | + (0.5 / variances) * (inputs[:, None] - means) ** 2 185 | ) 186 | while len(logits.shape) > 2: 187 | logits = torch.sum(logits, dim=-1) 188 | logits = logits - torch.max(logits, dim=1, keepdims=True)[0] 189 | probs = torch.exp(logits) 190 | probs = probs / torch.sum(probs, dim=1, keepdims=True) 191 | while len(probs.shape) < len(self.data_batch.shape) + 1: 192 | probs = probs[..., None] 193 | x_0 = torch.sum(torch.from_numpy(self.data_batch[None]) * probs, dim=1) 194 | return (inputs - torch.sqrt(alphas) * x_0) / torch.sqrt(1 - alphas) 195 | 196 | def predict_epsilon(self, inputs_np, alphas_np): 197 | inputs = torch.from_numpy(inputs_np) 198 | alphas = torch.from_numpy(alphas_np) 199 | return self(inputs, alphas).detach().numpy().astype(inputs_np.dtype) 200 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from ddim import BayesPredictor, Diffusion, create_alpha_schedule, train_predictor" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import matplotlib.pyplot as plt\n", 19 | "import numpy as np" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 3, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "USE_BAYES = True\n", 29 | "DATASET = 'bimodal' # uniform, bimodal" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 4, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "def generate_data(batch_size, num_batches):\n", 39 | " if DATASET == 'uniform':\n", 40 | " return np.random.uniform(size=(num_batches, batch_size, 1))\n", 41 | " elif DATASET == 'bimodal':\n", 42 | " raw_data = 0.2 * np.random.uniform(size=(num_batches, batch_size, 1))\n", 43 | " offsets = np.random.randint(low=0, high=2, size=raw_data.shape).astype(raw_data.dtype)\n", 44 | " return raw_data - 0.1 + (offsets - 0.5) * 2\n", 45 | " else:\n", 46 | " raise ValueError(DATASET)" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 5, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "diffusion = Diffusion(create_alpha_schedule(num_steps=1000))" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 6, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "if USE_BAYES:\n", 65 | " model = BayesPredictor(generate_data(1000, 1)[0])\n", 66 | "else:\n", 67 | " data = generate_data(batch_size=1000, num_batches=1000)\n", 68 | " print('mean', np.mean(data), 'std', np.std(data))\n", 69 | " model, losses = train_predictor(diffusion, data, lr=2e-3)\n", 70 | " plt.plot(losses)\n", 71 | " plt.show()\n", 72 | " print('final loss', np.mean(losses[-10:]))" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 7, 78 | "metadata": {}, 79 | "outputs": [ 80 | { 81 | "data": { 82 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADaBJREFUeJzt3X+s3fVdx/HnSzqGmXNQuNZK6cpCN8I/wHJDmBgT6TBsGlodki1Gr6amLlEzNxOt7i+NicM/xJmYJQ3grslkYHVpXci2rkCIiWO7CBs/utHSjKxNf9xtMDdjmMW3f9wv5krv5Xzvveece/vp85HcnO/3e76n58339j777bfnHFJVSJLOfT+y2gNIkobDoEtSIwy6JDXCoEtSIwy6JDXCoEtSIwy6JDXCoEtSIwy6JDVi3Tif7LLLLqstW7aM8ykl6Zz3+OOPf7uqJgbtN9agb9myhZmZmXE+pSSd85K80Gc/L7lIUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiPG+k5RSTpX3XXguWU/9sO3vH2IkyzOM3RJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RG9Ap6kouT7E3y9SSHkrwryfokB5Ic7m4vGfWwkqTF9T1D/zjwuaq6GrgWOATsBg5W1VbgYLcuSVolA4Oe5C3AzwL3AFTVD6vqJWA7MN3tNg3sGNWQkqTB+pyhXwnMAn+X5Ikkdyd5E7Chqk50+5wENoxqSEnSYH2Cvg54J/CJqroe+E9ec3mlqgqohR6cZFeSmSQzs7OzK51XkrSIPkE/Bhyrqse69b3MBf5Uko0A3e3phR5cVXuqarKqJicmJoYxsyRpAQODXlUngW8leUe3aRvwLLAfmOq2TQH7RjKhJKmXdT33+z3gU0kuBI4Cv8ncHwYPJNkJvADcMZoRJUl99Ap6VT0JTC5w17bhjiNJWi7fKSpJjTDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5JjTDoktSIdX12SvJN4PvAK8CZqppMsh64H9gCfBO4o6peHM2YkqRBlnKG/nNVdV1VTXbru4GDVbUVONitS5JWyUouuWwHprvlaWDHyseRJC1X36AX8IUkjyfZ1W3bUFUnuuWTwIahTydJ6q3XNXTgZ6rqeJKfAA4k+fr8O6uqktRCD+z+ANgFsHnz5hUNK0laXK8z9Ko63t2eBj4D3ACcSrIRoLs9vchj91TVZFVNTkxMDGdqSdJZBgY9yZuSvPnVZeDngaeB/cBUt9sUsG9UQ0qSButzyWUD8Jkkr+7/D1X1uSRfAR5IshN4AbhjdGNK0srddeC51R5hpAYGvaqOAtcusP07wLZRDCVJWjrfKSpJjTDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5Jjej7eejntJV8IM+Hb3n7ECeRpNHxDF2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGtE76EkuSPJEks9261cmeSzJkST3J7lwdGNKkgZZyhn6h4BD89bvBO6qqquAF4GdwxxMkrQ0vYKeZBPwC8Dd3XqAm4G93S7TwI5RDChJ6qfvx+f+NfCHwJu79UuBl6rqTLd+DLh8oQcm2QXsAti8efOyB13JR+BK0vlg4Bl6kl8ETlfV48t5gqraU1WTVTU5MTGxnF9CktRDnzP0m4DbkrwXuAj4ceDjwMVJ1nVn6ZuA46MbU5I0yMAz9Kr646raVFVbgPcDD1XVrwIPA7d3u00B+0Y2pSRpoJW8Dv2PgI8kOcLcNfV7hjOSJGk5lvT/FK2qR4BHuuWjwA3DH0mStBy+U1SSGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRA4Oe5KIkX07y1STPJPnTbvuVSR5LciTJ/UkuHP24kqTF9DlDfxm4uaquBa4Dbk1yI3AncFdVXQW8COwc3ZiSpEEGBr3m/KBbfUP3VcDNwN5u+zSwYyQTSpJ66XUNPckFSZ4ETgMHgOeBl6rqTLfLMeDy0YwoSeqjV9Cr6pWqug7YBNwAXN33CZLsSjKTZGZ2dnaZY0qSBlnSq1yq6iXgYeBdwMVJ1nV3bQKOL/KYPVU1WVWTExMTKxpWkrS4Pq9ymUhycbf8o8AtwCHmwn57t9sUsG9UQ0qSBls3eBc2AtNJLmDuD4AHquqzSZ4FPp3kz4EngHtGOKckaYCBQa+qrwHXL7D9KHPX0yVJa4DvFJWkRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWrEwKAnuSLJw0meTfJMkg9129cnOZDkcHd7yejHlSQtps8Z+hngD6rqGuBG4HeSXAPsBg5W1VbgYLcuSVolA4NeVSeq6t+75e8Dh4DLge3AdLfbNLBjVENKkgZb0jX0JFuA64HHgA1VdaK76ySwYaiTSZKWpHfQk/wY8E/A71fVf8y/r6oKqEUetyvJTJKZ2dnZFQ0rSVpcr6AneQNzMf9UVf1zt/lUko3d/RuB0ws9tqr2VNVkVU1OTEwMY2ZJ0gL6vMolwD3Aoar6q3l37QemuuUpYN/wx5Mk9bWuxz43Ab8GPJXkyW7bnwAfAx5IshN4AbhjNCNKkvoYGPSq+lcgi9y9bbjjSJKWy3eKSlIjDLokNcKgS1IjDLokNcKgS1IjDLokNcKgS1IjDLokNcKgS1IjDLokNcKgS1IjDLokNcKgS1IjDLokNcKgS1IjDLokNcKgS1IjDLokNcKgS1IjDLokNcKgS1IjDLokNcKgS1IjBgY9yb1JTid5et629UkOJDnc3V4y2jElSYP0OUP/JHDra7btBg5W1VbgYLcuSVpFA4NeVY8C333N5u3AdLc8DewY8lySpCVa7jX0DVV1ols+CWxYbMcku5LMJJmZnZ1d5tNJkgZZ8T+KVlUB9Tr376mqyaqanJiYWOnTSZIWsdygn0qyEaC7PT28kSRJy7HcoO8HprrlKWDfcMaRJC1Xn5ct3gf8G/COJMeS7AQ+BtyS5DDw7m5dkrSK1g3aoao+sMhd24Y8iyRpBXynqCQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiNWFPQktyb5RpIjSXYPayhJ0tItO+hJLgD+FngPcA3wgSTXDGswSdLSrOQM/QbgSFUdraofAp8Gtg9nLEnSUq0k6JcD35q3fqzbJklaBetG/QRJdgG7utUfJPnGqJ9zAZcB317OAz8y5EHWkGUfk4Z5TM7mMTnbko/JEDry1j47rSTox4Er5q1v6rb9P1W1B9izgudZsSQzVTW5mjOsNR6Ts3lMzuYxOdtaPiYrueTyFWBrkiuTXAi8H9g/nLEkSUu17DP0qjqT5HeBzwMXAPdW1TNDm0yStCQruoZeVQ8CDw5pllFa1Us+a5TH5Gwek7N5TM62Zo9Jqmq1Z5AkDYFv/ZekRjQZ9CS/kuSZJP+TZNF/jT6fProgyfokB5Ic7m4vWWS/V5I82X01+Y/cg77vSd6Y5P7u/seSbBn/lOPV45j8RpLZeb83fms15hyXJPcmOZ3k6UXuT5K/6Y7X15K8c9wzLqTJoANPA78MPLrYDufhRxfsBg5W1VbgYLe+kP+qquu6r9vGN9549Py+7wRerKqrgLuAO8c75Xgt4Wfh/nm/N+4e65Dj90ng1te5/z3A1u5rF/CJMcw0UJNBr6pDVTXoDUzn20cXbAemu+VpYMcqzrKa+nzf5x+rvcC2JBnjjON2vv0sDFRVjwLffZ1dtgN/X3O+BFycZON4pltck0Hv6Xz76IINVXWiWz4JbFhkv4uSzCT5UpIWo9/n+/5/+1TVGeB7wKVjmW519P1ZeF93eWFvkisWuP98sib7MfK3/o9Kki8CP7nAXR+tqn3jnmcteL1jMn+lqirJYi9vemtVHU/yNuChJE9V1fPDnlXnnH8B7quql5P8NnN/g7l5lWfSa5yzQa+qd6/wl+j10QXnktc7JklOJdlYVSe6vxqeXuTXON7dHk3yCHA90FLQ+3zfX93nWJJ1wFuA74xnvFUx8JhU1fz//ruBvxzDXGvZmuzH+XzJ5Xz76IL9wFS3PAWc9beYJJckeWO3fBlwE/Ds2CYcjz7f9/nH6nbgoWr7DRsDj8lrrg/fBhwa43xr0X7g17tXu9wIfG/eJc3VU1XNfQG/xNw1rZeBU8Dnu+0/BTw4b7/3As8xdwb60dWee8TH5FLmXt1yGPgisL7bPgnc3S3/NPAU8NXududqzz2iY3HW9x34M+C2bvki4B+BI8CXgbet9sxr4Jj8BfBM93vjYeDq1Z55xMfjPuAE8N9dS3YCHwQ+2N0f5l4Z9Hz3szK52jNXle8UlaRWnM+XXCSpKQZdkhph0CWpEQZdkhph0CWpEQZdkhph0CWpEQZdkhrxv+nC9zQqbtHeAAAAAElFTkSuQmCC\n", 83 | "text/plain": [ 84 | "
" 85 | ] 86 | }, 87 | "metadata": { 88 | "needs_background": "light" 89 | }, 90 | "output_type": "display_data" 91 | }, 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "mean 0.1518110806875614 std 0.98625144122289\n" 97 | ] 98 | } 99 | ], 100 | "source": [ 101 | "x_T = np.random.normal(size=(200, 1))\n", 102 | "samples = diffusion.ddpm_sample(x_T, model)\n", 103 | "plt.hist(samples.reshape(-1), 20, alpha=0.5)\n", 104 | "plt.show()\n", 105 | "\n", 106 | "print('mean', np.mean(samples), 'std', np.std(samples))" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 8, 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "data": { 116 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADahJREFUeJzt3X+sX/Vdx/HnSzqGmTooXGulZGWhjPAPsNwQJsZEOgybhlZFssVoNTV1iZrJTLS6vzQmDv+wzsQsaQB3TSYDO5fWhWx2BUJMHNtF2PjRjZZmZG36426DuRnDLHv7xz0sV3ov33Pv/f5oP30+kpvvOed7vvf77rm9z56efr+3qSokSee+H5n0AJKk4TDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5JjVgzzie77LLLauPGjeN8Skk65z3xxBPfrKqpQfuNNegbN25kdnZ2nE8pSee8JC/22c9LLpLUCIMuSY3oFfQkFyfZk+SrSQ4meVeStUn2JznU3V4y6mElSUvre4b+UeCzVXUNcB1wENgJHKiqTcCBbl2SNCEDg57krcDPAfcCVNX3q+plYAsw0+02A2wd1ZCSpMH6nKFfCcwB/5DkyST3JHkLsK6qjnf7nADWLfbgJDuSzCaZnZubG87UkqQz9An6GuCdwMeq6gbgv3nd5ZWa/2+PFv2vj6pqd1VNV9X01NTAl1FKklaoT9CPAker6vFufQ/zgT+ZZD1Ad3tqNCNKkvoYGPSqOgF8I8k7uk2bgeeAfcC2bts2YO9IJpQk9dL3naJ/AHwiyYXAEeC3mf/D4MEk24EXgTtHM6IkDceu/c9P5HnvuvXqsTxPr6BX1VPA9CJ3bR7uOJKklfKdopLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY0w6JLUCIMuSY1Y02enJF8Hvgu8Cpyuqukka4EHgI3A14E7q+ql0YwpSRpkOWfoP19V11fVdLe+EzhQVZuAA926JGlCVnPJZQsw0y3PAFtXP44kaaX6Br2Af0vyRJId3bZ1VXW8Wz4BrFvsgUl2JJlNMjs3N7fKcSVJS+l1DR342ao6luQngf1JvrrwzqqqJLXYA6tqN7AbYHp6etF9JEmr1+sMvaqOdbengE8DNwInk6wH6G5PjWpISdJgA4Oe5C1Jfvy1ZeAXgGeAfcC2brdtwN5RDSlJGqzPJZd1wKeTvLb/P1XVZ5N8CXgwyXbgReDO0Y0pSRpkYNCr6ghw3SLbvwVsHsVQkqTl852iktQIgy5JjTDoktQIgy5JjTDoktQIgy5JjTDoktQIgy5Jjej7w7kmbtf+51f82LtuvXqIk0jS2ckzdElqhEGXpEYYdElqhEGXpEYYdElqhEGXpEYYdElqhEGXpEYYdElqhEGXpEYYdElqhEGXpEYYdElqhEGXpEYYdElqhEGXpEYYdElqRO+gJ7kgyZNJPtOtX5nk8SSHkzyQ5MLRjSlJGmQ5Z+gfBA4uWL8b2FVVVwEvAduHOZgkaXl6BT3JBuAXgXu69QC3AHu6XWaAraMYUJLUT98z9L8F/hj4Qbd+KfByVZ3u1o8Cly/2wCQ7kswmmZ2bm1vVsJKkpQ0MepJfAk5V1RMreYKq2l1V01U1PTU1tZJPIUnqYU2PfW4Gbk/yXuAi4CeAjwIXJ1nTnaVvAI6NbkxJ0iADz9Cr6k+rakNVbQTeBzxcVb8OPALc0e22Ddg7siklSQOt5nXofwJ8KMlh5q+p3zuckSRJK9HnkssPVdWjwKPd8hHgxuGPJElaCd8pKkmNMOiS1AiDLkmNMOiS1AiDLkmNMOiS1AiDLkmNMOiS1IhlvbHoXLVr//Mrfuxdt149xEkkaXQ8Q5ekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRhh0SWqEQZekRgwMepKLknwxyZeTPJvkz7vtVyZ5PMnhJA8kuXD040qSltLnDP0V4Jaqug64HrgtyU3A3cCuqroKeAnYProxJUmDDAx6zftet/qm7qOAW4A93fYZYOtIJpQk9dLrGnqSC5I8BZwC9gMvAC9X1elul6PA5aMZUZLUR6+gV9WrVXU9sAG4Ebim7xMk2ZFkNsns3NzcCseUJA2yrFe5VNXLwCPAu4CLk6zp7toAHFviMburarqqpqemplY1rCRpaX1e5TKV5OJu+UeBW4GDzIf9jm63bcDeUQ0pSRpszeBdWA/MJLmA+T8AHqyqzyR5Dvhkkr8EngTuHeGckqQBBga9qr4C3LDI9iPMX0+XJJ0FfKeoJDXCoEtSIwy6JDXCoEtSIwy6JDXCoEtSIwy6JDXCoEtSIwy6JDXCoEtSIwy6JDXCoEtSIwy6JDXCoEtSIwy6JDXCoEtSIwy6JDXCoEtSIwy6JDXCoEtSIwy6JDXCoEtSIwy6JDXCoEtSIwy6JDXCoEtSIwYGPckVSR5J8lySZ5N8sNu+Nsn+JIe620tGP64kaSl9ztBPA39UVdcCNwG/l+RaYCdwoKo2AQe6dUnShAwMelUdr6r/7Ja/CxwELge2ADPdbjPA1lENKUkabFnX0JNsBG4AHgfWVdXx7q4TwLqhTiZJWpbeQU/yY8CngD+sqv9aeF9VFVBLPG5Hktkks3Nzc6saVpK0tF5BT/Im5mP+iar6l27zySTru/vXA6cWe2xV7a6q6aqanpqaGsbMkqRF9HmVS4B7gYNV9TcL7toHbOuWtwF7hz+eJKmvNT32uRn4DeDpJE912/4M+AjwYJLtwIvAnaMZUZLUx8CgV9W/A1ni7s3DHUeStFK+U1SSGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRBl2SGmHQJakRA4Oe5L4kp5I8s2Db2iT7kxzqbi8Z7ZiSpEH6nKF/HLjtddt2AgeqahNwoFuXJE3QwKBX1WPAt1+3eQsw0y3PAFuHPJckaZlWeg19XVUd75ZPAOuGNI8kaYVW/Y+iVVVALXV/kh1JZpPMzs3NrfbpJElLWGnQTyZZD9Ddnlpqx6raXVXTVTU9NTW1wqeTJA2y0qDvA7Z1y9uAvcMZR5K0Un1etng/8B/AO5IcTbId+Ahwa5JDwLu7dUnSBK0ZtENVvX+JuzYPeRZJ0ir4TlFJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGGHRJaoRBl6RGrCroSW5L8rUkh5PsHNZQkqTlW3HQk1wA/D3wHuBa4P1Jrh3WYJKk5VnNGfqNwOGqOlJV3wc+CWwZzliSpOVaTdAvB76xYP1ot02SNAFrRv0ESXYAO7rV7yX52qifc4HLgG+u5hN8aEiDnGVWfVwa5DFZnMflTMs+JkPoyNv67LSaoB8DrliwvqHb9v9U1W5g9yqeZ8WSzFbV9CSe+2zmcTmTx2RxHpcznc3HZDWXXL4EbEpyZZILgfcB+4YzliRpuVZ8hl5Vp5P8PvA54ALgvqp6dmiTSZKWZVXX0KvqIeChIc0yChO51HMO8LicyWOyOI/Lmc7aY5KqmvQMkqQh8K3/ktSIpoKe5NeSPJvkB0mW/Ffo8+1HFiRZm2R/kkPd7SVL7Pdqkqe6jyb/gXvQ1z7Jm5M80N3/eJKN459yvHock99KMrfg98bvTGLOcUpyX5JTSZ5Z4v4k+bvumH0lyTvHPeNimgo68AzwK8BjS+1wnv7Igp3AgaraBBzo1hfzP1V1ffdx+/jGG4+eX/vtwEtVdRWwC7h7vFOO1zK+Hx5Y8HvjnrEOORkfB257g/vfA2zqPnYAHxvDTAM1FfSqOlhVg964dD7+yIItwEy3PANsneAsk9Tna7/wWO0BNifJGGcct/Px+2GgqnoM+PYb7LIF+Mea9wXg4iTrxzPd0poKek/n448sWFdVx7vlE8C6Jfa7KMlski8kaTH6fb72P9ynqk4D3wEuHct0k9H3++FXu0sLe5Jcscj955uzsiMjf+v/sCX5PPBTi9z14araO+55zhZvdFwWrlRVJVnqpU1vq6pjSd4OPJzk6ap6Ydiz6pzzr8D9VfVKkt9l/m8wt0x4Ji3inAt6Vb17lZ+i148sONe80XFJcjLJ+qo63v218NQSn+NYd3skyaPADUBLQe/ztX9tn6NJ1gBvBb41nvEmYuAxqaqFv/57gL8ew1xnu7OyI+fjJZfz8UcW7AO2dcvbgDP+JpPkkiRv7pYvA24GnhvbhOPR52u/8FjdATxcbb9ZY+Axed214duBg2Oc72y1D/jN7tUuNwHfWXBZc3KqqpkP4JeZv5b1CnAS+Fy3/aeBhxbs917geebPPj886bnHcFwuZf7VLYeAzwNru+3TwD3d8s8ATwNf7m63T3ruER2LM772wF8At3fLFwH/DBwGvgi8fdIznwXH5K+AZ7vfG48A10x65jEck/uB48D/dk3ZDnwA+EB3f5h/ddAL3ffL9KRnrirfKSpJrTgfL7lIUpMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ1wqBLUiMMuiQ14v8AMWb5rcrhmLYAAAAASUVORK5CYII=\n", 117 | "text/plain": [ 118 | "
" 119 | ] 120 | }, 121 | "metadata": { 122 | "needs_background": "light" 123 | }, 124 | "output_type": "display_data" 125 | }, 126 | { 127 | "name": "stdout", 128 | "output_type": "stream", 129 | "text": [ 130 | "mean 0.17585555054618038 std 0.9831223405397485\n" 131 | ] 132 | } 133 | ], 134 | "source": [ 135 | "x_T = np.random.normal(size=(200, 1))\n", 136 | "samples = diffusion.ddim_sample(x_T, model)\n", 137 | "plt.hist(samples.reshape(-1), 20, alpha=0.5)\n", 138 | "plt.show()\n", 139 | "\n", 140 | "print('mean', np.mean(samples), 'std', np.std(samples))" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 9, 146 | "metadata": {}, 147 | "outputs": [ 148 | { 149 | "data": { 150 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAAEKCAYAAAAFJbKyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAHI9JREFUeJzt3XuQHWd55/Hv75y5aGZ0G10sy7rajhNjAxtgYiBQu7AxxjgUDlnIGrIbHDalZBNvSNUuuxAvEMhmIaE2uwFTGAFecNYFpAIEsYgYc0kBm5hYVgS+gmVZsiRsSxpd56Izc8559o/uGY3lGen0zPT0TOv3qTrV3W+/c87T9jvzqN/3PW8rIjAzM2tVpegAzMxsYXHiMDOzTJw4zMwsEycOMzPLxInDzMwyceIwM7NMCksckjZI+o6khyU9JOkdk9SRpI9I2i3pR5JeXESsZmZ2RluBn10H/mNE7JS0BLhf0j0R8fCEOq8DrkhfLwU+nm7NzKwghd1xRMRTEbEz3T8FPAKsO6vajcCdkbgXWC5p7RyHamZmExR5xzFO0mbgRcAPzjq1Dtg/4fhAWvbUJO+xBdgC0NPT85Irr7wyj1DNzErp/vvvPxIRq1upW3jikLQY+CLwBxFxcrrvExFbga0AfX19sWPHjlmK0Mys/CTta7VuobOqJLWTJI27IuJLk1Q5CGyYcLw+LTMzs4IUOatKwKeBRyLiz6eotg34jXR21cuAExHxnG4qMzObO0V2Vb0C+LfAA5J2pWV/CGwEiIjbge3ADcBuYAj4zQLiNDOzCQpLHBHxfUDnqRPA781NRGZm1gp/c9zMzDJx4jAzs0ycOMzMLBMnDjMzy8SJw8zMMnHiMDOzTJw4zMwsEycOMzPLxInDzMwyceIwM7NMnDjMzCwTJw4zM8vEicPMzDJx4jAzs0ycOMzMLBMnDjMzy8SJw8zMMnHiMDOzTApNHJLukHRI0oNTnH+VpBOSdqWv9851jGZm9myFPXM89RngNuDOc9T5XkS8fm7CMTOz8yn0jiMivgscLTIGMzPLZiGMcbxc0g8lfV3S1UUHY2Z2oSu6q+p8dgKbImJA0g3A3wBXTFZR0hZgC8DGjRvnLkIzswvMvL7jiIiTETGQ7m8H2iWtmqLu1ojoi4i+1atXz2mcZmYXknmdOCRdLEnp/jUk8fYXG5WZ2YWt0K4qSZ8DXgWsknQAeB/QDhARtwNvAv69pDowDNwUEVFQuGZmRsGJIyLecp7zt5FM1zUzs3liXndVmZnZ/OPEYWZmmThxmJlZJk4cZmaWiROHmZll4sRhZmaZOHGYmVkmThxmZpaJE4eZmWXixGFmZpk4cZiZWSZOHGZmlokTh5mZZeLEYWZmmThxmJlZJvP9meNmZvNeRBABjQia6X4zgma6jUjqTHYcpNvx8qRs7D2CpIy0LNl7dr2xum1VceXFS3O/XicOM5t3ms1geLTB4Eid4ZEGw6MNhkcanB5tcrreoDbaoFZvcnq0wUi9SS19jdSbjDSS7WgjeY3UY3x/tBHUm03qjaSs0QzqzTNljWbQiGQ78biZ1ms2k8TQGEsC6fn58lzSVYs72fFfr839c5w4zGzWRQSnR5scHx7h+NBo+hrh+HC6PzzCyeFRTgyPcnK4zqnTo5w6XedUrc5grc7QSGPan91eFe3VCh1tFTqqFdqrlfGytnS/rSLaqhU62yt0Vyq0VUS1ovFtdbJjicqErQRVJeektFxQqYhKui9BRcl5QVoXRFJPKK0DIjlRSeue+dnkujShHBj/WY2dE3S2zc3ogxOHmbVkeKTBUyeGefrkaQ6drHFkoEb/4AhHB0Y4OjTCiaFRjqXJ4cTQKCON5pTv1VGtsLSrnWVdbSxZ1M6y7g7Wr+hmSWcbPWOvjirdHVW6Otroak/2O9srLGqvsqgt2e9sq9CZ7ndUk1elojn8r3JhKjRxSLoDeD1wKCKeP8l5AX8B3AAMATdHxM65jdLswlCrNzh4bJh9R4c4cHSI/ceG2X90iIPHhzlwbJijgyPP+Zm2iujt6WBlTwfLutq5bHUPvd0dLOtuZ1lXO73dHSzvamdZdzvLuzro7Um2i9orSP4Dv1AVfcfxGeA24M4pzr8OuCJ9vRT4eLo1s2kaHmnw6NMnefTpUzz2zACPHTrFnsOD/PTE8LP66jvaKqzv7WJ9bzdXX7KM9b1dXLx0ERcvW8SapZ2sXryIpV1tTgAXoEITR0R8V9Lmc1S5EbgzIgK4V9JySWsj4qk5CdBsAas3muw/Nsxjz5ziJ8+c4tGnk9eewwM00wTR1V7l8ot66Nvcy6aV69m0optNK7vZsKKb1Ys73e1jkyr6juN81gH7JxwfSMuekzgkbQG2AGzcuHFOgjMrWkRwdHCEvf1D7Dk8wJ4jgzx+aIDHDw/w5NEhRhtnbiE2rOji59Ys4YYXrOXqS5Zy1dqlrFve5eRgmc33xNGyiNgKbAXo6+ubJ5PjzGau3mjy0+OneaJ/kH39gzzZP8SBY8McOD7Evv4hTp2uj9dtr4pNK3u4fPVirrv6Yi5b1cPlFy3mZ9csYXFnaX7drWDzvSUdBDZMOF6flpmVSrMZPHPqNE/2D7Hv6BBP9g+x58gAuw8NsPfI0LNmKC1qr7C+t5v1vV28eGMvm1f2sHlVN5tX9rBhRTftVS8IYfma74ljG3CLpM+TDIqf8PiGLVQnT4/y0+PDHDw2zN7+IfYeGWRv/yAHjiVlE5NDtSI2rujm8tU9vPrKi7hsVU+aIHq4aEmnB6StUEVPx/0c8CpglaQDwPuAdoCIuB3YTjIVdzfJdNzfLCZSs+n78dOneOsn76X/rOmsSxe1sXlVD1ddspTrrl7D+t7u8cHpS5Z3+c7B5q2iZ1W95TznA/i9OQrHLBePHTpF/+AIv/0vLuMF65axbnkXm1f2sLy73XcOtiDN964qswWvNpp0Qb31mo1sWtlTcDRmM+d7YbOc1epJ4uhsqxYcidnscOIwy1mtnizYN1cL0JnlzS3ZLGfjdxzt/nWzcnBLNsvZ2BhHh2dJWUm4JZvlrFZvjD//wawM3JLNclarNz2+YaXi1myWs1q9QWe7Z1RZeThxmOWsNuo7DisXt2aznLmrysrGrdksZ7V6w1/+s1Jx4jDLWa3e9Hc4rFTcms1y5jEOKxu3ZrOcuavKysaJwyxnHhy3snFrNsuZxzisbNyazXLmriorGycOs5x5cNzKptDWLOl6ST+WtFvSuyY5f7Okw5J2pa/fKiJOs5nwGIeVTWGPjpVUBT4GvAY4ANwnaVtEPHxW1S9ExC1zHqDZLBmpN71WlZVKkf8MugbYHRF7ImIE+DxwY4HxmM26iEjHOHzHYeVRZGteB+yfcHwgLTvbv5L0I0l/LWnDVG8maYukHZJ2HD58eLZjNZuWejNohh8ba+Uy31vzV4HNEfFC4B7gs1NVjIitEdEXEX2rV6+eswDNzmX8sbGeVWUlUmTiOAhMvINYn5aNi4j+iKilh58CXjJHsZnNitpoA/Dzxq1cimzN9wFXSLpUUgdwE7BtYgVJayccvgF4ZA7jM5uxM3ccThxWHoXNqoqIuqRbgLuBKnBHRDwk6QPAjojYBvy+pDcAdeAocHNR8ZpNh7uqrIwKSxwAEbEd2H5W2Xsn7L8bePdcx2U2W2r1tKvKdxxWIm7NZjmqjaZ3HB7jsBJxazbLkbuqrIycOMxy5K4qKyO3ZrMcjXdV+Y7DSsSJwyxH411VHuOwEnFrNsuRu6qsjFpqzZK6Jb1H0ifT4yskvT7f0MwWPg+OWxm1+s+g/w3UgJenxweB/5ZLRGYlMr7kiO84rERabc2XR8SfAaMAETEEKLeozEpi7I6jw4nDSqTV1jwiqQsIAEmXk9yBmNk5eK0qK6NWlxx5H/C3wAZJdwGvwOtGmZ1Xrd6gWhFtVScOK4+WEkdE3CNpJ/Ayki6qd0TEkVwjMyuB2qifN27lc87EIenFZxU9lW43StoYETvzCcusHGp1Jw4rn/PdcfyPc5wL4F/OYixmpZM8b9xTca1czpk4IuLVcxWIWRnV6k1/a9xKp6UxDkmLgN8FXklyp/E94PaIOJ1jbGYLnsc4rIxanVV1J3AK+Gh6/FbgL4E35xGUWVm4q8rKqNXE8fyIuGrC8XckPZxHQGZl4sFxK6NWW/ROSS8bO5D0UmDHTD9c0vWSfixpt6R3TXK+U9IX0vM/kLR5pp9pNpc8xmFl1GqLfgnw95L2StoL/APwC5IekPSj6XywpCrwMeB1wFXAWyRddVa1fwcci4ifAf4n8KfT+Syzorirysqo1a6q63P47GuA3RGxB0DS54EbgYldYDcCf5Tu/zVwmyRFROQQj9ms8+C4lVFLLToi9gEngWXAyrFXROxLz03HOmD/hOMDadmkdSKiDpxIP/s5JG2RtEPSjsOHD08zJLPZ5TEOK6NWp+P+McnaVI+TLnTIPPsCYERsBbYC9PX1+Y7E5gV3VVkZtdpV9WskS6uPzOJnHwQ2TDhen5ZNVueApDaSO57+WYzBLFceHLcyarVFPwgsn+XPvg+4QtKlkjqAm4BtZ9XZBrwt3X8T8G2Pb9hC4jEOK6NW7zg+CPyTpAeZ8ByOiHjDdD84IuqSbgHuBqrAHRHxkKQPADsiYhvwaeAvJe0GjpIkF7MFw11VVkatJo7PkkyFfQBoztaHR8R2YPtZZe+dsH8afzvdFqh6o0kz/BAnK59WE8dQRHwk10jMSmb86X8e47CSaTVxfE/SB0nGHCZ2Vfl5HGZTOPPYWHdVWbm0mjhelG5fNqFsXk3HNZtvavUG4K4qK59WHx3r53KYZVQbdVeVlVOrdxxI+mXgamDRWFlEfCCPoMzKwF1VVlYt/VNI0u3Avwb+AyCSmU6bcozLbMFzV5WVVast+hcj4jdIVqp9P/By4GfzC8ts4fMdh5VVq4lj7BGxQ5IuAerA2nxCMisHj3FYWbU6xvFVScuBDwM7SWZUfTK3qMxKwF1VVlatJo5HgUZEfDF92NKLgb/JLyyzhc9dVVZWrf5T6D0RcUrSK0m+u/Ep4OP5hWW28PmOw8qq1RbdSLe/DHwyIr4GdOQTklk5eIzDyqrVFn1Q0idIpuRul9SZ4WfNLkjuqrKyavWP/6+RLH/+2og4DqwA3plbVGYl4K4qK6tWlxwZAr404fgp4Km8gjIrg/GuKicOKxm3aLOc1OpNqhXRVvWvmZWLW7RZTmr1Bh1OGlZCbtVmOanVm55RZaVUSKuWtELSPZIeS7e9U9RrSNqVvrbNdZxmM1EbbXp8w0qpqFb9LuBbEXEF8K30eDLDEfHz6esNcxee2czV6g1PxbVSKipx3Ah8Nt3/LPArBcVhlpta3XccVk5Fteo16ZRegKeBNVPUWyRph6R7JZ0zuUjaktbdcfjw4VkN1mw6PMZhZdXyEwCzkvRN4OJJTt068SAiQlJM8TabIuKgpMuAb0t6ICIen6xiRGwFtgL09fVN9X5mc8ZdVVZWuSWOiLh2qnOSnpG0NiKekrQWODTFexxMt3sk/R3wImDSxGE233hw3MqqqFa9DXhbuv824CtnV5DUm66JhaRVwCuAh+csQrMZ8hiHlVVRrfpDwGskPQZcmx4jqU/Sp9I6zwN2SPoh8B3gQxHhxGELhruqrKxy66o6l4joB35pkvIdwG+l+38PvGCOQzObNR4ct7JyqzbLicc4rKzcqs1y4q4qKysnDrOceHDcysqt2iwnHuOwsnKrNstBvdGk0Qx3VVkpOXGY5eDM88b9K2bl41ZtlgMnDiszt2qzHIyMJY52d1VZ+ThxmOWgVm8AvuOwcnKrNsvBma4q33FY+ThxmOWgNuoxDisvt2qzHIx3Vfl7HFZCbtVmOXBXlZWZE4dZDjw4bmXmVm2Wg/ExDndVWQm5VZvNsojgxPAo4K4qK6dCHuRkVgbHBkfYc2SAxw8PsufwIHsOD7Cvf4gDx4YYHEm6qhZ3+lfMyset2uwcTo822Ns/yBOHB9lzZJAn0teewwMcGxodr9deFZtW9nDpqh5+8WdWsr63m+etXcLqJZ0FRm+Wj0ISh6Q3A39E8lzxa9JHxk5W73rgL4Aq8KmI+NCcBWkXhIjgyMAI+48NcfDYMAeODfPk0UH29Q+xr3+In54YJuJM/TVLO7l0VQ/XP38tl6/u4bLVPVy2ajHre7toq7rn1y4MRd1xPAj8KvCJqSpIqgIfA14DHADuk7QtIh6emxCtTE4Mj7L3yGDStXQo2T5xZIgn+wfHu5XGrOjpYNPKbn5hcy+bV63nstWLuWxVcjfR464ns2ISR0Q8AiDpXNWuAXZHxJ607ueBGwEnDpvU8aERHj+cdCU92T/IvqND6Z3D4LO6laoVsXFFN5eu6uFll61g04puNq7sZn1vN+uWdzk5mJ3HfP4NWQfsn3B8AHjpVJUlbQG2AGzcuDHfyKwQEcHRwREOHBtm/7Ehnjw6ND4o/cSRZyeHimDtsi42rezmdS9Yy+aV3Wxa2cPlq3vYuKKHDn+/wmzacksckr4JXDzJqVsj4iuz/XkRsRXYCtDX1xfnqW7zSERwcrjOM6dO88zJ0xw6WeOZU8n28Kkah06d5umTp3nmZG18ufIxFy159pjDpWmX0vrebicHs5zkljgi4toZvsVBYMOE4/VpmS0AzWbyXYb+wRGODY3QP1Cjf3CE/oERjgzUkqQwkCSFQydr40t0TLS4s42LlnSyekknL97Yy8VLF7Fm6SI2rOhmfW8XG1Z0e7qrWQHm82/dfcAVki4lSRg3AW8tNqQLQ0Qw2giGRxsMjzQYHKkzPNLg1Ok6A7U6A7VRTg7XOTE8ysnhUY4Pj3J8aJTjQ0mSOJbuN6e471vW1f6shLBm6SIuWtLJRUsXsWZJZ3K8tJPujvncPM0uXEVNx30j8FFgNfA1Sbsi4rWSLiGZdntDRNQl3QLcTTId946IeCjPuB59+iSNZoxPv4yAINJt8gc1xson7Dcjnl033W+O1UvrNMfrntkfq9OMoNkcK0vKk1iCRjNoRPKv+EZ6PHF/7FUf3zapN5LjeqPJaLqtN4KRRpPRRpPRRjBSbyavRpPaaINavcnp0Qan600aU/3VP0t3R5XlXe0s6+5geVc7V168lOXd7fR2d7Cip4OVizvo7U62qxZ30tvd4S4kswVOEeUbDujr64sdOyb9asg5Xfmer3N69LldJvOZBFWJtqpoq1SoCNqrFaoVTdgm+23ptr1aoaNaoaOtMn5uUXuVzrYKnW1VujoqdLVXWdRepbujjZ7OKl3tVZYsamfJojZ6OttY1pXst/u7C2alIOn+iOhrpa77Aib4yE0vohnJH2ORTBdOtmNlyQkBFWm8rCLS8mRf6bnxfZIpoJV0+vHY/sQ6FYnq2HsqqZMca3y/Ujnzs2fKzjml2cxs1jlxTHDd1ZNNAjMzs4ncz2BmZpk4cZiZWSZOHGZmlokTh5mZZeLEYWZmmThxmJlZJk4cZmaWiROHmZll4sRhZmaZOHGYmVkmThxmZpaJE4eZmWXixGFmZpk4cZiZWSZOHGZmlokTh5mZZVJI4pD0ZkkPSWpKmvJRhZL2SnpA0i5J2Z8Fa2Zms66oJwA+CPwq8IkW6r46Io7kHI+ZmbWokMQREY9A8qxtMzNbWOb7GEcA35B0v6QtRQdjZmY53nFI+iZw8SSnbo2Ir7T4Nq+MiIOSLgLukfRoRHx3is/bAmwB2Lhx47RiNjOz88stcUTEtbPwHgfT7SFJXwauASZNHBGxFdgK0NfXFzP9bDMzm9y87aqS1CNpydg+cB3JoLqZmRWoqOm4b5R0AHg58DVJd6fll0janlZbA3xf0g+BfwS+FhF/W0S8ZmZ2RlGzqr4MfHmS8p8CN6T7e4B/NsehmZnZeczbriozM5ufnDjMzCwTJw4zM8vEicPMzDJx4jAzs0ycOMzMLBMnDjMzy8SJw8zMMnHiMDOzTJw4zMwsEycOMzPLxInDzMwyceIwM7NMnDjMzCwTJw4zM8vEicPMzDJx4jAzs0ycOMzMLBMnDjMzy6SQxCHpw5IelfQjSV+WtHyKetdL+rGk3ZLeNddxmpnZcxV1x3EP8PyIeCHwE+DdZ1eQVAU+BrwOuAp4i6Sr5jRKMzN7jkISR0R8IyLq6eG9wPpJql0D7I6IPRExAnweuHGuYjQzs8m1FR0A8HbgC5OUrwP2Tzg+ALx0qjeRtAXYkh4OSPrxNONZBRyZ5s/ON2W5lrJcB/ha5qOyXAfM7Fo2tVoxt8Qh6ZvAxZOcujUivpLWuRWoA3fN9PMiYiuwdabvI2lHRPTN9H3mg7JcS1muA3wt81FZrgPm7lpySxwRce25zku6GXg98EsREZNUOQhsmHC8Pi0zM7MCFTWr6nrgPwNviIihKardB1wh6VJJHcBNwLa5itHMzCZX1Kyq24AlwD2Sdkm6HUDSJZK2A6SD57cAdwOPAH8VEQ/NQWwz7u6aR8pyLWW5DvC1zEdluQ6Yo2vR5L1EZmZmk/M3x83MLBMnDjMzy8SJYxKS/jhdDmWXpG9IuqTomKaj1aVdFgJJb5b0kKSmpAU3dbJMy+dIukPSIUkPFh3LTEjaIOk7kh5O29Y7io5puiQtkvSPkn6YXsv7c/08j3E8l6SlEXEy3f994KqI+J2Cw8pM0nXAtyOiLulPASLivxQc1rRIeh7QBD4B/KeI2FFwSC1Ll8/5CfAaki+y3ge8JSIeLjSwaZL0z4EB4M6IeH7R8UyXpLXA2ojYKWkJcD/wKwvx/4skAT0RMSCpHfg+8I6IuDePz/MdxyTGkkaqB1iQ2bXFpV0WhIh4JCKmuxpA0Uq1fE5EfBc4WnQcMxURT0XEznT/FMnszXXFRjU9kRhID9vTV25/t5w4piDpTyTtB34deG/R8cyCtwNfLzqIC9Rky+csyD9QZSVpM/Ai4AfFRjJ9kqqSdgGHgHsiIrdruWATh6RvSnpwkteNABFxa0RsIFkO5ZZio53a+a4jrTNrS7vkqZVrMZttkhYDXwT+4KzehgUlIhoR8fMkPQvXSMqtG3E+LHJYiPMtiTLBXcB24H05hjNts7C0y7yR4f/JQuPlc+apdDzgi8BdEfGlouOZDRFxXNJ3gOuBXCYwXLB3HOci6YoJhzcCjxYVy0y0uLSL5c/L58xD6YDyp4FHIuLPi45nJiStHps1KamLZCJGbn+3PKtqEpK+CPwcySyefcDvRMSC+xeipN1AJ9CfFt27EGeHAUh6I/BRYDVwHNgVEa8tNqrWSboB+F9AFbgjIv6k4JCmTdLngFeRLOH9DPC+iPh0oUFNg6RXAt8DHiD5XQf4w4jYXlxU0yPphcBnSdpXhWSJpg/k9nlOHGZmloW7qszMLBMnDjMzy8SJw8zMMnHiMDOzTJw4zMwsEycOs2mQNHCe88sl/e4MP+Pmhboys5WbE4dZPpYDM0ocwM2AE4fNO04cZjMgabGkb0naKemBCetqfQi4PH2my4fTuu+UdF/6fJT3p2WbJT0i6ZPpcxS+IalL0puAPuCu9D26irlCs+fyFwDNpkHSQEQsltQGdEfESUmrSJavvwLYBPzfsedVpM9GeRPw24BIlhz5M+BJYDfQFxG7JP0VsC0i/o+kv2OBPXvELgwX7CKHZrNEwH9PH27UJFkufc0k9a5LX/+UHi8mSTBPAk9ExK60/H5gc54Bm82UE4fZzPw6yfpZL4mIUUl7gUWT1BPwwYj4xLMKk+dA1CYUNQB3S9m85jEOs5lZBhxKk8arSbqoAE4BSybUuxt4e/rsByStk3TRed777Pcwmxd8x2E2M3cBX5X0ALCDdCnriOiX9P8kPQh8PSLemT43/R+S1bwZAP4NyR3GVD4D3C5pGHh5RAzneB1mLfPguJmZZeKuKjMzy8SJw8zMMnHiMDOzTJw4zMwsEycOMzPLxInDzMwyceIwM7NM/j+nxj8mSSYQAQAAAABJRU5ErkJggg==\n", 151 | "text/plain": [ 152 | "
" 153 | ] 154 | }, 155 | "metadata": { 156 | "needs_background": "light" 157 | }, 158 | "output_type": "display_data" 159 | } 160 | ], 161 | "source": [ 162 | "x_T = np.linspace(-3, 3, 100)\n", 163 | "samples = diffusion.ddim_sample(x_T.reshape([-1, 1]), model)\n", 164 | "plt.xlabel('latent')\n", 165 | "plt.ylabel('sample')\n", 166 | "plt.ylim(-2, 2)\n", 167 | "plt.plot(x_T, samples)\n", 168 | "plt.show()" 169 | ] 170 | } 171 | ], 172 | "metadata": { 173 | "kernelspec": { 174 | "display_name": "Python 3", 175 | "language": "python", 176 | "name": "python3" 177 | }, 178 | "language_info": { 179 | "codemirror_mode": { 180 | "name": "ipython", 181 | "version": 3 182 | }, 183 | "file_extension": ".py", 184 | "mimetype": "text/x-python", 185 | "name": "python", 186 | "nbconvert_exporter": "python", 187 | "pygments_lexer": "ipython3", 188 | "version": "3.6.3" 189 | } 190 | }, 191 | "nbformat": 4, 192 | "nbformat_minor": 2 193 | } 194 | -------------------------------------------------------------------------------- /mnist_conditional.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "metadata": { 3 | "language_info": { 4 | "codemirror_mode": { 5 | "name": "ipython", 6 | "version": 3 7 | }, 8 | "file_extension": ".py", 9 | "mimetype": "text/x-python", 10 | "name": "python", 11 | "nbconvert_exporter": "python", 12 | "pygments_lexer": "ipython3", 13 | "version": "3.8.5-final" 14 | }, 15 | "orig_nbformat": 2, 16 | "kernelspec": { 17 | "name": "python3", 18 | "display_name": "Python 3", 19 | "language": "python" 20 | } 21 | }, 22 | "nbformat": 4, 23 | "nbformat_minor": 2, 24 | "cells": [ 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "from IPython.display import display\n", 32 | "from PIL import Image\n", 33 | "import matplotlib.pyplot as plt\n", 34 | "import numpy as np\n", 35 | "import torch\n", 36 | "from tqdm.auto import tqdm\n", 37 | "\n", 38 | "from ddim import BayesPredictor, CNNPredictor, Diffusion, create_alpha_schedule\n", 39 | "from mnist_train import create_datasets" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "diffusion = Diffusion(\n", 49 | " create_alpha_schedule(num_steps=100, beta_0=0.001, beta_T=0.2)\n", 50 | ")" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 4, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "def load_trained_model():\n", 60 | " model = CNNPredictor((1, 28, 28))\n", 61 | " model.load_state_dict(torch.load('checkpoints/mnist_model.pt'))\n", 62 | " return model\n", 63 | "\n", 64 | "cnn_model = load_trained_model()" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 5, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "def load_test_data():\n", 74 | " _, loader = create_datasets(10000, False)\n", 75 | " batch = next(iter(loader))[0]\n", 76 | " return batch.numpy()\n", 77 | "\n", 78 | "test_data = load_test_data()" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 6, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "def show_sample(sample):\n", 88 | " image = sample*0.3081 + 0.1307\n", 89 | " image = (image * 255).clip(0, 255).astype('uint8')\n", 90 | " image = image.reshape([4, 4, 1, 28, 28]).transpose(0, 3, 1, 4, 2).reshape([28*4, 28*4, 1])\n", 91 | " image = np.concatenate([image] * 3, axis=-1)\n", 92 | " display(Image.fromarray(image))" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 11, 98 | "metadata": {}, 99 | "outputs": [ 100 | { 101 | "output_type": "display_data", 102 | "data": { 103 | "text/plain": "", 104 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAHAAAABwCAIAAABJgmMcAAAWW0lEQVR4nO2dfVATxxvHkzRBkkEGUIMRobairS+8KFUZFLCKL7Sio4JTiy+jAgNqtdXB1+oERKltba34ii9FK1QoiKJCAamKZgrKS6GCIAUUEME0gJJAJtm9/P7Yur3m9e4S2/46+fzhhNvb5577unu7t/vcLotlxYoVK1as/N+wcuXKmpoajUZDkFAoFBEREfb29sxscjgcW1tb4+fw+Xy7lxg6x8HBoaCgICMjQ2+ql5dXZGTkqFGjqLiUmJiYlpYWGxvr7Ozs4OCAj3t4ePD5fCoWTDNgwIA9e/aoVCpogIyMDEdHR7pmuVzu/v37m5qa3nrrLd1UR0dHkUg0bdq0zs5OfCFDpoKCggiCaGxsHDt2rG5qbGwsQRDLli2j4lVnZycqK7W1tQqForq6Ojc3t7u7W6lU9vb2nj59OvAlbDab+s3+BbFYbEhKDEV3yXz66acob2Njo1AoxMe5XO7GjRubmprI9qVSqVQqNWSquLiYIIj09HS9qUjQa9euUfFq/Pjx+/fv7+jouH79OrkutrS0nDlz5ujRo/jI2rVr6d4yi8ViJSUlqdVq8r1duHDh0KFDNjY2Fy5cwAeTkpLoWs7IyEB579y5g6szh8PZvXs3+XIymay8vHzKlCnjx4/Xa2fZsmUQQpOCEgTx1VdfUfRNKBSKRKK4uLjdu3e//fbbc+fORVUwODi4r68PWZPL5XPmzHFzc6Nxzxs3biSrWVFR4e/vz+PxUKqjo2NaWhpKUqvVGzZsoGGaJOiFCxfwwTVr1qCDe/bsKSwsvHXrllgsNm4H6dXZ2TlhwgS9JwwfPrylpYUgiCdPnowcOZKWk7osWLDgq6++Wrx4cUhISGlpaXV1NVVNhw0bVl5eTi4sQUFBWufY2NhcvnwZpZaVlYlEIuqeYUE7OztdXFzQwevXrxcWFt64cWPw4MGurq5U7CiVSoIgNm/ebOSc+vp6VKw8PT2pe2iScePGpaam5uXlTZ06VSuJo3v2kCFDvL29jVtUqVTFxcXo94QJEwYPHszArcGDB+MGNDw8PDg4ePHixb///ntra6vJvBMnTuRw9DivRVlZGQPHTFJTU1NRUTF58uTk5GStJNM+ZWdnV1ZWWtCb48eP49+44HR2dgIAurq6qFjg8XgbNmzgcrkAgBcvXhg5U/eGLUVeXp6jo+Pw4cNNn+rl5YUr+7Vr12xsbPSe5uvr++zZM3Sah4cHdVe4XO7169dRRolEwqCX5+XlhSpyXV2d8TMDAwNfRZVnsVj+/v4EQZCbAYSJElpUVKRSqfQmlZSUUCxQWgAAbt68iX77+voGBwczMILYtGmT8RMOHDjA2LgR+Hx+eHg4AODLL7/USjJd5Q0hEAhee+01Znlv3Lih0WjYbDaHwwkMDKSbff369ejHmjVrhpJITU09f/780KFDZ8yYkZOTk5OTo/fdwUxsbW3Pnz8fFRXV29vb0tKilcplbDcmJubNN99kllcikRQUFMyZM0ej0Wg0GrrZcau1cOHChQsXaqV++OGHzLyiyNy5cxcuXNja2nru3Llnz55ppZoQdOfOnRKJpLS01OJuJSUlzZkzh1nexMTElJSUY8eOcbl/+D927FiVSvXGG2+oVKpbt26hgyNGjEAv8ikpKQ8ePKB7FS6X6+fnx2azR4wY8ejRI2TQxcUlPj6exWLt3bsXtXgcDocgCGOGxo8f39fXR35B0r2Sg4MDfoPs6+sbN24cXXdFIhF6yzx48CDdvLq8884748ePDw0NDQkJwQfxm9KhQ4foGpw2bVpJSQl6L0D/6rJt27bw8PDU1FTT5g4dOoQFraysHDFiBDl19uzZOFWpVC5YsICKixwOx9/f/5133kF/Ojk51dfXW0pQvTAWNCIi4smTJwRBAAAIgoiLi2toaCAIorm5Gaspl8u9vLxSU1N1+/Z6EAqFSqWSrGlAQACPx+PxeAEBAb/88gtO+uWXXyh6+cUXX0AIFQpFcXFxcXFxdXU1svBvE3Tp0qVIuIcPHx48eHD48OFCoXD69OmnTp0KCwtbtGiRSCQSCoWDBg2i583hw4fhX8nIyMBvjYiqqqo33niDosFTp05BHRobG19FQ4xAgmo0GlqClpaWEgRx/fr1mTNnzpo1Kysra9euXRbwZvLkyWfPntWVAHPv3j3y4JtJRCLR/fv3tYzQHVihBRJUrVaHhoZSz5WWlobKdVdX13fffff555/Tem0xho2NTVxcnO4Ac1dXV1FREa0BEYRIJEpJScHl3dPTEzfTrwIkaH9/P61cYrG4qakJDdbV1dVptR8WYMWKFatXr75y5QoSIicnZ+7cuYytsdlsgUAgEAioDG2YCTNBEW5ubvPmzWM+OG/FihUrVqxYsWLFYsTExBAEUVJS8k878q+GRu8aDQaPHj3azEs6Ozvv2LHjxIkTZtrRYvHixVohEadPn4YQTpo0yUzLXC531KhR8fHxbW1tmpdUVFTEx8cbibsyTXp6OnrvZJbdzs7O19f34sWLUqm0o6MjICDA1dV19EvM8ozF8vHxUSqVDQ0N+EhkZKRarQYAXLlyhbFZDoczatSohoYGQwMad+/efe+99xhaLysrYyYom80ODQ3Ny8sjj8729/dLpVKlUimXy+VyeV1d3dWrV52dnZn5lpWVBQC4dOkSPtLS0gIAaGpqio2NZWbTzs5u69atRoaHEImJiUysjxw5squri5mgLi4uMpmMIIjq6mo0d3b69OmwsLBJkyahoQculxsWFlZTU1NVVcVAU6FQWFlZCQBYvnw5PtjT0wMAmDZtGl1riEGDBj148EBLu76+vra2tra2NiQFhDA7O9vX15fJBb788ktkorCwkEF2e3t7JyenAQMGGDnn22+/JQji3LlzdI3PnDkTAPD8+XMcMhQfH69UKltbWylG9ehCDk5AUhYVFc2ePZvFYg0bNiwrKwsdd3JyYmafdeTIEWTi888/Z2jCKIsXL5bL5QCA7du3082LBL1//z7609XVFUIIADhz5gwzZ0aOHPnw4UOspkKhiIqKEggEo0ePDgkJwcUTQuju7s7sEiy5XA4hlMvl8+fPZ2jCMOvWrUMP1s8++4xBdi1B9+7da6agM2fOJBfPS5cuCQQCvcPtDx48MF7t9BMcHIzy4xlaC7JmzZqnT5/KZDKxWMxs/BEJWldXh6J6sKDvv/8+M5dEIhFZNYlEcuvWLb0t0vHjx5mMkWNzlhVUIBAkJSWpVKquri4/Pz/GdpCguFFCs4EAgGHDhjEzOGDAgLy8PL0KYsrKysLCwnDMLD1OnDgBIVSpVDgmyXyEQmFNTQ1BED09PYzbYgRZ0IEDBwIAIIS3b99m3LedMGFCa2urISlLS0vDwsIM1XTTxZXP56O27NChQ52dncxc1MLd3T0jI2Ps2LFpaWmRkZF9fX3mWGOz2eRnBQqZqqyslMvlDKx5e3tfuXLF0IxZREREenq6WQ4vWbIE/c+8ePHi8OHDzA29xMPD48mTJ0qlkkFAh15wCW1vb9+0aRMqof39/T09PRKJhFbHlsvlbtq0yUhNr6qqYhzR9Qd+fn7t7e0QQh8fH4tMAfr4+PT398tksjfffJPP5/P5fIZPopdgQTHoGVpfXx8bG0vdOJvNPnr0KFm+7u7uzMxMdPuYjIwMs16UJ06ciAzt37+fuRUdm+fPn8/Jybly5YpCoWhoaDhw4IBYLBaLxfPmzaM7IWpIULpBtlwut6amhqwd+oLGz88vLS2to6MDH9+wYYNAIKBl/E9QowkhtEh918XR0dHJyWnRokUff/zxjh07CIKge6GJEydKpVIAgEKhWLdu3ZkzZyCjVp7H48lkMqza5s2bcUlcs2aNQqHASVlZWcwjClAT/+oExQgEggMHDhAE8dNPP9HNO2HChBkzZqC4LdwPpSuojY0NlqylpQWVwVmzZhUWFpLVhBBu3LiRrod/0t3d/TcI6uvrW1tbSxBEcXHx66+/bo4piwiqUqmOHj0qFovJkZ0Qwt9++y0mJsasgJfc3FyCICwuaHR0dGBgoK+vb3JyskKhUKvVTU1NCxYsMLOBYpEEffvtt+nmPXjwoL62/Q9SU1NXrVplpnssT0/P9vb24uJiHx8fc229ZOrUqUqlEgWyKpXKwsLCWbNmMfgOVy9Y0GPHjtHNa2Njk5SUpCtlc3NzVFTU3xA+xJCFCxei0ZCUlJQZM2ZY1nhgYGBeXh7jNyU7O7u4uDhUVDs6OuLi4saMGfNKg9qsWLFixYoVK1b+Jj7++OOuri4U0FFfX894ptMkbDY7JiZGo9E8fvzY8rH35jB9+nT0/SwZsVg8ffp0Btba2trInfAHDx6Eh4db2mXWoEGDyF8GmTnLYEm0pEQ6isVi/Cddg1qCojfx2tpac0eCSTg5OVVVVWH7dMet/yQoKEihULS2tupdGIkBWmreuHEDJ2FNyQepoCsowlIz3h4eHmVlZWhAA6k5dOhQJoaCgoJkMhkass3NzdVKtbe337Zt28WLF6kbnD59OpYMDSdr1XEsNK26HxwcfPLkyZMnT27evFkkEh07dgzd+e3bt6kbMcTQoUOlUimEEAnKXM21a9cqFAqFQrFr167vv/8eAKC1uFp0dDQAoKenh/oKPrh4GqrXWHG6mpKxtbVFMxnNzc1mjgSyWKxVq1ah/x61Wp2ens6wpn/yySfd3d29vb2LFi1isVhDhgyRSCTk2DZ0GQDAvn37KNokV3Yjp5nzMMVMnToVQlhbW2tmlKSrq2tlZSUS1PhqMcawtbUtKysDAOzcuRMfDAgIwH+6uLigx1Zubi71YRiKRY/8WGDgvLOz84IFCwoKCiCECoUChXcxpqioCKlZWVnJvIkLDQ2FEBq6Hy6Xi6Iq2tvb8TJWJqFYPBG4kNJw+iXr16/HjdLVq1fNGWPdsmUL/sY9ISGBsR1WRUUFAEB3NTGEr68vaqb27NlD0SDdJyM+n0Gtv3v3rkVaeTc3N7ysYWlpKXmOc9GiRdu3b1+yZAklQ/Pnz0eR1Hr7SePGjXv27BkSlPpkAN3HoqUEzc/PZ7bkGYfDKSgoQEPgKpXK2dnZzs5u37599+7dQx/go6Ta2lrTyxCEhYUhvXQFFQgEP/zwA0rVXV7HCHQbbvx8YNDQkwWFEFZUVJDDmini7++PsisUiujo6H379qF19nS7uqbjQw0J6uHhkZ2djSMJaK1ySf3pyWLUyjs5OeGpvdWrV1dWVpLDZVtaWmjFGfN4vNzcXJS3sbGxp6cH/Ualtb6+vrm5GR0pKSkZOHCgaYt1dXUow2ESUqlUo9EUFBS8ePEC0lw2lLpAzF6W8vPzt27dSj4yePDgDRs2PH/+HN0IranKadOm6ZbE9vb2kpKSefPm2djYfPPNN+jgN998Q8mit7d3fX09DmjBrF271sHBoaOjAwBAa6afokZYTbqVva2trbGxcdKkSVq97o8++oiBoOjrJi3Onj3r5ubG4/HwekGbNm3SWzz19CKrqqp8fHwmTZrk7++P62l6evrDhw/xOXfu3KHuIsK4Rjdu3CCfQDcQdcSIEWiZpaKiInxw5syZtIwYYdmyZQEBAR0dHZMnT2axWGVlZfn5+b29vRYwjUoos2eo3lpP7lQx68+Tn5i6qNXqpUuXUremt4TCl+/yEEKZTGZyfVUaMBCUXJexpmjITmvwidn7+4oVK9ADSi+//fYbLWvGBb1169bu3bsZOGkQ1GTRGl7VO5ysBbMXTcyRI0f0avrkyZPIyEhapmJiYvQKqlQqExISTC6/T5tly5bRLaEIQ5qaKSUmOjpaKzL23LlzY8aMoWuHzWYvX748ISHh8ePHyE5paWlCQgLzgFDjCIVCAEBJSYlQKKQb2KX1uGRcx/9TODg4PH36FADQ29vr5eX1T7vzn2DKlClSqZTiooxWrFixYsWK5UBrvZo522XlDzw8PCCEAIAhQ4b80778V0hNTVUoFLSXHLZiCKhv5XXzWbVqVXJyslKpVCqVycnJ+fn5sbGxUVFRFr/Qv4ulS5daXNAxY8Y0NTVpbSeGAAAcOXLEgtd6RQwYMKCiogJCuHr1aq0kE2EKjo6OFl8294cffjAUIcNms1EhXbdunQWvaG9vrxuQRB4vN8KYMWN4PF51dTU+4uHhcfjwYU9PT72zZCYEpT65RtG5devWGV/sjcPhREVF8Xi8pKSkX3/9ldmFQkJCyP8lw4YN09oL4tGjR1RCsrhc7vHjxz09PVtaWoqLi9Gi/WKxGC1P0N/fn52dTc8zNDhoqSpvZM0zXb777jsGl/D29jb0PMHDmgcOHKDSxjo5OTU3N2t9Mk6GyUizBQVduXKlkS1CIYR37tzBM7TozmkFpQYFBendewDT3Ny8bds2ik+w9evXo5CO3t5eiUQikUi0BM3IyGCywFBAQACEsLu728wPMUeNGqW1YSdGKpVmZ2eHhYXZ29v7+vqSk3bs2EHRPgplJectLS3NzMzcvXt3ZmZmZmZmSkoK9a402swMCYcX5poyZcqePXtwwOyUKVOYCMFisTQaDYTQzI59cnKyXjUlEgne8cvOzu7cuXPkVCq7/bFYrLVr1/b29uJcKpUqPDyc8f67fD4/Pz8fqfbjjz+S5zxmzJihVCoBABKJhPnXnw8fPtR6U3J1dQ0LC6M12E7eqgVCmJyc7Ofn5+fnR57axsttkTFp2d7eHoeEq9Xq8vJyWnOculRVVSE1ZTKZVsjJpUuXAABKpdKcrRFYpaWlWFBbW9uIiIju7m4AQH9/P8WVHl1dXbUE1W3o3333XfKiHojm5maTxidOnIiLJ/VHhBGQoFKpVOsz6alTp6KtHdvb2826AIQQAIBmOxISEiCE5eXlERERT58+ra+vp2IBb2WFSElJ0ZrwGj169M2bN7XUpB7gh57Ou3btssj37Fu2bElISNBdDQIVIwAA1UBGQ6SmpqJGbdWqVX19fRcvXkRypKenAwCoWNBaiBN9MmVvb8/n893d3b/++mvdsgkhPHXqFEUPf/75ZwjhBx98wPQWTZOVlYXaEgtsFDpu3Dh8k/jtQigU5uXlGdlXmwxyBXP16tX09PTbt28XFBTgnUG1UKvVum91hnBxcamrq7t7967FNoL/K5MnT0ZPFQAAeYs2hggEgoyMDHKMo6ura01NjZEoZy0MNfGGeP78OcVqxeVy0VT28uXLlUqlVqNsKebPn4869snJyUz6nnotorDmpKQk1Cip1ep9+/ZR7zp8//33FNWUyWTUdy9PTExcuXIl+r169er+/v5XMTCWmJgIAFCpVGZ+AvEXzp49i9/Anj59GhcXRyv7+vXrjYQfIVA4K62lbAiCuHz5Mv6zoKBALpdbMpKLxYqMjGxvbwcAMAg4NIabm1tSUpJUKo2Li6O0w7IO0dHRKEJCV0q5XH7kyBEGsdsSiYQsaHx8PISQ0jaG1PD29kY+AwAofkD3d+9oFR0drRunmpmZSaXLqUtwcPCpU6dCQkIqKipCQ0PPnj1ra2sbEBAgkUjMd5XL5TY0NLi6ugIATpw4YdZqYv8vODs7P378uKOjY/v27bh7b5ESyufzUX8RAEC9A/dfwN3dnTzscv/+feqfoxkBf45VXl5u8a3U/+24u7vn5eXV1NTk5ORYajkHLCj17vB/DQcHh3/aBSuvgP8B3hx7W8D2UQcAAAAASUVORK5CYII=\n" 105 | }, 106 | "metadata": {} 107 | }, 108 | { 109 | "output_type": "display_data", 110 | "data": { 111 | "text/plain": "", 112 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAHAAAABwCAIAAABJgmMcAAAML0lEQVR4nO2dfUwb5R/Aj+Z4TW1W3rrCOl8QdLLCppk0bGPENbpq2JKxmizMLDhHBJZNs6DbVFJBN+fL/hA3sZOpM5uAbODcQF4WAq6JzFIEhTGWgQLqulqY0m6X9rmrfzw/ntxa2l6Px6g/n89fcM9z3+d7H+65K71vnqMoAoFAIBD+NWzdunVwcNDj8XA8nE7nM888I5PJxMWUSCRRUVGB+0RHR0tn8ddnwYIFbW1t9fX1c7ZmZmZu3749NTVVSEoHDhw4efJkWVmZQqFYsGAB2q5Wq6Ojo4VECE5kZGRlZaXL5WL9UF9fL5fLQw1L0/TBgwdHR0fvu+8+31a5XK5UKletWmW1WtFA/kJptVqO465evfrAAw/4tpaVlXEct2XLFiFZWa1WeK4MDQ05nc6BgYHm5ubp6WmGYWZmZmpqatbMEhYWJvxgb8NgMPhTiRCYLp+XX34Z7nv16tXExES0nabpXbt2jY6O8uPbbDabzeYvVHd3N8dxdXV1c7ZCoefOnROS1dKlSw8ePHjt2rWOjg7+XBwfHz927NiRI0fQlpKSklAPmaIoqqqqyu1284+ttrb23XffjYiIqK2tRRurqqpCjVxfXw/3vXDhAprOEomkvLycP5zdbu/t7c3Kylq6dOmccbZs2cKybFChHMcdOnRIYG6JiYlKpfLVV18tLy+///77161bB6egTqe7efMmjOZwOB577LHFixeHcMy7du3i27RYLKtXrw4PD4etcrn85MmTsMntdu/cuTOE0DyhtbW1aOO2bdvgxsrKyvb29q6uLoPBEDgO9GW1WpcvXz5nh0WLFo2Pj3Mc9/PPP6ekpISUpC8bNmw4dOhQfn5+Xl5eT0/PwMCAUKdJSUm9vb38k0Wr1Xr1iYiI+OKLL2Cr2WxWKpXCM0NCrVZrcnIy3NjR0dHe3t7Z2RkfH69SqYTEYRiG47jdu3cH6HP58mV4WmVkZAjPMCjp6eknTpxoaWlZuXKlV5PEt3dCQsKyZcsCR3S5XN3d3fDn5cuXx8fHi0grPj4e3UALCgp0Ol1+fv5vv/02MTERdN8HH3xQIpkjeS/MZrOIxIIyODhosVgefvhho9Ho1RQ8p8bGxr6+PozZVFdXo5/RiWO1WgEAU1NTQiKEh4fv3LmTpmkAwB9//BGgp+8B46KlpUUuly9atCh418zMTDTZz507FxERMWc3jUZz/fp12E2tVgtPhabpjo4OuKPJZBLxKS8zMxNO5OHh4cA916xZ81dMeYqiVq9ezXEc/zYACbv77rvxjvQfJ/iUJ4QEEYoZIhQzRChmiFDMEKGYIUIxQ4RihgjFDBGKGSIUM0QoZohQAoFAIPwbKS4u5jjum2+++bsT+UcTwl3e4/F4PJ60tLR5DqlQKPbt2/fBBx/MM44X+fn5XiURNTU1LMuuWLFinpFpmk5NTa2oqJicnPTMYrFYKioqAtRdBaeuro5lWYEPJn2RSqUajeb06dM2m+3atWs5OTkqlSptlnllRlEPPfQQwzBXrlxBW7Zv3+52uwEAX375peiwEokkNTX1ypUr/iqRLl68+Pjjj4uMbjabxQkNCwvbtGlTS0sLv2bo1q1bNpuNYRiHw+FwOIaHh8+ePatQKMTldurUKQBAU1MT2jI+Pg4AGB0dLSsrExdTKpW++OKL/lQiDhw4ICZ6SkrK1NSUOKHJycl2u53juIGBgTNnzpw5c6ampkav169YseKuu+6iKIqmab1ePzg42N/fL8JpYmJiX18fAOCpp55CG2/cuAEAWLVqVajRIHFxcZcuXfJyd/PmzcnJycnJSaiCZdnGxkaNRiNmgLfffhuGaG9vF7G7TCaLjY2NjIwM0Oejjz7iOO748eOhBl+7di0A4Pfff0clQxUVFQzDTExMCKzq8YVfnABVnj9//tFHH6UoKikp6dSpU3B7bGysuPjU4cOHYYg333xTZIiA5OfnOxwOAMDevXtD3RcK/eGHH+CvKpWKZVkAwLFjx8Qlk5KSMjIygmw6nc6ioqKYmJi0tLS8vDx0erIse++994obgnI4HCzLOhyO9evXiwzhn9LSUnhhfeONN0Ts7iX09ddfn6fQtWvX8k/PpqammJiYTz75xPcCeunSpcDTbm50Oh3cv6urS1yKAdi2bduvv/5qt9sNBoO44mAodHh4GFb1IKFPPPGEuJSUSiXfmslk6urqmvOOVF1dTdN0yAOgcHiFxsTEVFVVuVyuqamp7Oxs0XGgUHRTeuutt6DQpKQkcQEjIyNbWlrmNIgwm816vR7VzCJIbRNmyPehmCFCMUOEYoYIxQwRihkiFDNEKGaIUMwQoZghQjFDhGKGCMUMEUogEAj/bZ577rmpqSlY0HH58mXRTzqDEhYWVlxc7PF4fvrpJ/gA/J9Cbm5uZ2en53YMBkNubq6IaJOTk14PzgoKCnCnTMXFxX344YdoFNFP/PHjpRJ6NBgM6NdQA3oJZVnW5XINDQ3dc889uHKOjY3t7+/nP7YTWeSi1WqdTufExMScCyOJwMtmZ2cnakJO+RuF4CsUguuJt1qtNpvNHMchmwsXLhQTSKvV2u12+ECxubnZq1Umk+3Zs+f06dPCA+bm5iJlBoPBd44j0SHNfZ1Od/To0aNHj+7evVupVL7//vvwyL/++mvhQfyxcOFCm83GsiwUKt5mSUmJ0+l0Op2vvPLKZ599BgDwWlzt2WefBQDcuHFD+Ao+6PT0N6+R8VCd8omKijpy5AjLsmNjY3feeae4IIjCwkL453G73XV1dSJn+vPPPz89PT0zM7Nx40aKohISEkwmE7+2DQ4DANi/f7/AmPzJHqDbfC6miJUrV7IsOzQ0NM8qSZVK1dfXB4UGXi0mEFFRUWazGQDw0ksvoY05OTno1+TkZHjZam5uFl43IfDU418WRCSvUCg2bNjQ1tbGsqzT6YTlXaI5f/48tNnX1yf+Frdp0yaWZf0dD03TsKril19+QctYBUXg6QlBJ2kISc+yY8cOdFM6e/asiLUOES+88ALDMDDUa6+9JjoOZbFYAAC+q4lBNBoNvE1VVlYKDBjqlRH1FzHrL168iOUuv3jxYrSsYU9PT0xMDGrauHHj3r17n3zySUGB1q9fDyup5/yclJ6efv36dSi0sLBQYHKhXhZxCW1tbRW35JlEImlra4MFgS6XS6FQSKXS/fv3f/vttxzHeWaXUh0aGoqLi/Pal9Q2YYZ8H4oZIhQzRChmiFDMEKGYIUIxQ4RihgjFDBGKGSIUM0QoZohQzBChBALhNjiOGxkZmefTLsL/UKvV8OlmQkLC353L/wsnTpxwOp2+3/UTRMLe/uIjXBQWFhqNRoZhGIYxGo2tra1lZWVFRUXYB/pnsXnzZuxClyxZMjo66vU6MQgA4PDhwxjH+ouIjIy0WCwsyz799NNeTUHKFORyufh32vnh888/91chExYWBk/S0tJSjCPKZDLfgqSRkREh+y5ZsiQ8PHxgYABtUavV7733XkZGxpyVA0GEii448JdcaWlp4MXeJBJJUVFReHh4VVXV999/L26gvLw8/p8kKSkpPT2d3+HHH38UUpJF03R1dXVGRsb4+Hh3d/d3331HUZTBYIBvOLt161ZjY2NomRUXF2Oc8gHWPPPl008/FTHEsmXL/F1PIAzDvPPOO0LusbGxsWNjY2AWeEXiU15eHnJ+GIVu3bo1wCtCWZa9cOHC2NgY/8hDKkrVarX8CmNfxsbG9uzZI/AKtmPHDljSMTMzYzKZTCaTl9D6+noxCwzl5OSwLDs9PT2fIiGKolJTU71e2Imw2WyNjY16vV4mk2k0Gn7Tvn37BMaHpaz8fXt6ehoaGsrLyxsaGhoaGj7++GPhH6Xhy8ygOLQwV1ZWVmVlJSqYzcrKEiOCoiiPx8Oy7Dw/2BuNxjltmkwmnU4H+0il0uPHj/Nbhbztj6KokpKSmZkZtJfL5SooKBD9/t3o6OjW1lZo7auvvuJXxT7yyCMMwwAATCaTmOWaICMjI17/KalUKr1eH1JBLCpjgxiNxuzs7Ozs7DvuuAP1Qctt8QkaWSaToZJwt9vd29u7efPmUI7Pm/7+fmjTbrd7LW3X1NQEAGAYZt26df52J7VNmCHfh2KGCMUMEYoZIhQzRChmiFDMEKGYIUIxQ4RihgjFDBGKGSIUM0QoZv4EFd651ol7tbUAAAAASUVORK5CYII=\n" 113 | }, 114 | "metadata": {} 115 | }, 116 | { 117 | "output_type": "display_data", 118 | "data": { 119 | "text/plain": "", 120 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAHAAAABwCAIAAABJgmMcAAAi3ElEQVR4nO1de3xM1/bf55yZTB6VB0K81bNUtFQ9KogbitsmqSpaqq7eoOLZj1tUb90Ql7pUxauEVkkpabzV+1FKbhEhXGnIlRJBRl4SmczjnDPn98fqrO7Z58xkMtzP7/5+n6w/5nPmnL3XXnvttdd+rO/Zh5BaqqVaqqVaqqX/MzR27Njr168rimKnyGQyxcXFBQYGeseT53lfX1/3afz8/J5xkKs0QUFBR48eTU1N1XzauXPn8ePHt2nTxhORFi9evG3btlmzZjVo0CAoKAjvh4eH+/n5ecKhejIYDImJiTabTXZBqampwcHBNWWr0+mWLFmSl5fXvn179dPg4OCwsLDevXsbjUYsyBWrAQMG2O32W7dudejQQf30o48+stvt7777ridSGY1GsJXs7GyTyXT16tWDBw+WlZVZrdbHjx9/9dVXfR3EcZznlXWihIQEV6pE8lBcmj799FPIe+vWrdDQULwvCMKMGTPy8vJo/sXFxcXFxa5YnTlzxm6379ixQ/MpKPSHH37wRKpOnTotWbKksLDw+PHjdF/Mz8//+uuv165di3fi4+NrWmVCCFm1apUoinTdtm/fvnLlSh8fn+3bt+PNVatW1ZRzamoq5D179mxAQADc5Hl+3rx5dHElJSWXLl3q3r17p06dNPm8++67sixXq1C73b58+XIPZQsNDQ0LC5s/f/68efPat28/ePBg6IJDhgypqqoCbpWVlYMGDWrWrFkN6jxjxgxam5mZmREREXq9Hp4GBwdv27YNHomiOG3atBqwphS6fft2vBkXFwc3ExMTjx07dvr06YSEBPd8QF9Go/HFF1/UTNCkSZP8/Hy73X7v3r1WrVrVSEg1xcbGLl++fNiwYdHR0efPn7969aqnOm3UqNGlS5doYxkwYACTxsfHZ+/evfA0IyMjLCzMc8lQoUajsXHjxnDz+PHjx44dO3XqVL169Zo2beoJH4vFYrfbZ86c6SbNjRs3wKzCw8M9l7Baev7557du3Xro0KHevXszj3h16tDQUFdtjmSz2c6cOQPXXbp0oV2h51S/fn0cQEeNGjVkyJA333yzpKSkoKCg2rxdunTheQ3hGcrIyPBCsGrp+vXrmZmZ3bt3T05OZh5VL9Pu3bszMzOfojTr1q3D6xdeeAEuHj58KElSWVmZJxz0ev306dN1Op0kSRUVFW5Srl+//klEdUMHDx4MCQnxqDN17twZO/sPP/zg4+Ojmaxnz54PHz6EZDXqUDqd7vjx45AxPT3di1le586doSPn5OS4T9m3b9//RJcnhERERNjtdnoYAOL1ej32HZhhhYSESJJkt9sVRbl8+bLdbodHOp1OEARM3LFjx5CQELvdzvM8x3H07AyS8Tyv0+kgL+aC+y+99BIM0BaLxWw206VDGlpEQRDwJly3bdsWFhrbtm2jEwMHvNOiRYvIyEhFUWRZ9vf3h7yMfUClcLyFvOoa0UVAmrfffpvjuJiYGEZgHioGfxRFAV3odDqmVjzPS5JEJwZpeJ63Wq2KoiiKwvO8IAgcx0EyuAC2KJzdbhdF8cCBA6hrnU7HcZy/vz9dAaaNISMhBBjm5OSUl5cLgjBmzJh+/fo1bdq0RYsWjRs3VhQlODjY19fX398/Li4uLS0tISEBtAMNgBxQNVApURRRA1AWVEetUKhp//79x4wZI4rikSNHULbfdEL/AfMBnWKvlyQJ2pamOnXqdOzYERpfEATQDrICUWRZ5nketEmXAkKDKn18fGRZVhTFarUSQgRBwDbT6XTAXJIk1Cywys7OHjly5KpVq9q3b3/s2DFQGaxn+vfv365dO71er9PpysvLi4uLQ0NDQX4oV1atu0BIWjaQBDofaJAQ4uPjY7PZCCFRUVFLliwJDAy02Wxnz56F1qJ7oZMh8DwfFRUliqKiKJIkJSYmgrLoPtu9e/c5c+aIoihJkiiKGRkZjRo1QlZ0D6U5cxyHj1asWCFJktVq/fnnn1G5anNA8vHxoSWE65YtWx45cqSoqOjBgwdGo9FoNF65cmX16tWjRo3asWPHlClTWrZs+fbbb9tsNrvd3qdPH7oIVzME4EwbELoaQRB69+69f/9+o9Fos9lu3bq1YsUKTAAZf3MdKKter+c47o033oBZvSRJR48eZaavzz///NWrV0VRBIUWFhYCX0IZEYqCTgAIytbr9WvXrhVF0Ww23717NyoqijjcLt0kahWjiwSdYgK4Q9cf/6alpZnNZlmWe/TooW5jMGRGg7/rxXEf/sbGxl65csVsNkuSVFxcPHLkSDqNE/n4+NAtBhZqt9vBiGBW36lTp8GDB588edJoNEqSBPYrSVJcXBytLLignTpdHmgtKCjo+PHj4NQqKytfe+01TIMmTDc7miRdW6Yf0Fqg7y9durSiosJms3Xt2pUWAzWLRSBD9VgUExMTExMjiqLNZrPZbDk5OZMnT2Y40O3Eqvm1114DRwbe7dixY8nJyVlZWdB3LBYLKFoUxXPnzsFSB/VIywG2SRsa3KlXr96FCxdgEysvL4+ZHuA18KQVB8MX0SLNiQHHcZs3b4bmf++99+j5DHH2RXAdERFRUFDQsGFDuhSDwTB79mxQiN1u37NnT5MmTUDvrgz8N9LpdPB40KBB0E2sViuoVXSQoihwYbPZFi5cyBgFM2uhK0YX+corr0CrSJJ069YtvE9PaLCSOBkgjAlQYuM1PfsBeQYPHgy1yM3NxQmv5nQCWC1dujQmJsZgMAiCUKdOHUEQTp48ee3aNVEU//3vf7/44ou4m8MQq1OwHeIw/o8//riqqgp0Z7FYQKcw6FdVVWVlZYWHh9PSgyUyZsJMP/GiX79+FosFWkW9t8YYEd33iaoraHc3BxODwdC5c+fCwkKj0Xj69Gk6ATSeup2mTp0qSdK2bdu++OKLs2fP3rlz5+7duzabbf369W3btmV6ofZcFQcNRrKhQ4cqimIymWBOoyhKcXHx1atXR4wYAXvgbqyGUC2GowQMlISQ1atXg7FbrdYRI0ZgFuiSTMdkNOuGwFRpiw4ICKhXr97p06cVRTEajdiHNFcQeBMmMNAXZVnetWvX8OHDcZyk25IeAH93UGrHBDIFBwePHTsWjHTjxo0TJkx4/fXX3WxWM0Mz4w2RM8/zdrvdZrNZrda8vLygoCD1CgJzocdgyoLZCGO8dInYeziOS0pKunHjxhtvvKHmg9IyM6omTZrExcVNmzZtypQpTqsgF9caiqCJdl509WhDVnc6N8Ss9pYsWVJZWSnL8qRJkzSFc2/7QGjItB9XVwebRM2NbgntUUW1lnXvYX7PAx1NzU7Tdmhe1WrTVWsRqtlcNTVtnrTF8Y4tAiR1J6OntJqP3DgQnFrQd9yvO5h5MVuYerDWTIarF82UarHovAypF1f0+kTtBPAv0Zr/06Q2FNqnIx9IRi8c3DNhJs6u6lVLtVRLtVRL/zGKj4+32+2w4VZLrqgGYxPsY7dr1+4Ji2zQoMHcuXOfevhs2LBhDCTiq6++kmW5W7duT8hZp9O1adNmwYIFBQUFioMyMzMXLFjgBndVPe3YsUOW5dLSUu+yBwQE9OzZc9euXUVFRYWFhX379m3atGlbBz2RZIR07drVYrHk5ubinfHjx8OO7f79+71my/N8mzZtcnNzZRd04cKFP/7xj15yz8jI8E6hHMe99dZbhw4dojFDZrO5uLjYarVWVlZWVlbm5OQcOHCgQYMG3sm2c+dOSZL27NmDd/Lz8yVJysvL++ijj7zjGRAQABt37mnx4sXecG/VqlVpaal3Cm3cuHFJSYndbr927dq+ffv27dv31VdfDR8+vFu3bi1atCCECIIwfPjw69evZ2VleaHT0NDQy5cvS5I0ZswYvPno0SNJktTIDg+pbt26v/zyC6O7qqqqgoKCgoICUIUsy7t37+7Zs6c3BSxbtgxYHDt2zIvsgYGBISEhBoPBTZpNmzbZ7fYtW7bUlHlUVJQkSeXl5QgZWrBggdVqvXv3roeoHjXR4ARQ5YkTJ1599VVCSKNGjXbu3An3Q0JCvONP1qxZAyz+8Y9/eMnCLQ0bNqyyslKSpI8//rimeUGh//rXv+Bv06ZNISD29ddfeydMq1atbt68ido0mUwTJ0709/dv27ZtdHQ0mqcsy61bt/auCAL7Q5WVlRDaf7o0efJkcKyfffaZF9kZhf79739/QoVGRUXR5rlnzx5/f//NmzerHegvv/zivttp05AhQyA/s+/9VCguLu7BgwclJSUJCQmebAaqCRSak5MDQQ5U6GuvveadSGFhYbTWzp07d/r0ac0Rad26da62/twRsnu6CvX391+1apXNZistLe3Vq5fXfEChOCgtXboUFEqjBWpEBoPh0KFDmhpEysjIGD58uOa2p3bQjjYWo9H46NGj4uLiDz/8kH6kDhmpCfYKoQjEmACHQYMGKYpSUVGBCCH1lqt6n1stsJ+f3/bt2xVFefz4cUxMzP3790VRzM3NrVOnDl1H7RCQswZQQc2bN0dfqSiKzWaDWS38BaBv79696bDzb9vKWEPwBQhNINTmLmANHz58yBQPiWmUEq1E4rydTCvdx8dn0qRJJpPJarX26dMHqkFr0xXKQ715Dnf8/f1TUlKg5jabDVYyKKePj48aDEI3GPpBbDCDwXDv3j0wfNkRAIZAOmDDrVYr4HzoVuF5XmMHngkrwgwUQF7NmzfXamBWC/TGs6ZRpKSk2Gy2/Pz82NhYNwzpOjPEBKjhd/To0YmJiWBBNptt+vTpam6a3Qu3seG3WbNmI0aMKCoqglCdxWIBC4VfWHoCFiQiIkKjd9LdkKh27Pv3719RUWGxWKqqqgDBj5pittDdKAUr36FDh5MnTwJO4tSpU8ybTgw3V8EleoMdOykEQsaOHQsVlmX59u3b7kVS3xQEoVu3bllZWZIkQSQR+jgEK7Ozs7Oysn766Se73X79+vVJkybRsZnfOSIug6iiETzPz5w5E5rIYrEQqlO46kSupNfpdPHx8aWlpYqibN68eevWrZIkpaSk+Pv7E2fnCL/PPPNMo0aN2rVrhz2GsSOaENnw/fffS5J048aNYcOGTZo0CbRMV5tTgVywYTiOi4yMRC8JThMQLpIkbdy40dfXt27duoSQ6Oho2IjRdsr48hbty9AMZ8yYASg7RN6iWNiqjAYBxYr8YTbz+uuvl5eXV1RUjBs3ThCEsLCwEydOlJSUzJs3DwcuqKogCFOmTNm3b19hYeHQoUOJs+tE5dJVgCBow4YNwfEB6lxdW94BunIVtYUFIaB3YHcNerrFYsFtNuwfdPMj6XQ6nclkov8TQmCaDTH00tJSAGlCD0LAJsI8ZQfiEu7wPI+ITo7jYB+EEJKSkvLw4cOkpKRNmzYRQgoLC6OiomJjY3fs2BEVFXXv3j2TyQT40Oeee65u3bpbtmwZM2ZMeXk5oZCniN2VKZgnz/OyLHMcFxMTExYWZjab161bxzlgm5iFc8BU9Xo9ImwhL7IyGAyga3SUYLk2m+3mzZuA6sfqQ0Y6uyAIOkmSoFPIsiyKIhogKE4QhAMHDkD+kpKSoKAgqCGIBXlRAhpVSyjwKoquKMq9e/egktBae/fuffXVV8eOHSvLsk6ns1qtBw4csFgs58+fB9AKURFdChgmDOuNGzdet26doihlZWVZWVlYBEqC3KCaNP4WGnL8+PETJ04EY0IlwF9QK6qS4zg/Pz+r1Qr2hJxp5RJXSyidTpeeng4zBiZ8ioO7evRQx3tbt249ceLEvn37MsyZ7sy4ZhpeSjNkuirP83PnzoX5zaJFi+hyUR7EW9CuiVCziD//+c+yLNtsNnRxZrMZLBGsijicD7zYi54K1O0kDVH5AvRohJBBgwaBKwFYGnH2m5rZNQmKp6cTzNSC0QKjODXYkR5ePvnkE6j8ypUrmULVYhBntWIVDh48CBMjeKtMUZSqqqoNGzaAtWHdNedhv/HRnDYTZxt0Zbx0hV3BHWiGjGlUO/GiExNn4BTdogDEnDNnjqIopaWl9MRAbcv4V10uz/NDhw49c+aMoijXrl2bN2/erl27Bg4cyGTXXG4yK7rfC3BlbrQuaJnUeC71jMRzcgMDYWYOdBGuSlGjV9SOiK6LG0BvtZx/J833urAwnPqpvSSypq+ZbkgzdC8rXYpaXMzOaUGleQrry1GvRtDiaTaVulLAHDTL2AqtE3Xd6TQu1w+M/+JcvKtBe1WOwtjTZajh2Jo9kRERewDqyJWoxBlDS5xbiDjeTwAOTCma9eKdAXRMe/PUJgadwD3QqpZqqZZq6b+RZsyYAXsuiqLcuHHD60hntcRxXHx8vKIod+7cgQD4fwtFRkaeOnVKcaaEhITIyEgvuBUUFDCBs9GjRz9tkUndunU3btyIpXgd8X/6xKgS9JiQkIB/a8qQUSgsH7Ozs5/8nBGkkJCQrKws5J+enu4lyGXAgAEmk+nu3buaByN5QYw2T506hY9Qp/RNT0itUKCnFfEODw/PyMiAd+BAmw0bNvSG0YABA0pKSmDD6uDBg8zTwMDAOXPm7Nq1y3OGkZGRqLKEhAR1H0dF16jvDxkyZMOGDRs2bJg5c2ZYWNiXX34JNf/pp588Z+KKGjZsWFxcLDve1fRem/Hx8SaTyWQyffrpp999950kSczhah988IEkSY8ePfK8Z6F5uurXqPGa6pQmX1/ftWvXQgjEVRDMcxo3bhw0jyiKO3bs8LKnf/jhh2VlZY8fP37zzTcJIfXr109PT6exbVCMJEm4V1Yt0Z3dTbIncaZIvXv3lmU5Ozv7CVGSTZs2vXz5MijU/Wkx7sjX1zcjI0OSpL/+9a94s2/fvvi3cePG4LYOHjzoOW7CQ9Oj3YIXwjdo0CA2Nvbo0aOyLJtMJoB3eU0nTpwAbV6+fNn7Ie6tt96SZdlVfQRBAFTF/fv38RiraslD8wRCI62B0A6aOnUqDkoHDhzw4qxDpNmzZ0PUU5blhQsXes2HZGZmSpKkPk0MqGfPnjBMJSYmesiwpp4R03vR6y9cuPBURvlmzZrhsYbnz5+H0CzQm2+++fHHH+OJDtUQHF0gSZLmPOn5558HFIkkSePGjfNQuJq6xael0CNHjtSrV6+mHAghPM8fPXoUokk2m61BgwYBAQGLFi26ePEibOPDo+zsbIgquyN/f/9ffvlFURT65CjYbWvSpMn8+fMBlPL5558TrT03ZtOMECIIwk8//QRxtPfee484tsvojUW4xr2viIgICOlAFEwdR8BctOStWrW6cOHCjRs3FEW5efNmfHw8bVYgW4MGDXCuQu/aMe/YCoIAWjObzcnJyYsXL/7xxx8RNoI4J0mSPvjgA7VgLM2fP1+W5ePHj0+YMKFHjx6RkZEvv/zy1q1by8rKFEWRZbm4uBhWsoz6QCA67iY4zveBjG+//TbWDbeuBeq9cGAF544pioLBXkzGHI+ChQqCMHHixAsXLiiKYjKZDh061LVrV5gw0nvPECWm96o1NTBixAjFgRYpLy+HY8FEUTSZTHfu3Hn06FFVVRUEnb788ku6wbRVy/P8ggULYCqLeDMg+Ks+ohd3kYlzEEVwnIslimJVVRX6Zc3XUYFJUlISHnIEYD/iGt6EqmnYsKEsy3j4xPLly1999dW33nprxowZWVlZV65cGTlyJCIVXEUfgHr16lVSUoJ+A4UpLi5etGjR+PHjq6qq4FFRUdGyZcuwt/1WKU04kV6v79WrV7du3UaNGhUdHT1w4MD+/ftHRUU9ePDAbrePGjWKqRihehC9rw676NDCzNE0hLIR+A0KCvr2228hcqsoyuHDh+kDcNQEufr3779582Z480VRFKvVKknS/fv3QQuK42CGo0ePZmVl0ZB4+j1nWqTJkyeDw4FzPECYJUuWhISExMTEgMFKklRSUjJu3Dgm9KDT6TQQh5pBJEEQ6tSpY7VaCwoKMAyFEQVX0RVoPYvFAvOPuLg4ug0gmb+/f2Bg4MCBAy9evAg1kWU5MzOTXuTw1OlxdAX0ev2WLVugbyJyBk9fg5qjO7bZbLNmzaIbQy0MIWTTpk3gKKFpRVEcMmTI2LFjf/zxx5KSEujpkiTNnj2blsTJ3jnVcRNq1RBCdu/erShKQUEBRniYSAstKHp3QkhqairYS25ubkJCQp06dQIDA+vVqxcYGBgcHNy/f/+UlBSz2QwDlyRJmZmZ9evXdx/Ug6fdunUDMBfABmDxhqBOxQFGtFqtVqu1sLAwKSkJsjPIS6Q2bdoUFRWB9qGFrFZrenp6VVUVvFi1d+/ebt26dejQAUwhODiYlZP+HxAQ0Lp1ayYKDwn8/f2vXLlitVrnzZtHKBibOlClttaWLVvKDkgbdCV0xyA0XBuNRkVRzpw5U79+ffSP7tHVzzzzzL59+6DyOATb7farV69WVVWVl5evXLkyKSlpxYoVixcvpkGQhIpL89SxXc8991xRURGAR+B4KqvV+uDBg8OHDy9atKhjx44tWrQICAiAPeyAgAB3Z3ViVBoLpr0BzCWTkpLUUyW1Kant/bPPPoO+jEc5goOjL8rLy3fs2AEzO1cBTjW9//770K9BlQ8ePDCbzc8++yyhYqXYPIiZwe4F7ouOpZ8+fRqEhOZPTU2lp18w+ODhmUQ9YGJQlFYTHorFOWBJRqNRFEX1ZJ4OtzL9nVCndXAc16pVq5kzZ3799ddgTdDBobfm5uYmJSV17NiRxk65CX9jKTzPw5Tu+vXrK1asCA8PRywUizdywUrtTPfu3Zubm7tw4cKEhISFCxfS/ZVzEB7OiWbnZGeMQTG9zNfXF94CsVqtYD5ucCXMlJ6pPE7pDQaDmx7tClpBtxMzFdcUhlCuiU4A5iKozndhSmQkpCew1WwJMcrGKSQg1tq0aQPu6ZNPPnGlAqZI3nE0nGYNaUOmwXU0hglMWxAEPz8/pkRm1UCopQH0Nl9fX1ftxPQ8tWNRm7Dm7IpJzwAy3OUJDQ2F9xolSaKPDof6090BRyRXK0Xa1zBLT+QJF9gSer3+pZdeclUZdUb1fXUp9MJMfZ/hRi9S6GTq4ff3glypkoGaAFULNWHaX50AOwGjeryvdp3qWmEWNDFmsoE9XT1HVnsSWpVuHBFRtZy26ujKQz3dbEC4EkWTm5p4LdCaG727EYM4q4ZzgTJksnuShp5lqz24ulfRmDI3zqGWaun/Ctnt9ps3b7o6iLSWakbh4eEwxNevX/9/W5b/L7R161aTyVT9Xn8teUiy84ePnhaNGzcuOTnZYrFYLJbk5OQjR47MmjVr4sSJT72g/y565513nrpCO3TokJeXx3xODEiSpDVr1jzFsv5DZDAYMjMzZVl+//33mUfVTNRDQkI83PXxnL7//ntX4EqO4yZMmEAIwYPinwoFBgaqAUn0qVluqEOHDjqd7tq1a3gnPDx89erVnTt31kQOVA+y9w5w4Eq4yZMnuz/sjef5CRMm6PX6VatW0dWoEUVHR9NN0qRJk44dO9IJbt++7cmBNjqdbt26dZ07d87Pzz9z5kxWVhYh5G9/+xscvmE2m2uElSOEkPj4+KfY5d2ceaamlJQUL4p48cUXXfkTIIvF8vnnn3syxoaEhNy+fZsOAUjOBHvtNaOnqNCxY8e6+USoLMtnz5799ddf8a/Vaq0RKHXAgAE0wlhNt2/fnjNnjocebOrUqQDpePz48blz586dO8coNDU11ZsDhvr27SvLcllZ2ZOAhAghbdq0YT7YiVRUVLR79+7hw4cHBgb27NmTfjR37lwP+QOUlc57/vz5tLS0efPmpaWlpaWlffPNN55PpeFjZqA4PJire/fuiYmJCJjt3r27N4ogjhfDn3Bin5ycrKnN9PT0IUOGQJqAgIAtW7bQT+/evesJ8/j4+MePH2Mum802evRor7+/6+fnd+TIEdDa4cOHaVTsH/7wBwhTnzt3zvt3vG7evMmslJo2bTp8+PAaAWIBgoCUnJzcq1evXr160ScB4XFbNFXLOTAwECHhoiheunTpnXfeqUn9WIIDRyDyzhxtt2fPHkmSLBbL4MGD3bFw71ays7NtNlufPn2ICsbkZpON3tRKcHwSGOJ0NpuN/roRJktNTQXYkNlshhhBUVERUe2Y0RmnTZt27do1+MYROGjF+Q1vzRgJveuqFhigXTk5OXjiDXCAt51BKpyEaeycYh56Q1Cv18MGs8FgyMvLe/ToUcuWLQkVGUWZmH14ek8XNxbT0tJkB/IAInTMPvycOXOgK8kO2A80wF/+8hdWXAf5+PgcP34cYAdTp06F7wkBsAkigP/85z+Zk8VoqTjqcwgQzsPqnzx50mw2T548ma5aixYtCgoK4LTTHj160HEnVGA1m6H4eP/+/YqiDBw4UBMhAim7dOkydOjQ/fv30+AplB5BHKCyESNGYKv4+fmlpaXh97AgFCxJ0qFDh+bOndunTx8mZIYRgc8//xxUn5ycDDd9fHz69es3d+7cR48eQXSaEUYzYKeOd/Xu3Rs+QoaDuMFgSE1NBflxOeAqpumkO+Ic9QcNXrhwQRTF6OjoZ599NjY2dtCgQf369evatev48eOTk5Pz8/PBNA4cOLB27dqFCxeqY0oVFRUgjcViefDgQf369SHNzJkzafgN+ATAesyaNQti5a6+bQHaFEURZpSoqRdeeCE9PR3wEwUFBbijDtqBejFvtDNtxlwIgtCqVSsI01dUVLzyyitqYTR0qvmJEo7jDAYDVNVkMlVWVoKgCNBQFOXWrVsbN24cMGCA5oFtoBEMwUuSVFVVtW3btm+//fabb765ePEiTpjBMAH6IsuyZoeg5YbSd+7cSdffYDCMHj0aUST79+93E4+iK06H3lDdmAveOBBFkT4hXR1NqX6G27x58w0bNoBLAvdRWlq6du3apKSk5cuXjxw5skmTJpwzxAxZ05HUkydPQhvAgWKoXPSVsLCxWCzp6enr168fNWoU7GerRQSAJ8/zgK97/Pjx9OnTIyIievTo0a9fv7KyMvDCJpMpKyvrpZdechMTE6hjogghMD3SdIXo1mGfgQkLgi92mkWhLWBhkZGRy5YtKywsVBTl559/HjlyZMOGDdXjsoex/z/96U/l5eVmsxlsB8diUCWC3Hbu3AkHfHAc5+vry1HYKSwI8SBTp05VHGdVwQWOeIqibN26FVYiDJqBGX+w+sTZVIGg6JdffhmRt2q4L3EfmIPiDQZDjx49li5dmpCQ8PrrrzNrJEb11do5pM/NzTWbzfApOtnxkTsENsmybLVambWDOjKMtdXr9Xq9/pNPPsnPz4fsdrvdZDIVFRWdOHGiS5cuiEZijEAtLT3505xLwfsrkiR98cUXhMIw0F4C0WA8QnFw2KWHI/c6UkOamCz0fYPBkJiYePPmTQYHu3379pMnT65Zs0bzi1xuSieE6HS6oKCg2NjYpUuXwgfJABTnat7KfJSJOGN7NLO0b98eB8yoqCjmXULk5jRVwCs4h0wQBFifAFM4/EqWZXiKLNR7eniT5gBHdMF9f3//0NDQGTNm1K1bFw784nl+4sSJcFac4jjmze78aWHmtDJ10VAc/nIcJ0kSJoPaYi2QA9YO/gqCAN4D0wPzuLi4NWvW6HS6e/fude3atbi4mJEKvu/LSoUN4qGNqGFWOMXR5One2GGNwIywmsAFZgpFWxkzaUerofEvrkZ8zMLIqdfrf/31VwCzf/fdd0x6ZIi/HHMYDHZzHPjUEBKe56HODFBLUz7i3ELIX/vMTWduNBhRcxzwcHDgnL/mpVmWJun1ej8/P1jXK4oSERFB8yQq63HyIXRhgvMxyYygtG3SDeCmYkx9eOqgH3VGtVLUvYGpGJNYnUA9OWGIHpfobuHqJu2UNZirTQnJDRKIyajWPnRkmP24YUIXQY+bxLXFqQdlZoQhzr2b7hbqlOpzwNR14bRe53CV+H8AplMfiRjffdAAAAAASUVORK5CYII=\n" 121 | }, 122 | "metadata": {} 123 | } 124 | ], 125 | "source": [ 126 | "latents = np.random.normal(size=[16, 1, 28, 28])\n", 127 | "images = test_data[:16]\n", 128 | "image_mask = np.ones_like(images)\n", 129 | "image_mask[:, :, 14:] = 0\n", 130 | "show_sample(images)\n", 131 | "show_sample(images*image_mask)\n", 132 | "\n", 133 | "samples = diffusion.ddpm_sample_cond_energy_inpaint(\n", 134 | " latents, cnn_model, images, image_mask, temp=1.0, eps=1e-2\n", 135 | ")\n", 136 | "show_sample(images*image_mask + samples*(1-image_mask))" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 12, 142 | "metadata": {}, 143 | "outputs": [ 144 | { 145 | "output_type": "display_data", 146 | "data": { 147 | "text/plain": "", 148 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAHAAAABwCAIAAABJgmMcAAADGklEQVR4nO3av0tqcRzG8Y+hoAgiOChKU26CiGuri0ObS4gIjjUHKTg7+R/4B9ioi+I/kEQtURAK/kCDwlElUY7cwYNXLrebeux0nnueZ+oc9MUXjRJ5i3Acx3Ecx3FbzKLlyS6Xy+fz/XGz1WqZ2dztBT07O7u8vFxf+v3+UCi0+YBer3dyckLz60UikU6ns1gslE82m82KxaLH49n+iCY1Y7FYqVT6zFIUpdvtXl9fWyw7/LL/x+YXeiwWu7m5cbvd6zv39/eDweDx8TEcDovIZDK5uroajUY7ndKk5sXFxXg8Xr8b8/k8mUw6HI7tz0Tz91wu13A4XHGLxeLh4eH8/FzLEU1uSjQaXb9FuVxOK2d6U0Sk0+koipLP54+OjmhuuX89+v39XUTa7fZyudR0OpqrBQKBl5eXu7s7jX/gaYrVarXZbCKSSqVms1m9Xrfb7TT3X6FQSKfTq58zmczHx0e5XKa5/5bLZaVSWV82Go3JZBKJRGh+ub//U7q9vd28bDabDofD6XTue0izmxKPx19fX6PRqIgkEonpdKooyunpKc095/V6+/3+29tbNptdf8TViJrZFBEJBoOrD7erPT09BQIBmlrdWq32/PxcrVaPj48PIJrbVLf5/RVNjuM4juM4jjPk2DYd2GTb9EMmQDOEYhqkGUIx2TbpaGI0QygmSjOEYsI0QyimiMGaIRSTbZN+JkIzhGKiNEMoJkwzhGIasRlCMdk26WHCNEMoJkwzhGKK4DRDKKbqQjRDKKY6lGYIxeQ4juM4juO4bxjbpgObbJt+yARohlBMgzRDKCbbJh1NjGYIxURphlBMmGYIxRQxWDOEYrJt0s9EaIZQTJRmCMWEaYZQTCM2Qygm2yY9TJhmCMWEaYZQTBGcZgjFVF2IZgjFVIfSDKGYHMdxHMdxHPcNY9t0YJNt0w+ZAM0QimmQZgjFZNuko4nRDKGYKM0QignTDKGYIgZrhlBMtk36mQjNEIqJ0gyhmDDNEIppxGYIxWTbpIcJ0wyhmDDNEIopgtMMoZiqC9EMoZjqUJohFJPbbb8A/1X9jy48d6UAAAAASUVORK5CYII=\n" 149 | }, 150 | "metadata": {} 151 | }, 152 | { 153 | "output_type": "display_data", 154 | "data": { 155 | "text/plain": "", 156 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAHAAAABwCAIAAABJgmMcAAAB80lEQVR4nO3aMWrCcBiG8aRbliC4iJs4CuIhsnqA4A3MLOQS3sALeBddHB1U0C2jDmIIOBRSKYiKL/7zmee3Vdq34Wmn8HkeAAAAnuC/88NhGLZarX8frtfrOm++FnQ4HCZJUn7Zbrd7vd7tN+x2u263y+Zjg8Fgs9nkeV7ccT6fp9Nps9l8/hFruhlF0Ww2u7dVFMV2u03T1Pdf+Gf/4s0H61EUzefzRqNRfrJYLPb7/Wq16vf7nuedTqfJZJJl2UtPWdPN8Xh8PB7Lv8blchmNRkEQPP9MbP4Jw/BwOPzO5Xm+XC7jOH7nEWuy6Xc6nTd/PW79uH6Ab0NQMYKKEVSMoGIEFSOoGEHFCCpGUDGCihFUjKAAAABwgdsm8Sa3TY42DdwMWdmsyM2QlU1umz64aeNmyMpm1W6GrGxy2yTG+1AxgooRVIygYgQVI6gYQcUIKkZQMYKKEVSMoGIEBQAAgAvcNok3uW1ytGngZsjKZkVuhqxsctv0wU0bN0NWNqt2M2Rlk9smMd6HihFUjKBiBBUjqBhBxQgqRlAxgooRVIygYgQVIygAAABc4LZJvMltk6NNAzdDVjYrcjNkZZPbpg9u2rgZsrJZtZshK5vcNonxPlSMoGIEFSOoGEHFCCpGUDGCihFUjKBiBBUjqBhBxa4wg/Ve5ngsIQAAAABJRU5ErkJggg==\n" 157 | }, 158 | "metadata": {} 159 | }, 160 | { 161 | "output_type": "display_data", 162 | "data": { 163 | "text/plain": "", 164 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAHAAAABwCAIAAABJgmMcAAAXaUlEQVR4nO1de2xUxfef+9jttiWFYou8amkxagk1SsVETYRgiyhUEQU14AsVFIGQiEBDIgKN0aAQBao8YovPImrFGiMUhVJBq0gLlfJaS9NSSyvQ1273cffe+/vj0NPTmXu3mK9R+nPPH5vduXPPnHPmzJm5cz47l7EIRShCEYpQhCIUoQhdBkn/y81xcXFXX301V3j69On/Ms+/ZtDs7OwXXngBfw4bNmzUqFG0Qm1t7ciRIyM8e6ebbrqppqZG0zTdhvx+/5tvvjlw4MDLF/E/yjMzM3Pr1q12vHRdr62tXbZsmST9BWf/f8yzF+6ZmZnbt28fMGAAlhw6dKi+vv7o0aM33ngjY8zj8SxevPj8+fN/Scr/KM958+Z1dHRgbwSDwZkzZ0ZHR1++TBGe3RQXF3f27Flgp2nar7/++uijj/4vIv6HeCqKIssyLcnPz//jjz/8fn8gEAiFQi0tLVlZWampqViB1ocIgiWSJKmqSrnhpQ8++MA0zVAopGlaKBSqrKy8++67aWVFUbgb4V6pi0Th161bZ5omKN/c3Jyenq4oyl8KlBypqjp16lTTNA3DMAxD1/X169dHRUXROsgfviiKoiiKw+HoLoJPVVVlWR4zZkxzc3MwGAwGg6FQyO/3G4bR3Nzc1tZ24cKF1atX9+vXz1IOakpsUlVVNNOuXbtCXQQSh0KhTz75hOtO5CbLMvIU6zDGnnrqKcMwTNPUNK22tjYtLQ3K0aa0myVJ4jrMkm1cXFx1dXWwizRNW758OTUlGA4VxC7nmcuyDFU3bty4b98+XddN0+zs7AS+fr8f3MowjIqKinvuuQdFZ4w5nc7BgwdzkkmSxHXs999/D2zBBJqmgU3Xrl1Lq9ERI44eVVVBzqioqM2bN7e2tno8nsLCwuHDh6MicAvHBHqI2oKqQO3w5ZdfgmwgbTAY3LZtGyPOoShKbGws1bTbvlhqGIamaYwxt9sdHx8PTtTc3FxYWPjRRx8piqKqKrQxatSoBx544Omnn2aM6brOGIOBTGWCL8FgEJxFlmVVVffs2WOaptfrbWtr83q9pmnCvTNnzsTuFPU0DIPaFDqDMRYIBDIyMkDJ999/v6mpCe4yDANkMAwD9ITBiwaFdmkwYYy5XC7g73A4CgsLVVXVdV2WZU3TJEmaPHny119/HR8fD/rquu7z+bBvwEVs48zWrVv9fn8oFJo7d25CQsKYMWM+//xzr9cbCoU6OztBn+3bt4OlMFbAd3AEalN0lrVr1w4ePHjNmjXFxcUzZsxISEhISUk5ePAgxOi8vDzkQHue2Qx2aGjv3r2BQMDn802YMIFrkfZKj3Ui9SbiwvBl2rRp33777YQJE8CfIDppmubxeHRdz8zM5LocWSGH7uAK/SZJ0tmzZ4cOHer3+0+dOnX8+PHk5OSxY8d6PJ7Y2FhVVZuamg4fPrxo0aJTp05xGmJH4aeqqui5mqaVlpY2NjZKknTmzJmUlJTRo0enpaW1trYmJCTs3r17ypQp0P/IAb6AYBzBpYyMjB9//FFRFK/XW1paKstyS0vLgQMHfvzxx6NHj4KrMuK2rMs9OZWxsLa2tqqqav369cXFxaqqmqapKApYtq2tbfLkyb/88gvqBXGTDk1rJ509ezYEOF3XwSsNwwgGg+3t7WvXrh01alT//v3thienM3yCIyxcuBBWITgdYZyqq6uLi4sDC+K0hvMmDaOctzocjoKCAgjuMNHDjO/xeD7++GNOHvRWl8ulqiq3FAFKSUlhjMXFxTU1NYHWILPf7zdNc+LEiTAyuv2xy9+7LSA6bWpq6qpVq7755ptQKFRRUWGa5qFDh3JyciZOnAgrW7pocDqdtAQNTRcPMDRgDjlx4oSu6xCvTdMMBAJFRUVTpkwRhyGGEdQWSvAnDSYNDQ2NjY1tbW3gBGCFnTt3UmuiEZ1O51VXXSX2PaWMjAyclGCdU1lZKU68f4EkSYIxLssyTNZixOH8BWdA5EA9Fyo7HI6lS5e+/PLLK1asWLx48aBBg8RFK7qk3dqT9RwTKEZ6enpeXp7X6wXn4pyUchOXd8gWS/Lz80+ePFlTU+N2u2fNmjVkyBDqH5QtbwqOO/a8XSy3UxLc0+FwcKLjqoX1NDdnSjojUXvRW7gWxSWqLMsDBw4cPXr0oEGDKE9kQtfFjLE777xTFABDDRqErmctdcfJufu3WJUr4fyR/uTWumEaZmTkcnXAo8O0EoYnd4tlTVrOzQF2JAnLeGaz6uBbxE6g6ww+3Ary0Z90MqE/qSainiAuXYTCY5ylxFhuJ5I4HwBBoLfrG1mWnU4nhBq8nXNJahY6Q9r2MbAI4w7AkTOi6MIYLmC80DkKyW5daVlOBeBsLZFlL1WV61pOBfEnBh+YMJBEU6KVuOjXI/KghtSnOD/HtqGrRVUZCRrcT7grjE/RkG0XZOhaghs3UAemTW6pgNKGiV3he5FZzSX0O8jfK5MIRShCEYpQhCIUoQjxFME2/c08I9imf4lnH8AM9RWeVwhmqK/wjGCb/kGefQMz1Fd4XomYoT7Bk27t0d29NWvW6LoOGcpgMPjUU09xKSOaLUAOsKWIe+90A3DIkCFutxvSwh6PJzc3F6J4mPSOpZBci/Hx8T6fD/J9uq4HAoHs7OyMjIzY2Fi63Sdmd7g9ZtyuZYxlZmYiziMuLk6UQZS2e2NTzEMwxkaOHNnc3AyZ3ubm5nvvvRfV5jYTJSvAEHYPKh8dHb1jxw7I8RqGUVRURA1HETKWejKr7VqoPHToUEh6Q99rmhYMBj0ez+7du1etWoV4EMozTA4C2po9ezZAHLxeb15e3vDhw5OSkiwzb5zdujlzbXz11VeYjF21ahXqQzeYMYeFNhUFRets3rwZsuemae7bt2/58uWwZLOMStR9uBQb0qJFi1auXJmUlMQYe/PNN6GfQqFQMBiEQQDC79ixA1BAlhaknDnTP//888DHMIwzZ87cfvvtDodDxKGg+rw30FygJEklJSWgvGmakIwVN6W5jpKsUIzQEsArQGcM883NzVu3bl2wYMGwYcPAgqLVuAQBleG9996DHrp48aLb7fb5fJ2dnYjP0HUd0ER+v3/JkiWoM/V3yNSjLRAG8Omnn5aVlV24cAH6Bly1pqaGeiiXiLWTv3sIr1ixApQPBoMAEa2trZ03b97gwYP79+/PSJ7DLstE0yoOh6OhoQG4GYbh9/sBkQEx2uv1Hj16dM6cOYmJiVDZMqnA9d/KlSvBgoCOAzjYiRMnAOIBTYDFc3Jy+vXrxyXTMS1IhypIi54EX6AJj8eTnJyMtuNSeGHCyKWIlpWVtX///tbW1pqamlAoFAgEUMqSkhJLl8ESLusJDK+55pqSkhKEsAKoEwl99uGHHxaDtWWuSZble++9d//+/YAK1glgRtd1GPtg0Mcff9xSU8vJgzF27NgxAKABT5C2vLyc9YzjkDFED+2BScBKgI0CbZ1O5x133HH48OHMzMzVq1cDfFnXdUVR2tvbk5KSAoEA60INKoqi6zr+xCYhBkETiYmJ06dP/+mnny5cuGAYxpw5cxRF8fl80dHRuq4/9thjgwYNkiSpsLDwySefRJGcTmcwGEQ7InNs96abbjJNc+7cubfddltiYmJiYiKAuRhjnZ2diqJkZWWVlZXBXVFRURB8RD9AbOX48eP37t2raRrYaPv27Zs2bbp48eKRI0dQHriUkpJSU1PDGAO8GBWP7yix9x544IGKioqTJ09+//33M2bMQJPZJfzEUIDTC5eyZ10ALofDER0dbZmhtMx0MmGySk9Pf+yxxw4ePHj48OHy8vJFixZh31B5YP6hQ4EuSK6++uq6ujqw+6233nrttdciKEpV1eTk5DFjxlAho6Ki+PEePmwxm9EnPhRznYG2oOYQF5sYlWiymruLa9ESXKUoimWWmxOMym+5BFYUpby8HJaMd911FzbhcrnS0tJSU1MHDhwo6o58ZEVRAFZo4a6C0JIkQVjAGRzKQTLTNGHsgIGQIVTTu7Cf0H8wq8K8BPfiSIQxjuUoBk6mCBdGcDOwwuDAwTIoZBmEBDkR2glKwXdd10tLS6Gh3bt3jx07lnVNYnFxce3t7R6PxxKvDGy7R59EcKcgE2hIDQ3fIcjSci6IwFQDHBAxK3WhZwHQDTWxUawPDB0OB1jHMAxAt0IsRigviootUvlBWw6py93Cka7ryGHnzp2g4GeffQZR0jAMn89XVVUFKzPOYijhpS6kf+DgRhM8SnJtY8ShAY4JCzS75TQypA+RHCvWE2ljt7yntyAGjQtfsoBAEoMJLbcbzpagK3yWtVWS3s+ttvC7uPzkaloqYHmJ4wN9YDmaLH/aoR45I1qCovAnrHVETbHE0qtofVuzihzpLIQ9b9eHzMaJRAoPBqJP97gCx1vwkuWUYvlkYSkecIAJkBFwKxP8g1ntp2AnhffrCEUoQhGKUIQiFKEIWVME2/Q384xgm/4lnn0AM9RXeF4hmKG+wjOCbfoHefYNzFBf4XnFYYb6KE/MUJaWlkKiyuPx5OXlDR48mP6jEf8KaBmScAcXt7ZUVU1KSiorK6Op46ampi+++CI3N5cCTyz3s+22yGJjYxMSEqZPn97a2goIJ8MwWlpaJk+ejHXEXXBm8492IJfLtXfvXl3XIa1kmmZLS0tqaqq4p6eQU1d6iCcWTZo0CRAJpmm+/vrrzz777LRp0+6//36l61yEy/kvLeu5//jwww8jKAGRBLqu+3y+7OxsiRATtjVlWRb/7omtq6o6duzY8ePH33PPPQD3cLvdNG8ORDeSOaLlU6dOhQOr2tra3G43oBGgh8RcAKfpJT7irvAjjzwCTgT5fkA5+P3+/fv333jjjTTVYbdVLOaK09LS6uvrPR5PXV3dW2+9NXPmzAMHDrS3t4MLJCYmiqxwq1y2P3WJ0jvvvAPghqqqKkur2Q0paoE///wTEl9HjhzZsGGDrusNDQ3Jyckc/MY2H8OdcAD0ww8/IFSKftE07aGHHsLNcEmSnE6nqqpctFZsDkt76KGHbr75ZirZhx9+CGifnJwcqjbrOW7CoL2QUlJSTNOEga/r+i233MIIuAijimx/UgUQqNnY2Lh8+XJASe7atQvmemooMVXF6DX6RZKkjRs3AvyGonzAVeGMLTGtxMJ2HWPst99+W7lyJYfTu/POO1tbWw3DgFPWOIXDZ2xiYmKOHTv24osvMsbi4+PhlC3DMLxe7yuvvBITE0M7hjsGBznTcA/f6+rqAMzU0tICQzMnJ4cLNXbWlCSJSZIE/kVHcVpaGh78B3aECO33+19//XV6P+UL+R8u8YI/gc/XX399/fXXM8acTmdaWlpzc7NhGNXV1S6Xi8txMiGScr0YExOjaVogEHj66ad//fXXjo4OQMpVV1dz2ip40l9XajYM4GXfvn0QNyD/fvr06SFDhjAhcFum8y4V0jiFlydMmHDhwgV0UsMwAoHAtm3b4HAUKoF4L72EKv3+++/I5+233y4uLg4EAgDqzM/PZz0He68paMZYdHQ0yObz+fCMwvr6+jVr1og30tEjDnP685lnnoHVkqZpFRUV6enp3FFK9DwFu6DcQ1YclTNnzty9e3dRUZHb7T58+PCsWbMQAij3RLFadjigSzAzPGLEiCeffLK8vBysEAgEYFFimubo0aMthQsDNZZlOTY2FrF8EOJ1Xb/vvvuobNy8TEeMnSFiYmJKS0t37tyZm5sLOEjOduIBMJYB0NaZ6bhzOp0UCGUpk11sBYGuuuqqVatWbdq0yefztbe3FxQUZGVliWsGSw8V88a5ubkAX+3s7CwrK5s+ffp1110HfDg8JTefsJ6wDFF4mKu5mT0MddfEGwBV4nA4EGrCXeIaprgi1gVdkgXEocjK4XDExsa2t7dLkqQTZKHYEC1BcCRegnbnzp0bFRV14MABOARN7wlVlGyO0aOH8ok/AfmkqiqYgoqBStlKyzmUIhw012vnIDsqEGUoltNYhhXwSYwTAIhD1vX6ZGE3tC1RPVRTbEjuQrRx1eiiWJSTJ64l6vbc5EhnPfHQV6zpcDjgVCQ7K7tcLrtFMto9DKrJ0kBcmJOsTtGUhH/oiLdYRj9ch2F578tkrqt7fcRkXVMQDU9cUA7jPjCtY6EYuexGCZ27acf3Ki0j4DIRuMotM2AkUSNQDBb8OyQ8uOgSidMUSiAOZE4rSwVEDSXhvDsMjpYBxxICJ4odfryznqNNPIyNqy9eomFBdJcIRShCEYpQhCIUoQj1ThFs09/MM4Jt+pd49gHMUF/heYVghvoKzwi26R/k2TcwQ32F55WOGbpieXKpMdiwSkhI+PTTT+E9NbW1td99992YMWOwjtLzEAFkYvn3W7pBFxMTs23bNrPrLKi2trZDhw4VFBRAz1vm3+nuPfdXQFmWXS7XF198Afm+hoaGI0eOeL3eO+64g27fcfvEvW5cqqr66quvnj17Fl6hVVRUlJGRwciOH013U86XFOcagM3BxMTEgwcP4r+xx48fz2xe98BtjFrmDCgNGzbswQcfLC4uhow8NLFnzx58z53YipiuwJIdO3bA6910Xd+0adMHH3zQ0NCQk5Pz6quv5ubmDh061DKbwoQdXtro0KFDq6urvV4vpP/gJCPuRjFN0K21RF4MgK2OGDHi4MGD+M64Y8eOLV26lG7EXk6iRiTqwunp6W63G/ETEydOZMIfvqOiosJnFyApD3KCQ4VCobKysmXLlv3www8vvfSSKKT4d1oOYDFnzhw4FcYwjJKSEnqJasGxRW6Ms1FcXNy777577Ngxt9vd0dGB7xTSNO348eMxMTHMxqkpHy7vZtn2pEmTYGUHbwKLjo7mYoXomKJxEduCL2gKhUJ//PHHDTfcMGDAgAULFgAghd4iZoQw3wfGLS4uRryMaZpvvPHGli1b7r///qlTp44cOXLQoEElJSV5eXn0/YWi0RnljlKeP3++sLBw3bp1lZWVYNP6+vrLzMVzxN01bty4ixcvghUCgcC1114bhpVFnOqK2hR6FQqFGhsb58+fDwcsAdEQJPYTZA/RAyCTSCFymqa1t7f7/f6ioqK6urpTp06BwGfPng2jHZ/t2bJlC/R5IBDo6OioqKg4d+4cou84tZUu4sotwTmSJLlcrg0bNrS3twOQs6Ojg4JjxUOb7HAfQFVVVXCmFoz9u+66CxuyHKTUvmJeCG5EZeHNmx6PB4Kpz+errq4uLy/PyMiAACWastvxuTT3woULMSRRfOzbb78t5q2483poag+xkpIkOZ3O11577cyZM3rXoY+tra3Dhg1jwsLAcrqnmUi8tHTp0vLy8vr6el3XA4HAuXPncCSGSZTS3qJoTlVVHQ7H7NmzQWWA3cKBZTt27HjiiSfQVtHR0ZZJUFmWbYE3o0ePfu65586fP79r167KysrTp0+/+OKLUVFRvaZ2sVWuWmVlJZz2BSP9559/zsrKogIx4a1zlq0oPU8umz9//rhx42BInTx5ksu/gxviqUucbGJzcHXatGlHjx6tra396aefZsyYER8fj1dpKLfEQHRzcblcdHbC2pQdE8YdLmNpqprCGrD5/Px8gAk2NDQsWbKEvmZPzGtjE9zaluMJlJOTA+OpqqoKXwhribGmgtmhgjkZ6Pv4xAT9ZaXjgSzXg9yqMwwwnDvukTGWnZ1dUFCwYcMGeJEl3niZns6sJmicZ9avXw8LBnyhL1fHkpXlQpITW1RNPMxHpqdYiagVzjs4QlE4d+be7UfloCejWC+GexI3qO0UBqLnfopLK5RWHO/cvCeiOkSy801wte4bLScByzvFS+F9nqtgGSssm7N73KKLc05sy+GCky0XOtAPJHIoHyckdbUw8YHOKz2a5x696URBW0JlLB+WLCnMhMsELJja87WVvb7eUBKgVxQUxL29j+tF2oql/PhwaCmG1BMCZfGgSNe3SFwdy2F1md6NStIKivDPjDCE3mQ3qmjMob5i8YDYVV8MyvST/tuB2sdO4P8D2tDuxzjUcPMAAAAASUVORK5CYII=\n" 165 | }, 166 | "metadata": {} 167 | } 168 | ], 169 | "source": [ 170 | "latents = np.random.normal(size=[16, 1, 28, 28])\n", 171 | "images = np.tile(test_data[14:15], [16, 1, 1, 1])\n", 172 | "image_mask = np.ones_like(images)\n", 173 | "image_mask[:, :, 14:] = 0\n", 174 | "show_sample(images)\n", 175 | "show_sample(images*image_mask)\n", 176 | "\n", 177 | "samples = diffusion.ddpm_sample_cond_energy_inpaint(\n", 178 | " latents, cnn_model, images, image_mask, temp=1.0, eps=1e-2\n", 179 | ")\n", 180 | "show_sample(images*image_mask + samples*(1-image_mask))" 181 | ] 182 | } 183 | ] 184 | } -------------------------------------------------------------------------------- /mnist_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | from torchvision import datasets, transforms 6 | 7 | from ddim import Diffusion, CNNPredictor, create_alpha_schedule 8 | 9 | USE_CUDA = torch.cuda.is_available() 10 | DEVICE = torch.device("cpu" if not USE_CUDA else "cuda") 11 | SAVE_PATH = "mnist_model.pt" 12 | 13 | 14 | def main(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--batch-size", default=32, type=int) 17 | parser.add_argument("--lr", default=1e-4, type=float) 18 | parser.add_argument("--save-interval", default=1000, type=int) 19 | args = parser.parse_args() 20 | train_data, test_data = create_datasets(args.batch_size, USE_CUDA) 21 | diffusion = Diffusion( 22 | create_alpha_schedule(num_steps=100, beta_0=0.001, beta_T=0.2) 23 | ) 24 | model = CNNPredictor((1, 28, 28)) 25 | if os.path.exists(SAVE_PATH): 26 | model.load_state_dict(torch.load(SAVE_PATH)) 27 | model.to(DEVICE) 28 | opt = torch.optim.Adam(model.parameters(), lr=args.lr) 29 | 30 | step = 0 31 | loaders = zip(iterate_loader(train_data), iterate_loader(test_data)) 32 | for train_batch, test_batch in loaders: 33 | train_loss = compute_loss(diffusion, model, train_batch) 34 | with torch.no_grad(): 35 | test_loss = compute_loss(diffusion, model, test_batch) 36 | print(f"step {step}: test={test_loss:.5f} train={train_loss:.5f}") 37 | step += 1 38 | opt.zero_grad() 39 | train_loss.backward() 40 | opt.step() 41 | if not step % args.save_interval: 42 | model.cpu() 43 | torch.save(model.state_dict(), SAVE_PATH) 44 | model.to(DEVICE) 45 | 46 | 47 | def compute_loss(diffusion, model, batch): 48 | ts = torch.randint(low=1, high=diffusion.num_steps + 1, size=(batch.shape[0],)).to( 49 | batch.device 50 | ) 51 | epsilon = torch.randn_like(batch) 52 | samples = ( 53 | torch.from_numpy( 54 | diffusion.sample_q( 55 | batch.cpu().numpy(), ts.cpu().numpy(), epsilon=epsilon.cpu().numpy() 56 | ) 57 | ) 58 | .float() 59 | .to(batch.device) 60 | ) 61 | alphas = torch.from_numpy(diffusion.alphas_for_ts(ts.cpu().numpy())).to( 62 | batch.device 63 | ) 64 | predictions = model(samples, alphas.float()) 65 | return torch.mean((epsilon - predictions) ** 2) 66 | 67 | 68 | def iterate_loader(loader): 69 | while True: 70 | for x, _ in loader: 71 | yield x.to(DEVICE) 72 | 73 | 74 | def create_datasets(batch, use_cuda): 75 | # Taken from pytorch MNIST demo. 76 | kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} 77 | train_loader = torch.utils.data.DataLoader( 78 | datasets.MNIST( 79 | "mnist_data", 80 | train=True, 81 | download=True, 82 | transform=transforms.Compose( 83 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] 84 | ), 85 | ), 86 | batch_size=batch, 87 | shuffle=True, 88 | **kwargs, 89 | ) 90 | test_loader = torch.utils.data.DataLoader( 91 | datasets.MNIST( 92 | "mnist_data", 93 | train=False, 94 | transform=transforms.Compose( 95 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] 96 | ), 97 | ), 98 | batch_size=batch, 99 | shuffle=True, 100 | **kwargs, 101 | ) 102 | return train_loader, test_loader 103 | 104 | 105 | if __name__ == "__main__": 106 | main() 107 | --------------------------------------------------------------------------------