├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── config ├── config.yaml ├── dataset │ ├── mnist.yaml │ └── vctk.yaml ├── loader │ └── basic.yaml ├── model │ └── mnist_score.yaml └── worker │ └── mnist_worker.yaml ├── diffuse ├── __init__.py ├── datasets │ ├── __init__.py │ └── mnist.py ├── models │ ├── __init__.py │ ├── components │ │ ├── __init__.py │ │ ├── conv_glu.py │ │ ├── unet_parts.py │ │ └── utils.py │ ├── diffusion_process.py │ └── mnist_score.py ├── utils │ └── __init__.py └── workers │ ├── __init__.py │ └── mnist_worker.py ├── notebooks ├── diffusion.ipynb ├── jax.ipynb └── visualise_diffusion.ipynb ├── pyproject.toml ├── requirements.txt ├── setup.py └── train.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # prevent notebooks from being counted in the language statistics on GitHub 2 | notebooks/* linguist-documentation -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | notebooks/.ipynb_checkpoints/* 3 | .DS_Store 4 | wandb/* 5 | outputs/* 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | env/ 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # dotenv 89 | .env 90 | 91 | # virtualenv 92 | .venv 93 | venv/ 94 | ENV/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | 109 | #intellij 110 | .idea/ 111 | 112 | # vscode 113 | .vscode/ 114 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Angus Turner 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diffusion Experiments 2 | 3 | An educational implementation of [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239), 4 | with corresponding [blog post](https://angusturner.github.io/generative_models/2021/06/29/diffusion-probabilistic-models-I.html). 5 | 6 | Includes: 7 | - A toy U-Net Model, which can be fit to MNIST - `notebooks/diffusion.ipynb` 8 | - Notebook used for blog visualisations - `notebooks/visualize_diffusion.ipynb` 9 | 10 | This repo is a WIP. I hope to add some more substantial experiments soon. 11 | 12 | ### Requirements 13 | 14 | Required: 15 | - Python >= 3.7 16 | - PyTorch >= 1.7 17 | 18 | Recommended: 19 | - Linux and CUDA 20 | 21 | ### Install 22 | 23 | ```shell 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | Uses the [PopGen](https://github.com/Popgun-Labs/PopGen) framework to manage experiments. -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | name: null 2 | model: null 3 | dataset: null 4 | worker: null 5 | nb_epoch: 1000000 6 | overwrite: false 7 | wandb: 8 | project: denoising 9 | defaults: 10 | - loader: basic 11 | -------------------------------------------------------------------------------- /config/dataset/mnist.yaml: -------------------------------------------------------------------------------- 1 | dataset_class: MNIST 2 | dataset: 3 | both: 4 | data_dir: "/home/angusturner/datasets/" 5 | train: 6 | train: true 7 | test: 8 | train: false -------------------------------------------------------------------------------- /config/dataset/vctk.yaml: -------------------------------------------------------------------------------- 1 | dataset_class: AudioDataset 2 | dataset: 3 | both: 4 | formats: ['*/'] 5 | sr: 16000 6 | 7 | train: 8 | paths: ['/home/angusturner/datasets/vctk_16k/train'] 9 | patch_size: 16000 10 | test: 11 | paths: ['/home/angusturner/datasets/vctk_16k/test'] 12 | patch_size: 16000 -------------------------------------------------------------------------------- /config/loader/basic.yaml: -------------------------------------------------------------------------------- 1 | # Good standard settings for the PyTorch DataLoader class. Note that 2 | # the batch_size should probably be set from the CLI for the specific experiment. 3 | loader: 4 | # settings for both train and validation 5 | both: 6 | num_workers: 1 7 | pin_memory: false 8 | shuffle: true 9 | batch_size: 32 10 | # specific settings 11 | train: 12 | drop_last: true 13 | test: 14 | drop_last: false 15 | -------------------------------------------------------------------------------- /config/model/mnist_score.yaml: -------------------------------------------------------------------------------- 1 | model_class: MnistScore 2 | model: 3 | nb_timesteps: 250 4 | hidden: 128 5 | c_dim: 64 6 | input_dim: 1 7 | -------------------------------------------------------------------------------- /config/worker/mnist_worker.yaml: -------------------------------------------------------------------------------- 1 | worker_class: MnistWorker 2 | worker: 3 | diffusion_settings: 4 | nb_timesteps: 200 5 | start: 1e-4 6 | end: 0.05 7 | optim_class: Adam 8 | optim_settings: 9 | lr: 2e-4 10 | weight_decay: 0. 11 | -------------------------------------------------------------------------------- /diffuse/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.0" 2 | -------------------------------------------------------------------------------- /diffuse/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .mnist import MNIST 2 | -------------------------------------------------------------------------------- /diffuse/datasets/mnist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import transforms, datasets 3 | 4 | 5 | class MNIST: 6 | mean = 0.1307 7 | std = 0.3081 8 | default_transforms = transforms.Compose( 9 | [ 10 | transforms.ToTensor(), 11 | # transforms.Normalize((mean,), (std,)) 12 | ] 13 | ) 14 | 15 | def __init__(self, data_dir, train: bool = True): 16 | """ 17 | Return continuous MNIST digits, scaled in [-1, 1] 18 | :param data_dir: location to save 19 | :param train: 20 | """ 21 | self.dataset = datasets.MNIST(data_dir, train=train, transform=MNIST.default_transforms, download=True) 22 | 23 | def __len__(self): 24 | return len(self.dataset) 25 | 26 | @staticmethod 27 | def denormalize(im): 28 | """ 29 | :param im: (B, C, H, W) 30 | """ 31 | return (im + 1) / 2.0 32 | # return (im * MNIST.std) + MNIST.mean 33 | 34 | def __getitem__(self, idx): 35 | x, y = self.dataset[idx] 36 | x = np.array(x).astype(np.float32) 37 | x = (x * 2.0) - 1 38 | 39 | return x, y 40 | -------------------------------------------------------------------------------- /diffuse/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mnist_score import MnistScore 2 | from .diffusion_process import DiffusionProcess 3 | -------------------------------------------------------------------------------- /diffuse/models/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angusturner/diffuse/cb0d8dfdd4c07dc09ad167904377109a973b15dc/diffuse/models/components/__init__.py -------------------------------------------------------------------------------- /diffuse/models/components/conv_glu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ConvGLU(nn.Module): 6 | def __init__(self, channels, c_dim=None, dilation=1, kw=3): 7 | """ 8 | Convolution with GLU activation and (optional) global conditioning 9 | :param channels 10 | :param c_dim: 11 | :param dilation: 12 | :param kw: 13 | """ 14 | super().__init__() 15 | 16 | # TODO: try without batch-norm? 17 | self.bn = nn.BatchNorm2d(channels) 18 | 19 | # main op. + parameterised residual connection 20 | padding = (kw - 1) * dilation // 2 21 | self.conv = nn.Conv2d(channels, channels, kernel_size=kw, dilation=dilation, padding=padding) 22 | self.up1x1 = nn.Conv2d(channels // 2, channels, 1) 23 | 24 | self.rescale = 0.5 ** 0.5 25 | 26 | if c_dim is not None: 27 | self.c_proj = nn.Linear(c_dim, channels) 28 | 29 | def forward(self, x, c=None): 30 | """ 31 | :param x: (B, C H, W) 32 | :param c: optional global conditioning (batch, features) 33 | """ 34 | h = self.bn(x) 35 | h = self.conv(h) 36 | a, b = torch.chunk(h, 2, 1) # (B, C // 2, H, W) 37 | 38 | if c is not None: 39 | assert hasattr(self, "c_proj"), "Oops, conditioning dim not specified!" 40 | batch = x.shape[0] 41 | c_proj = self.c_proj(c) 42 | c_a, c_b = torch.chunk(c_proj, 2, -1) # (B, C // 2) 43 | c_a = c_a.reshape(batch, -1, 1, 1) # (B, C // 2, H=1, W=1) 44 | c_b = c_b.reshape(batch, -1, 1, 1) 45 | a = (a + c_a) * self.rescale 46 | b = (b + c_b) * self.rescale 47 | 48 | # main op + residual, and re-scale to preserve variance 49 | out = torch.sigmoid(a) * b 50 | out = self.up1x1(out) 51 | out = self.rescale * (out + x) # (B, C, H, W) 52 | 53 | return out 54 | -------------------------------------------------------------------------------- /diffuse/models/components/unet_parts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from diffuse.models.components.conv_glu import ConvGLU 5 | 6 | 7 | class DownsampleBlock(nn.Module): 8 | def __init__(self, hidden, main_op=ConvGLU): 9 | """ 10 | Halve the spatial dimensions and double the channels 11 | :param hidden: 12 | """ 13 | super(DownsampleBlock, self).__init__() 14 | 15 | self.down = nn.Conv2d(hidden, hidden * 2, kernel_size=2, stride=2) 16 | self.conv = main_op(hidden * 2) 17 | 18 | def forward(self, x, c=None): 19 | down = self.down(x) 20 | return self.conv(down, c) 21 | 22 | 23 | class UpsampleBlock(nn.Module): 24 | def __init__(self, hidden, main_op=ConvGLU): 25 | super(UpsampleBlock, self).__init__() 26 | 27 | self.up = nn.ConvTranspose2d(hidden, hidden // 2, kernel_size=2, stride=2) 28 | self.conv1 = nn.Conv2d(hidden, hidden // 2, 1) 29 | self.main = main_op(hidden // 2) 30 | 31 | def forward(self, x1, x2, c=None): 32 | x1 = self.up(x1) 33 | feats = torch.cat((x1, x2), dim=1) 34 | feats = self.conv1(feats) 35 | return self.main(feats, c) 36 | -------------------------------------------------------------------------------- /diffuse/models/components/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | 5 | 6 | def unsqueeze_as(x: torch.Tensor, shape: Tuple[int, ...]) -> torch.Tensor: 7 | """ 8 | Add trailing dimensions onto `x` until it matches the rank of a second 9 | tensor whose shape is given. 10 | e.g) 11 | x = torch.randn(3) 12 | y = torch.randn(3, 4, 1) 13 | x = unsqueeze_as(x, y.shape) # (3, 1, 1) 14 | :param x: 15 | :param shape: 16 | :return: 17 | """ 18 | extra_dims = len(shape) - x.dim() 19 | return x[(...,) + (None,) * extra_dims] 20 | -------------------------------------------------------------------------------- /diffuse/models/diffusion_process.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from diffuse.models.components.utils import unsqueeze_as 8 | 9 | 10 | class DiffusionProcess(nn.Module): 11 | def __init__(self, nb_timesteps: int = 50, start: float = 1e-4, end: float = 0.05): 12 | """ 13 | Implements the diffusion process presented in DDPM. 14 | No learnable parameters, but relies on a trained 'score' model 15 | for sampling. 16 | :param nb_timesteps: 17 | :param start: 18 | :param end: 19 | """ 20 | super().__init__() 21 | 22 | self.nb_timesteps = nb_timesteps 23 | 24 | # beta = likelihood variance q(x_t | x_t-1) 25 | beta = torch.linspace(start, end, nb_timesteps) 26 | alpha = 1.0 - beta 27 | alpha_hat = alpha.cumprod(dim=0) 28 | 29 | # variance of conditional prior, q(x_t|x_0) 30 | # = N(x_t ; sqrt(alpha_hat) * x_0, prior_variance) 31 | prior_variance = 1.0 - alpha_hat 32 | 33 | # forward process posterior variance (beta_hat) corresponding to q(x_t-1 | x_t, x_0) 34 | alpha_hat_t_1 = F.pad(alpha_hat, (1, 0))[:-1] 35 | posterior_variance = (1 - alpha_hat_t_1) * beta / (1 - alpha_hat) 36 | posterior_variance[0] = beta[0] 37 | 38 | for (name, tensor) in [ 39 | ("beta", beta), 40 | ("alpha", alpha), 41 | ("alpha_hat", alpha_hat), 42 | ("prior_variance", prior_variance), 43 | ("posterior_variance", posterior_variance), 44 | ]: 45 | self.register_buffer(name, tensor) 46 | 47 | def sample_t(self, batch_size=1, device=None): 48 | """ 49 | Sample a random time-step for each batch item. 50 | """ 51 | return torch.randint(0, self.nb_timesteps, size=(batch_size,), device=device) 52 | 53 | def sample_q(self, x0: torch.Tensor, eps: torch.Tensor, t: torch.Tensor): 54 | """ 55 | The "forward process". Given the data point x_0, we can sample 56 | any latent x_t from q(x_t|x_0) 57 | :param x0: the initial data point (batch, *) 58 | :param eps: noise samples from N(0, I) (batch, *) 59 | :param t: the timesteps in [0, nb_timesteps] (batch) 60 | """ 61 | assert (t >= 0).all() and (t < self.nb_timesteps).all(), "Invalid time step" 62 | 63 | alpha_hat_t = unsqueeze_as(self.alpha_hat[t], x0.shape) # (batch) 64 | return alpha_hat_t.sqrt() * x0 + (1.0 - alpha_hat_t).sqrt() * eps 65 | 66 | def sample_p(self, x_t: torch.Tensor, eps_hat: torch.Tensor, t: torch.Tensor, greedy: bool = False): 67 | """ 68 | The "reverse process". Given a latent `x_t`, draw a sample from p(x_t-1|x_t) 69 | using the noise prediction. 70 | :param x_t: the previous sample (batch, *) 71 | :param eps_hat: the noise, predicted by neural net (batch, *) 72 | :param t: the time step (batch) 73 | :param greedy: use the mean 74 | """ 75 | 76 | alpha_t = unsqueeze_as(self.alpha[t], x_t.shape) 77 | beta_t = unsqueeze_as(self.beta[t], x_t.shape) 78 | alpha_hat_t = unsqueeze_as(self.alpha_hat[t], x_t.shape) 79 | 80 | # calculate the mean 81 | mu = x_t - ((beta_t * eps_hat) / (1.0 - alpha_hat_t).sqrt()) 82 | mu = (1.0 / alpha_t.sqrt()) * mu 83 | 84 | if greedy: 85 | return mu 86 | 87 | # sample 88 | std = unsqueeze_as(self.posterior_variance[t].sqrt(), x_t.shape) 89 | x_next = mu + std * torch.randn_like(mu) 90 | 91 | return x_next 92 | 93 | @torch.no_grad() 94 | def generate( 95 | self, shape: Tuple[int, ...], score_model: nn.Module, return_freq: int = 20 96 | ) -> Tuple[torch.Tensor, List[torch.Tensor]]: 97 | """ 98 | Generate a batch of samples. Returns intermediate values as well. 99 | :param shape: a tuple indicating the shape of the data 100 | :param score_model: trained pytorch model, which produces the update 101 | :param return_freq: how often to accumulate intermediate values 102 | """ 103 | device = list(score_model.parameters())[0].device 104 | batch, *_ = shape 105 | x_T = torch.randn(shape, device=device) 106 | 107 | x = x_T 108 | 109 | out = [x_T.cpu()] 110 | for t in range(self.nb_timesteps - 1, -1, -1): 111 | t_ = torch.full((batch,), t, dtype=torch.long, device=device) 112 | 113 | eps_hat = score_model(x, t_) 114 | x = self.sample_p(x, eps_hat, t_, greedy=t == 0) 115 | 116 | if (t + 1) % return_freq == 0: 117 | out.append(x.cpu()) 118 | 119 | return x, out 120 | -------------------------------------------------------------------------------- /diffuse/models/mnist_score.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from diffuse.models.components.conv_glu import ConvGLU 7 | from diffuse.models.components.unet_parts import DownsampleBlock, UpsampleBlock 8 | 9 | 10 | class MnistScore(nn.Module): 11 | def __init__(self, input_dim=1, hidden=64, c_dim=64, nb_timesteps=50): 12 | """ 13 | A simple U-Net based model, tailored to the dimensions of MNIST. 14 | Predicts `epsilon` required to invert the forward process. 15 | (Equivalently, can be considered as the 'score' network in NCSN) 16 | :param input_dim: 17 | :param hidden: 18 | :param c_dim: 19 | :param nb_timesteps: 20 | """ 21 | super().__init__() 22 | 23 | # create positional embeddings (Vaswani et al, 2018) 24 | dims = torch.arange(c_dim // 2).unsqueeze(0) # (1, c_dim // 2) 25 | steps = torch.arange(nb_timesteps).unsqueeze(1) # (nb_timesteps, 1) 26 | first_half = torch.sin(steps * 10.0 ** (dims * 4.0 / (c_dim // 2 - 1))) 27 | second_half = torch.cos(steps * 10.0 ** (dims * 4.0 / (c_dim // 2 - 1))) 28 | diff_embedding = torch.cat((first_half, second_half), dim=1) # (nb_timesteps, c_dim) 29 | self.register_buffer("diff_embedding", diff_embedding) 30 | 31 | # define the main convolution op. 32 | op = partial(ConvGLU, c_dim=c_dim, kw=3) 33 | 34 | self.init1 = nn.Conv2d(input_dim, hidden, 3, padding=1) 35 | self.init2 = op(hidden) 36 | 37 | self.down1 = DownsampleBlock(hidden, op) # 14x14 38 | self.down2 = DownsampleBlock(hidden * 2, op) # 7x7 39 | self.up1 = UpsampleBlock(hidden * 4, op) # 14x14 40 | self.up2 = UpsampleBlock(hidden * 2, op) # 28x28 41 | 42 | self.out = nn.Conv2d(hidden, 1, 1) 43 | 44 | def forward(self, x, t): 45 | """ 46 | Produces an estimate for the noise term `epsilon`. 47 | :param x: (batch, 1, H, W) torch.float 48 | :param t: (batch) torch.int 49 | """ 50 | 51 | # get the conditioning for this time-step 52 | c = self.diff_embedding[t] # (batch, c_dim) 53 | 54 | # initial channel up-sampling 55 | x1 = self.init1(x) 56 | x1 = self.init2(x1) 57 | 58 | # u-net 59 | x2 = self.down1(x1, c) 60 | x3 = self.down2(x2, c) 61 | x = self.up1(x3, x2, c) 62 | x = self.up2(x, x1, c) 63 | 64 | # output 65 | return self.out(x) 66 | -------------------------------------------------------------------------------- /diffuse/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angusturner/diffuse/cb0d8dfdd4c07dc09ad167904377109a973b15dc/diffuse/utils/__init__.py -------------------------------------------------------------------------------- /diffuse/workers/__init__.py: -------------------------------------------------------------------------------- 1 | from .mnist_worker import MnistWorker 2 | -------------------------------------------------------------------------------- /diffuse/workers/mnist_worker.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import numpy as np 8 | from tqdm import tqdm 9 | import wandb 10 | from wandb.wandb_run import Run 11 | 12 | from diffuse.datasets import MNIST 13 | from diffuse.models import DiffusionProcess 14 | from popgen.workers.abstract_worker import AbstractWorker 15 | 16 | 17 | class MnistWorker(AbstractWorker): 18 | def __init__( 19 | self, 20 | exp_name: str, 21 | model: nn.Module, 22 | run_dir: str, 23 | run: Optional[Run], 24 | diffusion_settings: Dict, 25 | optim_class: str, 26 | optim_settings: Dict, 27 | *args, 28 | **kwargs, 29 | ): 30 | """ 31 | :param exp_name: 32 | :param model: 33 | :param run_dir: 34 | :param wandb: 35 | :param diffusion_settings: 36 | :param optim_class: 37 | :param optim_settings: 38 | :param args: 39 | :param kwargs: 40 | """ 41 | super(MnistWorker, self).__init__(exp_name, model, run_dir, run, *args, **kwargs) 42 | 43 | self.model = model 44 | 45 | # define the diffusion process 46 | self.diffusion = DiffusionProcess(**diffusion_settings) 47 | 48 | # setup the optimiser 49 | self.params = [p for p in model.parameters() if p.requires_grad] 50 | optim_class = getattr(torch.optim, optim_class) 51 | self.optim = optim_class(self.params, **optim_settings) 52 | 53 | # register the optimiser state, to include in the checkpoints 54 | self.register_state(self.optim, "optim") 55 | 56 | # put everything on GPU if available 57 | if torch.cuda.is_available(): 58 | self.diffusion.cuda() 59 | self.cuda() 60 | 61 | # track the number of iterations 62 | self.iterations = {"train": 0, "test": 0} 63 | 64 | # load existing checkpoint 65 | self.load(checkpoint_id="best") 66 | 67 | # train / evaluation logic 68 | def main(self, loader, train=True): 69 | losses = [] 70 | for i, (x0, _) in enumerate(tqdm(loader)): 71 | if train: 72 | self.optim.zero_grad() 73 | 74 | # put features on GPU 75 | x0 = x0.float().cuda() 76 | 77 | # sample x_t ~ q(x_t|x0), for a random step t ~ {0..nb_timesteps} 78 | eps = torch.randn_like(x0) 79 | t = self.diffusion.sample_t(eps.shape[0], device=eps.device) 80 | x_t = self.diffusion.sample_q(x0, eps, t) 81 | 82 | # predict the noise 83 | eps_hat = self.model(x_t, t) 84 | loss = F.mse_loss(eps_hat, eps, reduction="mean") 85 | 86 | # DEBUG - check for NaN values 87 | if torch.isnan(loss).any(): 88 | raise Exception("NaN :(") 89 | 90 | losses.append(loss.item()) 91 | 92 | if train: 93 | loss.backward() 94 | self.optim.step() 95 | 96 | self._plot_loss({"MSE": loss.item()}, train=train) 97 | 98 | if i % 500 == 0 and not train: 99 | self._plot_sample() 100 | 101 | return (np.mean(losses),) 102 | 103 | def train(self, loader): 104 | self.model.train() 105 | return self.main(loader, train=True) 106 | 107 | @torch.no_grad() 108 | def evaluate(self, loader): 109 | self.model.eval() 110 | return self.main(loader, train=False) 111 | 112 | @torch.no_grad() 113 | def _plot_sample(self): 114 | x_gen, _ = self.diffusion.generate((1, 1, 28, 28), self.model) 115 | x_gen = MNIST.denormalize(x_gen).clamp(0, 1) 116 | x_np = x_gen.view(28, 28).cpu().numpy() 117 | self.wandb.log({"Forward Process Sample": wandb.Image(x_np)}) 118 | -------------------------------------------------------------------------------- /notebooks/visualise_diffusion.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "c753f8cd-32b7-478f-a869-70bb3452a535", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import torch.nn as nn\n", 12 | "import torch.nn.functional as F\n", 13 | "import torch.utils.data\n", 14 | "\n", 15 | "import matplotlib.pyplot as plt\n", 16 | "%matplotlib inline\n", 17 | "import seaborn as sns\n", 18 | "\n", 19 | "import numpy as np\n", 20 | "import os\n", 21 | "from functools import partial\n", 22 | "\n", 23 | "from tqdm import tqdm as tqdm\n", 24 | "\n", 25 | "# wandb for plotting\n", 26 | "import wandb\n", 27 | "\n", 28 | "# PopGen for data loader and worker utils.\n", 29 | "from popgen.setup import setup_loaders, setup_config\n", 30 | "from popgen.workers.abstract_worker import AbstractWorker\n", 31 | "\n", 32 | "from sklearn.datasets import make_s_curve, make_swiss_roll" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "id": "4d66f118-e891-4a5d-bcfa-f542803cc605", 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stdout", 43 | "output_type": "stream", 44 | "text": [ 45 | "(5000, 2)\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "X, color = make_swiss_roll(n_samples=5000, noise=0.3)\n", 51 | "X = X / 4\n", 52 | "X = np.stack((X[:, 0], X[:, 2]), axis=1)\n", 53 | "x, y = X[:, 0], X[:, 1]\n", 54 | "print(X.shape)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 3, 60 | "id": "9b67472a-ca64-4093-b5c1-9412fae5ca9b", 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "name": "stderr", 65 | "output_type": "stream", 66 | "text": [ 67 | "/home/angusturner/miniconda3/envs/py38/lib/python3.8/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variable as a keyword arg: y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.\n", 68 | " warnings.warn(\n" 69 | ] 70 | }, 71 | { 72 | "data": { 73 | "text/plain": [ 74 | "" 75 | ] 76 | }, 77 | "execution_count": 3, 78 | "metadata": {}, 79 | "output_type": "execute_result" 80 | }, 81 | { 82 | "data": { 83 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAoL0lEQVR4nO2dbYik2XXf/6dbWy7TPbv6oEFSt6YzhgiRRTgOagkHL4TYcrIOThQrMVgfHLQxDIY4tiHgl+wHvwRDjMEQYkGyYK0dYuyYxIscW0ZeEZFlwUp21myclWWFjbOj2WkjjSN2d3pMubxT1x+qbtV9bt3X5/3l/4Nmt6e7q5+uep5fnefcc84VpRQIIYQMl4OuD4AQQkg1KHJCCBk4FDkhhAwcipwQQgYORU4IIQPnbV380ne84x3q+vXrXfxqQggZLC+++OKfKqWu2v/eicivX7+OmzdvdvGrCSFksIjILde/M7VCCCEDhyInhJCBQ5ETQsjAocgJIWTgdLLYSUiXPP38G4XPn3jskY6OhJB6oMjJoLGlXOdjUPBkKFDkZHDkyvv2/cPo91w7ehD9PRQ76SsUORkUPomHZH2xCC8FncxXez+fIvYQlD5pE4qcDAaXSG0B29K+s4yv599ZHuB0ttp+7hI74Ja7D0bzpE0ocjIatMRNeV8sV75vt/AL/2S+foyUFA2QF81T8KQOKHIyWEyx2hI3Bb5c3Hf+/Gx+tPv55Qons/03gtPZKpqa0biEH4viKXhSBxQ5GSzXjh4kR8kulov7XplrUlIzwL7wQ1F8SorGFDylTmJIF3t2np+fKw7NImWwI9iyUbkpcI0t8TKYufbt487T0jsxwVPoREReVEqd7/07RU6GRlmZ29QhbhcumYdwiZ5SJy4ocjIqUmSusdMjqaJNTavY5Ipc44vcU1IxFPs0oMjJ6EgpRyyLb4HTJ/ey8nYRSsXklECaUPTjgCInoySlSef2/cNkAcbeCFIrWOqkrhy7CcU+TChyMmrqmLkSIjfS75vwmXMfBz6Rs/yQjAJbRGU6K0NvBinRrin7WBSdKvqcPL1d/x7C/nv0306hDxNG5IQEqCvST43oQwu16V2qO8zKHC13+03G9yZFqfcPRuSElKBqJK/xydIW/Mk8vZPUxNe9egGzXn7/cc25MozShwsjckIaouq43dSaeJ/EbXQTlI7SXRF6KIVEoXcPFzsJ6Qkxwftq4kNCz5U5sBa6PfVRQ6H3E4qckB4Sknpqk1PKgDCNLXIgP3euodDbp7EcuYjMATwH4Os2j/eflVI/UfVxCZkCWoYuoafWhZ/OjGgdxRkypthd82VMLhYHBZn7cuca5tD7Q+WIXEQEwJFS6lJEHgLwPIAfUkp93vczjMgJ8ZPa5GSTOo/dVcmy/VqkbJIpl25pLCJX63eCy82nD20+2s/XEDISYjXxwL5Qb98/3ErYrHw5nbl+g1vWKR2koSido3e7o5YcuYgcAngRwF8F8Aml1I86vucGgBsAcHZ29oFbt25V/r2ETJWcBdMcQqWPZaY0aij2emhlsVNE3g7gGQD/XCn1su/7mFohpH7qkHtKDXuVKY0ApV6FVhqClFKvi8jnADwOwCtyQkj9xFIyrnRMGVw7IfkejymYdqijauUqgL/YSPzrAXw7gJ+tfGSEkEqYkgzl2e0ZMTmdpfb32tG6LXdX9yhlXp06qla+EcAvAzjEugf415VSPx36GaZWCOkGX/rF11VaFyk16hR6HDYEEUK29EHolHk+FDkhZI9UoQP1SZ3ReXkockKIlxyhm1SRe8psFwq9CEVOCIlSpYSxrNQp9HQockJIFqljeHPSMKFdjJhyiUORE0IqUSZad43hdZEzTnfKMqfICSG1kTp+F3AP8/JBoYfxibz9rb4JIYPnicce8YrUl+eObQgNrGW/Hcm7ONi+Cdy+f7j3BlHXfqpjgBE5IaQysaqX1BSLTWqEPpXonBE5IaQxYtH5yXy1FfHpbLX9iKEj9DvLg70I3eTp59+YdIRe69AsQsh08e12ZM50MaPqi0Vxz9DYhhi+uNOMzqe6axFTK4SQRiizH2log2mNudeofmOYys5FrFohhHRGSufoxeLAKXLXvqMumQPjFzpz5ISQzvBVuZjiPZnH8+Za6lr0rty5r/t0zDl0ipwQ0houoV87elBYFD2drQobRPswZQ4UUzRTkzlFTghpHZ/QgX2Zz+ZHhY8QKTIfIxQ5IaQzUmTui87Nfy9UvERkPsaonCInhHRKTOam0M0PmxyZjw2KnBDSOSGZm0KPLYb6ZG4ztqicIieDYuodfGMmtghqd4b6cI0AGHuKhXXkJIshnvxjqB+eEnXMbdGiH9uMczYEkVLkjCsNEWrUKPN4OY9blSFd6GPCde6lyHzMG1ZQ5CSJMm3VQHjnF2D/YvI9Tll8j29Sl/SHctGPBfucdMk8hOvcGKrMGxO5iFwD8B8AvBOAAvCUUurfhH6GIu+OnNSIS9x2C7XZPq0xa31TGjvMx8sh5bFj0VkKY2/7HgI+mWtSpJ4SmQP9fk2bFPm7AbxbKfX7InIFwIsA/qFS6g99P0ORt0OZfLZP3qa4V4vLpMc6mB8DQLCJw34jcD22fpxc7N+b+qZS5tYc6LcAxkAo1eLCJ/chp1paS62IyKcA/IJS6lnf91Dk9ZMi7VgO2jeFTstWS/bw3t3o73pw5er2/0MiTn1TqErOm4HvDYDbkHVPyvAtk1CkPkShtyJyEbkO4DkA71dKvWl97QaAGwBwdnb2gVu3btX2e6dI1Y1wNa5VfzttkiNwk5jMm5C46xjN4yiDfWdhiz00SrVvIhgLqUJ3reWE7rj6/ho2LnIROQbw3wH8jFLqN0Lfy4g8nzKb3QL7oo7los1UhylaW5APf7n4+r15Vjy3UiLyukWe+0aTg/1mcDA/LoidUu+GWLolVqo4tLusRkUuIg8B+C0An1FK/Xzs+ynydFJP1NgcZx+rxeWeaG3B1hGJ2+mKlNx4WeoQun6jst+gNPpvNP++VKFr+iCGsZBS2RLaK9Ql9D7KvMnFTgHwywC+ppT64ZSfocjDpNTP+qpHmso55+TFXfL2LTSWqVbRxN4McoVu32X4cN19xISu6aMcxkJM5rFNn4ew0bNP5HXs2fktAL4XwP8WkZc2//YvlVKfruGxJ0OuvJeLewD20x85LTU5uePY9x7Mj3EAfx4ZcAv7dOZ+vKSd1mdXtv97sVwB86PCG5p9zHWlXuxo/fDeXehLfQng1QVwsRG6Kyer0YIwX3tKvX5O5qvt/qCh80p/7XS2/n4t89v3D1tpPKsCG4I6JKUVuUz1iBlZ+lIDgF/OqRUedsTti2hMrh09qH0anX377Kpx992pxHL/IVLWBVx3JbH0C2VenlC9eWpkDrhb/M3XqavXqMmInGSSkvfeCWlVqN12Rd058tG4Fu80qfXXxZM9bd/ElK9XQUdfAHBnE7HraF1jCt48ksN7dwtyLvO8avTrtdj892B+jFcX669d6LuW5dsKb36M0OvHDBzMyBywJiUaKT59J7X93BGUPP38G716bSjyFkmNFnwCt6kq8Fgu25UWSW13tgmd9FUHcfl+vz7WvRI0S/Ba7CGpV8W8I1hsFpgvcISLpV/owO656ZM0+swTjz2ydz65ZB5CS90U+sl81esUC1MrLRFaiKmaPrFxCci3EOkWdzg1EqIr4VSpq29i4TiWjzdfD71AChR3hWcHaXmqVHuZxCqR2n4tODSrQ8yTyhZ4Xc03LmLyjom7aqTdJbmdrr4yNdcF7irrLFsx4xJ6LD+r6etz3xfKVH+5sF+TLtc0KPIOCAkc2EXgocabHEJpk9hiZOot45DlUUc3rK/ByhXFl5E5AG+EDnBhNJdUmdsiXy7uOxfzu47KKfIWiaVR6hC4r3My1kauGXK0XRdV59O4y0LL3Vn53oQp9OrE1qZskdt3XHYw1GW1EUXeMLF3flPgQLkUStMR99RlUHbELxC+0yojc8D/GlPo+aSuUQHh0cxdy5wib4Cc27aqEwSB/fkeQLzeNQQv+jRyUjKutY+UO6+UhqtUmWj4+hZJuVMG/KMt+iBz1pHXSEr0Dbg7MOvKmwLhyMwFL+xyuJ438xzQz/3t+4fbTYK358HsCi6MEkc1P06qhrHPkwfY1aS/Oj9ed45aZYvmsZjHyNd9jas00eRktp8rD6G7P/tQlsiIPJOcd/Uq6ZOU2R2MxPpBTjdhStdpmdJFTl9MJ6eKzCRnjHFTzzUj8oqkLWCWi76B/fx38QLlbXSf0c+9Pkdcr9GuqxC7SF1LY/M9rkjdNS5Yn1c6St9G6I4OXDtC53lSRDcL7QKk9bWspe0S+sVyFZ2j0/ZzTZEnkFOFUmVwVa7AeVH2C19XoY0t9QscefOyNubArjJCZ7olnGJZX3e7u6fQNoWaPqRYmFqJELsNC6VQcmZa8xZ5fKTsYhM7l2JjGMxzyzdOl8O53LiubaC4zgW4m4Xsjui2nmNWrZSgisQBv8jLCHzqF93QySlPrSpzIL3KZernVeoM85DM28yVU+SZ1CFxCpzY5ArdPLd8Qo/N1olF51M/x8rKvIuxxBR5Iqn5cCC/JjxnQNLUL66xk9v9G5qbnpu64zlXJKXzc/u5Ne5W05bMWbWSQGppIZAu8dyLCZjuBTUl7AU38zzY1qEbNej6q/p8Sxmxay6GLgFc4Ah6Ic/1e6da1eJ6LUyZmzsL+Wbzdw0j8g1lx8xqthdNYGYGBU5c+NItru7gstMxmWqJU9een01G5YzIA1SVOOAuIwTMhaa3APB2luxj16EDrjf6TZkidvuR5sg8FJ3bpXNTjczrps1yxElH5CnvwL5OL1Pm9vRB5sBJFVznZW7nsKuRSGOv1+jzlOdo/6NyRuQWuVG4TWx0LAVOyuLLn5szXHKjc7uRyJc373pmSNf4moXM/T5TNm/WtBWV15K5F5FPishXReTlOh6vaezSQh3xmFHPq2/ei3bbzeZH26hGRzYffPitbXRDiZOyPPHYI4XzxZT56WzdIn794SuYv/2dOJgf48GVq4X0nm8xVAv98N5drBaXWC7uFxbxzUW+qnupjoGyEvbtC9rUc1rXEuwvAXi8psdqjKeffyNYH/7C5X4u3PwwsXPgMYFT4qQMtsyvHT3AyXx9vmmhz+ZH2zvE2DhcYF/mwCaNGNmUeCr4rlV9l+2brwK4Uy+hzUnqopZXTin1HICv1fFYTeHLO9p1uzr3GBo1qqPwDx7H0ygUOKlKSnRuyjwXOyo3YVS+XxoKhGWuaTMqb+0tWERuiMhNEbl59271jYVz8KVS7Cg8JHBfKaErCqfASRP4onNb5lWj8jYiyL5jX7+uFItP5neWxjrbRub2c1q3zFsTuVLqKaXUuVLq/OrV+IlWFymt9vYuLj7YHUe6xiUYU+YmKU1DGubK9/HJ3N5GUXOxXG0/ALQq81EnxVxPVKjV3kdqEwUlTtrAdZ7pczJl7KrGLFGM5cqnKnMftszt1JQWekzmdTFqkZuYnXKuVnsfZnOPCTvhSJe4cuY6KjfTKzmt/DoqZ4plR0qKxaztNz/013LKFctSV/nhrwL4PQDvE5HXROT76njcKvhmDesn1ZVOCdXi+honKHHSFfa5p1MsOblyE12OaMIUi/sa71uFT11VKx9TSr1bKfWQUuo9SqlfrONxy5IzghZYC9wlcV9KhZC+UaissBY+3zw7j0bm5vlvRuUAZT4EJtHZ6Zpe6CK07RqHC5E+YnYiFm/7d92fD7AWdUqaZbW43HZ8ns6aOGLSBKMTuS+lonF1a9q3oMyLkyGRI/MQh/fuFq6F7ehW3okm4XKGuXlMk/Qr0VORlJQKEF7g9G2RRYmTPmMvfpatLwfc1wfTK0XMNzffjHLz30Pjq+tgdBG5SWpKBdhv+AHAvDgZFPbAp925mxeZA5sKls1ALV4DfvQQLZ/MQ9F4nQHhaCJyOxp3jf0E9sfP6g9zANbpbMW8OBk0ZSM/l+RZV17EbgzSvrCxR9o2yShEnpoXj80Qt1vvAY71JMNGp1hC2LPL7cmIOlc+5brylC5PLXSf2Jt0yShEbpOTFw9tBMH2ezI0XI1CQDFfa+bKtcQf/vLNgtBNmQPuzsSpReWxqYi5X2OLvoEvGs9pwQfCUwwBSpwMn5QJibbQV4vLvbpyssY1FbErRvfKmDv8AGGJmznxEJQ4GRK+js9c7NZ9DaPyHbbMtdDN/9c0mZoatMhjufEcOM2QjBFfegXYz427vsbNJ/LJic7reiMc1aviyo27sLs2NZQ4GQs8d6fFYOvIU6JxV1rFlSN0vYPyQiBj4mS+znPP5kdYJKwZ2SwX94HZlQaObFz47li0Y5rajHk0EXlKNO7b5QdgrTgZNy7B5Gw8Qda4tozUhNJOTaekBinyMrlxSpyQavgWPEkeTTx3gxR5KmbnJuAeasOGH0LSaWOThDGg9+00n68mN2MexauSc9sSmi3OaJyMhZAcUncQsitX7OaglN9F2mFwIk85aVyRd85ehoSMjVAteU6unOmVfjI4kaeih2DpD2B/NjBz42SseNvJPa36PlLryRmVFwnNXGmCwZYfAnkRgWs2MCFTZDZfj7RNGV1hwzLEOPY6AjeWyCBlEYbROJkKoQhZ71mbMpecpKHvUFweamOBeFARuevkNG/xzPzdcnG/kFLRMBonU8a1yUqoVZ/4sXchA9zP75pmt3wblMhNzLSKrxHIlDkhU2Ynm2oy8V1PU7ir9QWSpsRdzYj6OWtyD9RaYn4ReVxEviQir4jIj9XxmDap+3G6sN8lWTtOpoyZVkmNxh9cubrtx3DtYzs17PU5e2y2uf5gT4+0qeNNsLLIReQQwCcAfAeARwF8TEQerfq4JqF3QtcTaD6RLsGzbIqQNOzSRN/elFPEtSewKfAyi8llqeNV+RCAV5RSf6yUWgL4NQAfqeFxneQMyGrziSRkCITuXG20xHU0PpsfOZvpppBWiZHzvDZBHSI/BXDb+Py1zb8VEJEbInJTRG7evVtttdyOxkPCtvce5DxlQuK4JG5G41NLq8Tq5HUe3JyuGhoNUjetWU0p9ZRS6lwpdX71arwRoQyH9+4WPgghYVxdnabEgf1Bc1OTuE0sNWvOd9Kk7ERWhTqqVu4AuGZ8/p7NvxFCOsYnnQdXrm6DHZfMtcQP5secT1QBOxpvqvy5joj8BQDvFZFvEJEZgO8B8Js1PG4WrgicUTmZInYawLUoF8KUuC0iNtL1k8oRuVLqLRH5AQCfAXAI4JNKqS9UPrIaWS0uAdaTk4lRpTrL3g6RjXT5uBoSm6KWhiCl1KcBfLqOx8oltTKFMyLIFAhF42ZhgO9uNTRIi9F4mNn8qLPqlUF1dtoRRtclP4T0idDOWSkSd8FoPM7JbJe22suJb6Jx345kdTH6WjzmyclUceXGU64HV7khcaMrUVzPlSlxH3Xd2fQ+IuecY0Li2NG4K6ViStxszU/ZWIJplR3Xjh7g9v1DnMxXhb4Ul8y3om94f+DeizwFRt1kqrh2dQ9JPHW2Chc50zmdrbwzyM3nr8n6+1HcO6XsdAKERkwSMjzsKLyqxM0uTqD5KHKomEP7XLhSKbbE634eByVy+8ngiFoyVVKmgaZIPNTFCVDiQDi9e2d54N04wnc308TzOPjUysH8mMOxyGTwpVIA4IVLYLm4t1edEpO4xu7ipMT3iW0mcTJb//vpbJ0/bys1NaiIXBPaETz4c1yFJwMmlA9fS3y/xDAlJx6bbjhlQnc+wH66Vn/exvZuJoOOyE9mB7jAeiPZB3Avetp5P0KGSG4qBQhL3D/dkCmVEHZJp93LMpsf4WK5aj1oHLTIbcxBQPpzQoZOXfnwEEyp7BNqsALcDYldbS85KpEDlDcZF01I3BWNA9yYPEZsHwR7dG2bjCZp7HsSOfyHDJW2JM4qlX1SGqxSaeM57b3Iq/zhLrm7aj/ZPUr6xNPPv+GV+AuXwKtv3sPi9a9A7v4/HN67i4e/fDNb4kB4wwhKfI2ZUolt8t4lo0it6KljoVLE9eLELvK4ff9w8judkP7RVD48FInzLnVHuEpl/TylRONtP6+DE7k552BNcSiQS+bmAoQuC7Kf5Keff2PSUQjplnir/b2twA9RbkHTJ/Gmuw6HTG5Kxa7+aYtBiPyJxx5J2vzUvu3RT/rB/Hi9ATOKE90YlZOucQkcKApk8fpXKgncbPxhOsVPrEqlKk0+v4MQuY2OygE91yC8jdVqcVmQuV4aOJmvCjJnVE7aIkXgZhReRuBAMaUyf/s7KXEPrkAxZ4s8X7FFW4HiYETuisrtMZKhXLmWObDOl5/O0GoLLSEuWcQEXjYKN9EpFUo8jVA07nKL9kqX6w6DEXmIlKhco/Pleh4CAEblpDF8KcGQwAFUisI1b56dbyV+/eErlHgJTLeYgaKJuTenbyegpp/nwYrcN9w9pYKFkCZJkTcAr8CB6lG4KXFzhgolvk9K+bHezs3VtRnbzq2N53lQIrfTK2Vz5cCuHBFgRxspT0gChRrkPXkXB1zVkULR2BL/4PH+OU6J5+PfAaidXYBCVBK5iHw3gJ8E8NcAfEgpVc+ZmIkvKvfhmofA9ApJoYy4gWIzSZ35bxO7VtxMpwA7wfDc3uFbdDbRKVj9WtqTV9vaBShE1Yj8ZQAfBfDvaziWJEKliFUrWAgxid1y54gb2KVN6o6+NVrgc5YYVsbVAe4anR2SeJvPdSWRK6W+CAAiUs/RlMCXKwf8m06YKRYTRuXDpK0RC6FUCYBOxA3sqiaKAn8LAGenhAiVgGpcc8VT9uNs+7luLUcuIjcA3ACAs7Oztn5tEEblwyU1xVEHrgu7rYgbcEfdwFrcAApDr1wCByhxm5SUStubQ1QhKnIR+SyAdzm+9KRS6lOpv0gp9RSApwDg/PxcJR+hA9+iZ6FtH+E8uS8qJ/0mdAHGIikb162y/XNmmi4kbqBeeWv0wiWAQtoEwJ68AQo8hZRzaEgSBxJErpT6cBsHkktK2z6Qv6cn0yv9JNYJCexPp3O97uab96uB32f/7OG9u9AJxCbFrbEFXhw36y5zM+F566YJiffhzXNQ5YchykTlmlApImXePakDpYDidn/OBItjO0AX5s82KWwTe8ysWXkC+NMmGp6nYVLv5oZI1fLD7wLwbwFcBfDbIvKSUurv1nJkCSQP00La6EkNh2n1F1vii9e/AqC+Rpq2iVed+NMmGgo8Tl2RuL3Q2YdoHKhetfIMgGdqOpZSmDIPjbi1Uyy+/Lhr/gqj8u6I7dQC1NPO3ia+ypNY1YkJz8c0mlhT6ZvEgRGlVnyczHYpltjipm/+CkCZ9xF71/g+kyJvIJw+4fmXR6y8MDUXbi6K91HiwAhFbkfl+sWKdXteLFfbagAdlVPmpAymtDWhyhPKu358O/0ARYG7mgfNVvwhSBwYichjuXJz4E1s8VNH5abMARQqWfTvJO2jL6aL5VqKD7BOrbx5dt54esUlaBstbBMtbwBJuW+eW9WISTxlSqpNnyUOjETkNvYwLfMdOCRzHZXbMgeYaukK3xrIyfJt65QZUJA5UM/kQBempH1pugOgMMdnfzKev2yQ51N1wntuuvsCXBMNgX7nxG1EqUq9OaU4Pz9XN2/WHz3V9U7surUawos5ZlI3JdaUyZ37omnALWcXuQOVeP7UR6rE7UDO98YL9O+6F5EXlVJ7kcaoRA6k5cZShE6Z94+UxStXxJWCHZXF8qSpUN7Nk9osBrjPCVvkfb7eJylyYL/cKLbQYdP3d+gpEtoyTVOlwcMn66q9BTxH6qeqxIHiDj+Af4MIoPvXcDIi1+SUHsWETpn3l9TJh6FBWk02f/GcaIbUgG1MEgcmKHIgv5srReiUeb+pe6QtX8/+UVeKbWgSByYqck1qdO4Sub2bEGVOSHe41sCAvCgcGKbEAb/Ihz0pJhH7hdAvlH7htmI2FriWi/vbE8H8/4vlrpxRvxHYt+1PP/9Ga5sdEDIVfIUMZvWSrmAao8RDTELkQFzmKdgyv7M8KMicQiekGVJKi2MCB8YpcWBCIgfCL4xrkwEXpswBbGXui86B9rYiI2Rs2MFQSn/IanG5/XAxNokDExO5jSsq1y+yr9sLKMrcl2pxReeEkHRi27G59ks1sbtvzWt6TBIHJijyUIrFzpWnyBzYT7WEcueEkDg5/SA2B/Njp8RPZgfGrJvxSBwY6ayVGPE9PwE9y9yUuf2ub85q2N3e7U4w3wRFfQyEkCJlN4BwzVCaUrXZJEUOhGV+sdAverGcyTdwqxCdQ588+ycbh24R4idWJgyE+z58YxbsfU6BcUkcmEgdeYicCCClPhUorozHogBg+CcRIVWoq3EvZT7O0AU+6YagGKHZHakyXy0uC3m5kMwBCp2QMkGUjT2JcuzXGkWeQM78BlPkdpmTLfTUWS3AcE8wQlKpQ+BAeEIpMM7riyJPpA6Z+8qeUqPzoZ9shNjEJlbmjprOlfhYrimKPINYqsUnc2A/xaJxdZQB41xBJwTwl9tWETjgv4aAcUbhJo2IXER+DsDfB7AE8H8BPKGUej32c30XuSZnylrKJgaUOZkCdUffJjkSH+P105TI/w6A/6aUektEfhYAlFI/Gvu5oYgcyIvOgbyKFoB5czIeUooGAHdbfagGHJh2FG7iE3mlOnKl1O8an34ewD+u8nh9RJ8QKQ1EgFlHHpa6b6NnfeK7mojM4yGkL+RXfe2nJEMw6IlTW45cRP4rgP+klPqPnq/fAHADAM7Ozj5w69atWn5v21SZbe6izMq7ZqonLekHKSmU3DtWoNxmyMA0rofSqRUR+SyAdzm+9KRS6lOb73kSwDmAj6qEd4YhpVZclKk7jxETOuCX+hROYNIfyubA61hHAqaRC/dROrWilPpw5IE/DuA7AXxbisTHgC/d4iZN6MWv77f3+9Iu+jimdDKTbqiyiNlEMQAwLYmHqJQjF5HHAfwIgL+llPqzeg5pOFQVuuvkdg3g0jl0wC905tFJk6Q08QDlKlEYhVenatXKKwC+DsD/3/zT55VS3x/7uaGnVnzk5s9TtqMC9k9wgHl00jxl6sABv8TN8z0lDw4wCrdhQ1CLhIReVuZAXOga5tJJLimz8u35+r6xsjlNPQAFnkMj5YfETZl55yk5RLNkMRVf+kUfJ5kmMXG7tiwEimkUH/YgKxc5FVo8T+MwIm8Y136DgDs6zyF2seSWMtrw4hkXueJOEbZJKLiw98NlNVZ5mFrpkNTcOVBO6j5CuUdNTO68mIZLSN6xxUpg/1z0zfuOwfRffVDkHZNaugWEo5tYa7ONryJAw4tsPOTmuXNnBpU5lzRckK8HirwnxISusW9tfZGSvvDsmejA/lx0E99AfoAVMUOibMrENcHTdQ5p9LkUmq9vwju9ZuBiZ09IqT3fXxjdcWe53gk8JQWjL8yD+fFepPXqYv1fLfiL5fpzc/6LjasRyYQXZ/OUWaT01Xn7BH547y4eXLka/D1l1mB4fjQHI/IekHNxhm6HQxFVDN82dUC5aN2EF3B5qpQFAu47uZDATUyZH8yP9zZIYblg+zC1MiBijRg5uc2yck8Re2r1QQpTv+hThG2SK28Awe0JNVrmtsSBXVrFlDgF3i4U+UCpY2NoICz0UCSmKSN2k1zJj00GuaK2qWMdRZP65u7Li1Pi3cEc+UBJn4cOnM42F/Psyt6FrC/KlIvYJXbz5zbpdGP2+u4NJacszabN5qUcsZb53VUj7BAhgQP5w6o0sTdrCry/MCIfEDmzL4Bq+VEXdqRuRmyalK4+Tah8rUqapq+kyLqORpwyIucM8GHA1MqISG30AMKT6VxSjwndlTvV2CWOqTSZe3eRE/1WOY6UdAgQ7hsAws03vp9tcmMTyrs7mFoZEeaFZEs9JJuT+Xocrp2C0amS1eIyWnYWI+dWHnCP7bWPuax4XeRGvHX/TncaxN9Baf+MLXX9uf24qSMcgBVb5kcAI/KRUVcpI+DPp4ci8ZSftx/Dl5qpkm+3KTOhz3VMPlKj5twRDKHGrbLk3PFQ3v2CEflEsC+8lIhdR+oAcGd2BcBGOCVTJSmsFpdbmS8X951pmZwpjy6K0gznjc03HfuN6tVFPG1kNlSlHU8+PomHqoVcUNzjgyIfOblpGC2FWGWECz2WNyW9Ehof4HvsMuSWYvq+vnR8n00xTbSTetlj992dxMQ99VLPKUKRT4hQtB67+F3y2M83HxRkbke2NqnVLjkiDL2JVOl8Ne8gQr/b9aaUK/RUgZdZBKa0xwlFPmHKLpoC/nkwvlp2m1ju17wDSJ0tAyD5jqAJYm9MqaWZdcibwp4WFDkBEM+t26QM+jqdrf/rrrpwy0pH+XY1hinBmNR9uy4dzI+zo/JYFG7/TqDagi0nCZIyUOTEiU8OPsHHyh5TSfle1xuEi13n6ZqUdI9Jal38vrjjIwtCUNwkF4qcZOGSSG5aJpXQOF9gX5R2zn6b5tFsKnLqoBhpl5/LnQrlTUJUErmI/CsAH8H6TP4qgI8rpS7qODAyHFLkXoY6JOirwvF9X1WqHjOFTcpQqSFIRB5WSr25+f8fBPCoUur7Yz/HhiASo443ApNQd2jTc10oZ1IXjTQEaYlvOALQfpsoGSVl5Vcmh18nlDbpgso5chH5GQD/BMAbAP525SMipAIhkXJrOjJWoqkVEfksgHc5vvSkUupTxvf9OIC5UuonPI9zA8ANADg7O/vArVu3Sh80IYRMkcbH2IrIGYBPK6XeH/te5sgJISQfn8grTSUSkfcan34EwB9VeTxCCCH5VM2R/2sReR/W5Ye3AEQrVgghhNRL1aqVf1TXgRBCCClHJxtLiMhdrCP4IfAOAH/a9UHUyNj+HoB/0xAY298DdPM3/RWl1N42Xp2IfEiIyE3X4sJQGdvfA/BvGgJj+3uAfv1N7W9gSAghpFYockIIGTgUeZynuj6Amhnb3wPwbxoCY/t7gB79TcyRE0LIwGFETgghA4ciJ4SQgUORRxCRnxORPxKRPxCRZ0Tk7V0fU1VE5LtF5AsishKRXpRPlUVEHheRL4nIKyLyY10fT1VE5JMi8lURebnrY6kDEbkmIp8TkT/cnHM/1PUxVUVE5iLyP0Xkf23+pp/q+pgo8jjPAni/UuobAfwfAD/e8fHUwcsAPgrgua4PpAoicgjgEwC+A8CjAD4mIo92e1SV+SUAj3d9EDXyFoB/oZR6FMA3A/hnI3iN/hzAtyql/jqAbwLwuIh8c5cHRJFHUEr9rlLqrc2nnwfwni6Ppw6UUl9USn2p6+OogQ8BeEUp9cdKqSWAX8N6eNtgUUo9B+BrXR9HXSil/kQp9fub/78H4IsATrs9qmqoNZebTx/afHRaNUKR5/FPAfxO1wdBtpwCuG18/hoGLokxIyLXAfwNAP+j40OpjIgcishLWO9V/KxSqtO/qfIOQWMgZfMMEXkS69vEX2nz2MqSuiEIIW0gIscA/guAH7a2iBwkSqkHAL5ps2b2jIi8XynV2boGRQ5AKfXh0NdF5OMAvhPAt6mBFN7H/qaRcAfANePz92z+jfQIEXkIa4n/ilLqN7o+njpRSr0uIp/Del2jM5EztRJBRB4H8CMA/oFS6s+6Ph5S4AUA7xWRbxCRGYDvAfCbHR8TMRARAfCLAL6olPr5ro+nDkTkqq5eE5GvB/Dt6HhTHYo8zi8AuALgWRF5SUT+XdcHVBUR+S4ReQ3A3wTw2yLyma6PqQybRegfAPAZrBfRfl0p9YVuj6oaIvKrAH4PwPtE5DUR+b6uj6ki3wLgewF86+b6eUlE/l7XB1WRdwP4nIj8AdbBxLNKqd/q8oDYok8IIQOHETkhhAwcipwQQgYORU4IIQOHIieEkIFDkRNCyMChyAkhZOBQ5IQQMnD+EkJCvs+6QBLDAAAAAElFTkSuQmCC\n", 84 | "text/plain": [ 85 | "
" 86 | ] 87 | }, 88 | "metadata": { 89 | "needs_background": "light" 90 | }, 91 | "output_type": "display_data" 92 | } 93 | ], 94 | "source": [ 95 | "sns.kdeplot(x, y, fill=True, bw_adjust=0.2)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 4, 101 | "id": "8053e264-78d2-44bf-9b43-6eec3354cb3e", 102 | "metadata": {}, 103 | "outputs": [ 104 | { 105 | "data": { 106 | "text/plain": [ 107 | "" 108 | ] 109 | }, 110 | "execution_count": 4, 111 | "metadata": {}, 112 | "output_type": "execute_result" 113 | }, 114 | { 115 | "data": { 116 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQMAAAD8CAYAAABzYsGzAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAiGUlEQVR4nO2df5Ac5Xnnv88MLZjFjmexdTkYaRHncy1GCGnKe0Y+/XGHkkM4suSJjCwTlJQvqZBUXe4ioJasDuFdEe6QawNR6pyqO/KjcldSYQkBc5JlSiKFUq5TvNiSZ8WyAZ3NL8GQKuSTFkfsIEa7z/0x02J2tmd3dt6nu9/ufj5VW9LuznS/O9Pz7ed9fhIzQ1EUJRX2AhRFsQMVA0VRAKgYKIpSR8VAURQAKgaKotRRMVAUBYCgGBBRmohKRPQ9qWMqihIckpbBHwJ4RfB4iqIEiIgYENESAOsB/KXE8RRFCZ4rhI6zG8ADAD7ZzoM/85nP8LJly4ROrShKu5w8efLnzLzY63fGYkBEXwHwHjOfJKJ/O8fj7gFwDwD09PTgxIkTpqdWFGWBENFbrX4nsU1YA2AjEb0J4LsA1hLRnuYHMfMTzNzHzH2LF3sKk6IoIWIsBsy8nZmXMPMyAN8A8AIzbzVemaIogaJ5BoqiAJBzIAIAmPnvAPyd5DEVRQkGtQwURQEgbBkodlMslTF85DTenajgumwG/et6Ucjnwl6WYgkqBgmhWCpj+zNjqFSnAADliQq2PzMGALMEoVE0sl0OmIH3K1UVkJhDYbQ96+vrY80zCJY1u15AeaIy6+e5bAbHB9Ze/r5ZNFqRU2GIJER0kpn7vH6nlkFCeNdDCICahfDZ7YcxzcBCbgtzWRZKNFExSAjXZTOelgEATHVoHFaqU9i2bxRDB8cxtHE5AGDnoXGcn6wCALIZB0Mbl6tYRATdJiSEYqmMe/eNLujuL0V3l4PBDSoKNqDbBAWFfA7b9o2Gcu7zk1X0Hzh1+fuhg+OYqFQvf68WhB2oZZAgWjkRbaG7y8H6W67FsVfPavjTJ+ayDDTpKEH0r+sNewlzcn6yij0jZ1CeqIDxsZOyWCqHvbREoGKQIAr5HLqcaL3lleoUho+cDnsZiSBaV4ZiTKU6HfYSFkyrsKgiizoQE8ZcIUZbSRFd3ipoOrV/qBgkjP51vaGFGDtlirkWjWCgOl1buSY9yaPbhIRRyOdw9+oeUNgLWSDVKb4sBC5u0tOygcNYs+sFdTQaoqHFhFIslUPLO/CbNBHuunUpHimsCHsp1qGhRWUWhXwOuWxm3sflshns3rKqrcfawhQz9oycwY7iWNhLiRQqBgmmf10vMk665e8zTvqyk+74wNrIbS32jJzBDbqFaBsVgwRTyOfw6KYVyGUzINTSgru7HBBqFsGjm1bMcM5d18I66HJSSFNNKtJE+Nw/uzqA1beHm7zU/9QpFYR5UJ+B0jZevQ4yTnqWaDQ+/v79pzAVwjXmRTbjYHTw9svfJ7HzkxYqKSK4H5R2P0Duz72apRAW1j9BgolKFcsGDiOXzeC2Gxfj6ZPltjo/JQVjy4CIrgLwAwBXoiYuB5h5cK7nqGWQLLzuwGFHMlqJUXPnp7jht2VwEcBaZr5ARA6A/0NEzzHziMCxlRhQyOdm3W3DFoNWt8Akpz5LTFRiZr5Q/9apf9mxSVSsJZtxwl6CJ62cpElAaiR7mohGAbwH4HlmflHiuEp8Gdq4HE5qZrDSSRG2ru5BOhVOEDNFwAcXLyU2HCkiBsw8xcyrACwB8EUiurn5MUR0DxGdIKITZ8+elTitEmEK+RyGN6+8HNbMZTMY3rwSjxRW4LGGn3d3OYFZEdNcczImtZeCeGiRiL4FYJKZ/6TVY9SBqHSC64gMsuoybg7FuRyIEtGExQCqzDxBRBkARwF8m5m/1+o5KgaKKTcMHA7MMZXLZmKTi+B3bcK1AI4R0UsAfoyaz6ClECiKBEE6+pLShk0imvASM+eZ+RZmvpmZH5ZYmKLMxXx1FX4R5zZsmoGoRJLGbMigOzfFNRdBaxOUWFAslWfNY/CTqM6a1H4GSuwp5HMYHbwdW1f3BHK+OPoPVAyUWPFIYQW2BtTWLW7+A90mKLGkuThq2aczOP7aOV/OFaXQo5YwK4nDqzhqR3EMe0fOiOcnuA7MqJdB6zZBSQyPFFbgjV3rfd1CRHnroGKgJA6/E5aiGnpUMVASh98DaKNaBq0+AyVxFPI57Dw0jvOT8jkJTorQv64XO4pjePLFtzHFHJk5DmoZKIlkcMNyX9KZq9OMP3r6JewZOXO5Eaw7x+HzDz1ndV6CWgZKImlOZ5Zs0Hrxkvek60p12upog4qBklgaw4+NeQnZLgcXPrw0a7ajBG60QcVAUSylOS/Bz0YqtkYb1GegKB64I+X8mDFpa7RBxUBR5qB/Xe+sxq0Sx7QRFQNFmQ9BLehyUlb6CwD1GSSS5iKe225cjGOvno1MsU2QDB85jeqUnCNxsjqN/MNHMbhhuXWvsVYtRgCv8WQAZni/mYH3K9V5P8xew1ObmW+YapKGlfrdeDXoJim+dkfuBBWD9vH68DopAggt71hOivCJq67AxORMcdhRHMOekTNtnderRfhCpzDHgTW7XvC9rVqQr6GKQUQJe6T5m7vWz7AEUkSea4nbbIFG2rGkJAjqNfS1nwERLQXwvwD8MmpJXE8w85+ZHjeOtGti+1V3v1Bueug5VKf5sgXSSpTKExUUS+VYWgfNmYrpuiCmWwhjp9iQeyAxROVaANcy80+I6JMATgIoMPM/tHpOEi0DrzuMmwLrXljdXQ4ufFhF1Tub1WqcNGH4zpWxFIRWFEtlsWnSaSI89nX/Xz9fG6Iy8z8y80/q//8nAK8ASM4V0SbDR07PMjVdGXbvMOcnoykEQM1/sW3faKIGlhbyOUilIEwxh95gVTTPgIiWAcgD0CnMTQTd2z8s4tg1eC5+41a5bsxhd0kSyzMgok8AeBrANmb+hcfv7wFwDwD09ATTzjoMvBpxjrx+PuxlBUqlOoVt+0Zx4q1z1tfwm/JIYQXeOHtBrNlqmL4DkWgCETkAvgfgCDM/Pt/j4+ozCMrzHCWuvCKFb3/tltj7EqT8B91dDkrfut18QS3wewozAfifAM4x87Z2nhMnMdhRHMPeF88gpOhf5MhmHAxttC/7ToKF5HHMBQG+JXT5PVFpDYDfBLCWiEbrX78mcFzrcd98FYL2mahUsW3fKPIPH42dX8Ed4JImM69iWBOfNemoDVrlB/idqhp34hyOXDZwWOQ40slIOkTFgGY/QHmigvv2j+LefaMqBIZUpxj31vfZcRIEybt5kA5FLWGeB6/8gGmW65eXdBjAtn2j2FEcC3spYkiGB4NshKKWwTzYkCY6H5LNPMNiz8gZvHH2At78f5XIV0RKXTOEYBuhqGUwD8JNboxIEWZ13ck4ady9uge5bAaE2h5z6+oezzbgTsO73d3lYPeWVdi9ZZUvLcM74fhr51CeqFx2oG3bN4pVO6PnaJS6mzOC3T6pZdCCYqmMBw6cgmBfC2M+lXEwuGF5W8VOfddfs6C+A360DJdgolK1ur24F/3rekVyDhalg70TaTTBA1uThwjAG7vW+3qO5siJLWnUUSuTXrXzKCYq5hObtq7uEc3i1GjCAvFyGtpAEM6k5pbhQTT3aIco+G4aGdq4XOSGsmfkDPaMnAmkI5KKAWp3w6GD4yJKLkE6RUgBM4Z4ZJx0KF11+9f1WmElXZfNRKrlWnMfBFPcJKTGY0uTuG2CVzPQfT9625fpOQshl8207HEY9oXfqgfjf37mJUwGVHPtpDCrvDsqLdckrSvT7ZK2Patjqy8gavvhRhqnDafq3scgWzJE4bWTvu7eNPAbqc+gTti+AK9GpmGZ/1I8Ulgxw8Hl51gyL2zwZ8yHa7lIdUXyq8VcovIMgnZCpahWpefG/4c3r8TwnStn5AREwcxdCO5YsiCDYraPOgdqr4vUqDa/GqAkyjIIMlQ2V6lunD78rQjyta5Up9H/1CkAdr+2UvkHft3UEmUZdC3y/8/NZTN4c9d6jA7ebvWF6Tf963oDzWysTnOoLcPaoZDPobvLMT5OVuAYXsReDIqlMtbsegHLBg7jp+994Ou5or7/l6SQz+HRTStmbIl2b1nl6zmjkIswuGG5sUj65fOP9TYhiOhBNuO0NdYsiTQnMAHwdSiMm5Rlcz6Cuw6T12GiUvXFiRhrMdh5aNxXISAAo4P+9auLI35Oh3p3ooJlA4dn1FcEkayzUAr53OU+Dp1ynw99IGK7TSiWyjg/6W9GYZC15nFByqPuBTf96+J2a7ZppoPptTMNYPszL8kspk5sxWDnoXFfj6/+gc4I2rHYiE0zHSSunYpw9mcsxcAPq6C7y4l1fkBQeDkWW/Vf8IOwB5W4SEUWJImlz2DooKxV4KQIgxvi2d47DLwci33XXxPYxGlbog6DG5aj/8CpGRmpC0XSkShiGRDRXxPRe0T0ssTxOsUNI0pWH2YzDoY3x7ODr00U8jk89vWVszo5+YEtvp5CPoerF5ndjyW3w1KWwd8A+A5qo9lDwY8wokYLguXjsNuobx2mbPP1vG9445LcDotYBsz8AwAyw+Y6oFgq4/79p8TDiLbcQZJEIZ/Da4/KdnOy2dcjcY1JdZaOvM+gWCqj/yn5vaZtd5CkkROsbbjtxsXWDoDtX9drPIPDHelm+jcGFk0gonuI6AQRnTh79qzYcYcOjos3JrHxDpI0JEOQe0bOWBFO9KKQz+Hu1eZTyfcK/I2BWQbM/ASAJ4BacxOp40o6C3dvWaUCYAnSbcOGDo5b+966d3SToa2M2mtl8jfGMs+gE3LZjLUXS1JxeyNI4Obz24o7tNUE05CpVGjxSQA/BNBLRO8Q0e9IHHc+pN5c9Q/YjVSw8f79p6wXBJPKTlNnpFQ04S5mvpaZHWZewsx/JXHcuSiWyug/cMr4OCmC+gcsR2JPDdSKpGxJR26FyXV4242Ljc4d2W3CzkPjRplbLo9/Xf0EtiNhQrvYko7cChOhOvaqmWM+smIgkWyxdXWPCkFEeKSwQiyX3+Ymqiap9Fb4DKKKrbFnxZvBDcuRFkhXJsj5m6QxiY5Z4TMIA9OZlDZfEIo3hXwOj21eaXwcNwwXN0yd4JEUgx3FMePc9bheEEp72FK52EyX09lHkmDe9SiSYvDki2+LHMfWC0JpjZSA21p3cmWHWZcSWXyRFAOpOgRbLwilNVICbmteyYSBYzz/8FGjrW8kxUACTTSKJlICvvPQuJU+I5O/7/xk1SiPIlJi4DYvMaW7y9FEo4giJeCmHxy/MC3QMsmjiIwYuKXKpjHirat7UPpWsqcdRZlCPhfrBKTGHpGd0ulWKjJiIFWqbJqlpYSPm8MvUbNgoxO5kM8ZWUCdbjUiIwZSpco2vvlKZ0i4kW10Irst/DrBxBcWGTGQwsY3X1k4Uo1AbXQiDx853VELvzSRkS8sEmKgpcpKMxK1KdmMY6XvqFO/2F23Lo1/c5NOizecNCGbcaxthqmEy9DG5WEvwRPq0BmyZ+QM7v6LH3Z8XusbohZL5Y79BV9c1o29v/sl4RUpNpDNOEZ+pKsXpa27MbjTo01y6o6/dg47imMdFeFZbxmYhH5GXj8vuBLFJr6y8lqj5390acqqHAPXaShRXt1pur71YmDi/Q9iVJcSDqYh4up0bay5LYLQqdPQi06ve+vFQL3/ihcSd9BpyM/l7BTJkHe6Q6eD9WLQv65XrCGmEh+kRjJKtto3QfKmd9etSzt6nvViUMjn8K8/e01HzzVJ6VTsRnhuTuhIDY3JOKmOO3iJRBOI6A4AfwYgDeAvmXmX6TFdz+q7E5WOQi2aU6C0g1RfRVPcyMbQwfGOrRUC8OimWzpeg7FlQERpAH8O4MsAbgJwFxHdZHLMRs8qo7O7wFUddoxRkkOKan0VbaGQz2F08HZkM50J1N2GDX4lPjFfBPAzZn6dmT8C8F0AXzU5oIRn1dYSVcUefuNWO7tjd2oZ9F3f2XbaRUIMcgAaA5vv1H/WMVKtrG0sUVXsYd+P3rbyZtFpNMD05mflFOZOXwwvtEpRaUV1mq28WXSaJ2B685NwIJYBNMYyltR/NoOFTGGWTBbSPAVlLmy5WbgOc1Or2OTvkbAMfgzgc0R0AxEtAvANAAdNDigVEtSIQjzZUeys1t8LG24WkqnIJn+PsRgw8yUAfwDgCIBXAOxnZqO0LqmYq1YpxpO9I2dEjuOkyIqbhWQqssnfI5JnwMzfB/B9iWMBH8dc799/ymjLoEIQTyQ2kdmMg6GNy624RiS3KmGHFn2hkM9h9b/o7vj5kk5IxR4kvP8EYHTQnqa4UlsV0+21tWIAAH//2rmOn6sVi/FEot2ZDX6CRqS2xbfduNjo+dY2NymWykbmoNYlxI9iqSzS7sz0QyONa6GYRhNMy7qttQxM4782OIYUWaRyAmxsl1/I53B8YC12b1nV8TFMfQ/WioFUFqISH6SuCVtyC5oxaZEOmG9/rBUDUwegjZllSmfsKI5h2cBhsePZ5jNwMQkxSuTUWOszMHUA2qr+SntIZeQ1Y0tugRedXrPdXQ4GN5iHSa20DIqlsnEnG1vVX5kfyYy8WVgcce70mmWWyamxTgzcC8G0k41tHmOlfSQz8pqpTtlZnAR07vSeqFRF8i+sEwOpC8FGj7HSHn47j+O4hZQQOOvEQOqN0mhENAmiv4CtW0iTD7TE58Y6MZB6ozQbOXqYhta8aM7ss7mS1eQDLfG5sU4MpFqjMwdzl1HMKJbKWLPrBdwwcBj37z8l7it4dNMK5LIZ6+dtmmTcSkVIrAstFvI5nHjrHPYIlKkOHzlt5Ruv1HAtAVcApOtJiGrXk+3XQLFURv+BUx0//xNXXSHyN1onBgDwSGGFiBjE0VEUZRrb31+XzeCDi5d8ixoAwN239vh2bEmGDo6jOtW5EE4I1GsAlooBUDPpTJ2AWUt64iuzrQC/HbxbV/d0PEwkCCSTqqT8bNb5DFwkfAdS8Velc1yfwLZ9o75aAY3kshnrhUAqqUrSIWqtGBTyOeOONszm7aOVzvE1k3AObI0WuEjl0hDJtvazVgwAmW5FOjshPPzMJGxFd5djvcNQzJcllIbsYrUYSHmX1ZEYDmG87jaNS2uF1B5fOnnKajGQ6lakjsRgcf0EQTeei4JVAMi0OfOj+tJIDIhoMxGNE9E0EfVJLcpFqjfchQ8vqd/AZ1wBWDZwGPfuGw3cT5Bx0pGwCoCaaf+1L3QuWhknheHNK8WFzzS0+DKATQD+h8BaZuH+sfftHzWqYnTHaEXhrhEkzXH//nW9Hb1GO4pjM/JCwrAIJOr5g0AipPjKH39ZcEUfYyQGzPwKAJCPhQCFfA7b9o0aH0f9BjPxivs31gV4iYSXeEhli5pQ+tbtoZ6/XZpf807ZURzzJXRqbdKRNLZWqgXNXHemSnVqlvCWJyrYtm8UDxw4hY8asuTcnyvtIxVdefLFt30Rg3l9BkT0t0T0ssfXVxdyooVMYW4mmzF3AE5+pH4Dk7j/Rwbpsn4SpZb4UtapXzNB5hUDZv5VZr7Z4+t/L+REzPwEM/cxc9/ixQvrQjS0cblxG7Tzk9XEJyCFEff3E5vLkb2Qimr5NS3M6tCiSyGfw+NfX2XcoyDJCUjFUjnyDV+cFCJRjtyKD4WE+K5bl4ocpxkjnwER/TqA/wZgMYDDRDTKzOtEVtaE+6ab7lOj/oHoBNMSWVu4NA0cH1gb9jI6olgqo1KdNj5Ol5Pyre7CyDJg5meZeQkzX8nMv+yXELgU8jl0OWbGTBIHsu48ZFYiawtRdQIXS2XcK+BszThp/NdNt5gvqAWR2CY0YqquU8yJ8xtIzCcMkkVpmlWxGjX/gMuO4hi27Rs1zr0IYlsUOTGQuDv0HziVOEGIEtUpxp9uWRVp/wBQswikcjCOD6z1/e+PXJ5B/7pe3LdvFCb2QXWKsfPQeOQurk7JZhxMVKJjHVyXzUSiXdl8DB00Hx8PBLe1jZxlUMjn8LjBpFqX85NV3DBwGGt2vRB7K2Fo43I4prFZn0g3rSuq24FmiqWymAD7FT1oJnJiAMjVcDM+TsONsyAU8jkMb14Z9jJmkM04eHPXejy2eWXktwPNSEZv1nz2msC6NkVum+CScVIioRrg4/yDqF+Ec1HI53wZZNoJToowtLFWYRiH7UAjxVLZuLDOJeg+jpG0DADgUeEQSxIKmaRKwk3IZhxfym9twLUIJIRg95ZVgfdxjKxlIDlfAYhuDHshuB9At/LwUxkH1alpfPDR7My4X7oyjV9clE9dHh2MRoXhQmku4zYlDLGMrBgAtfkKfddfg52Hxo1j6R9crBUyxfGO1YiXWT5XX4Pm391242Ice/UsyhMVpIkwxYxsxgFRzSnr/syLKBUVLYS7/+KHOP7aObHjhfU6EftUATUXfX19fOLECdFj3vTQc5g09CE4acLwnfE0YYPEq24/46Rj4Rxsplgqi5Zy+/06EdFJZvbsShYbMZB6U9xAl0nnH0Wui5KtSA5BaWT3llW+vk6JEAMAyD98VDT1Nq53M8UM11HoR73Hm7vWix+zkbnEILLRBC+kG2ImueRZaY1fhV/dIXfxjpUY+HEHT0LIUVkYfhR+OWkKvbtzrMTAD3TmguJSLJXx+YeeEz9umuxwXEc6tBgE5yeryD98FMzA+5VqLJ1hyvwUS2X0P3UKVYmMogacFFmThBU7y8CPGO35ySomKtXE1DIosxk+clpcCGzLxoydGASRclupTomVpyr2I90/MuOk8Oau9RgdvN0aIQBiKAaFfA6PblrhexbXRKWq1kECcBOopCDI19VIETsxAGqCcHxg7azWWdJo2DHeuIlsUu3lCcDdq3ussgYaMe2OPAxgA4CPALwG4N8z84TAukS4LpvxtWTXhnJgRZZiqSxS69JMLgKOZ9NowvMAtjPzJSL6NoDtAP7IfFky9K/r9S1TDIDxYBfFDvxKLW4kCi3eTQevHm34dgTAnWbLkcVVYT+UHoBI3boSLjuKY9g7csbXydFRac8vmWfw2wD2CR5PhMaSXT/uAGt2vTCjnDcK5mBS8SrH9lsIgOB6GJoyb6ESEf0tgH/u8asH3XmLRPQggD4Am7jFAYnoHgD3AEBPT88X3nrrLZN1GyFddtqMTYkkSg2pcegLJejWZfPha9UiEX0TwO8B+BVmnmznOX5VLS6EVTuP+to+PJtxYtvVJ4q4FlxQ2HpD8K1qkYjuAPAAgI3tCoEtDG1c7mtykuYh2EWQBWe2ZRa2i6nP4DsArgTwPNWcJCPM/PvGqwqAxn6Aft0x3GSVqF0UccTvMDPgf2MSvzEdvPovmXkpM6+qf0VCCFzc5KTdAkNZvNB+CPaw7NP+9xWMshAAWrUIQL7TciPliQo+/9BzeHTTLZG/WGzAq50agDlbrBVLZdGGpV5cvSjcFvQSxKrtmSnFUhn37z/VsruvKW46aqN3uTHcqeHJufGKCDhpwtQ0z8j5cNKELf9qKZ4++Y7YoJ25SKcIj0XER5CYHogSBBWCytXj3E+fLHueS/svziboiEA7RE245xID3SY04b6pDz475jlcRIryRGXObUkSRr4tFFta0BGAP424s9CLWFYtmlLI5zD+8B1Y89lrQl2HLRe/DRRLZaQsSOvNOOlYCgGgYjAne3/3S8hmwuuByACWDRxG/uGjic5ZcLdufvly5sOVoLhMiW6FbhPmYWjj8lDSWBs5P1nF/U/VRnzH9UJspDli8PMLF3Hxkv+OQC+i5hMwQcVgHoJITmqHqWnGg8+OoZDPRXJaUbvzHLNdDt6vVC9HB8J4zZPqvNVoggF+1ze0i+0Xr1eExg2z9l1/TeiWVyNxtwQ0tOgTNwwc9r381YRsxsHQxuUiF/aO4hiefPFtTDEjTYS7bl3adjXeXCHBFNnTF4IAvOHzeLOw0dCiTwSR727CRKWK+xpKtTvdWuwojs0Ig04xX/6+HUGYKypiixAAtfczyahlYEBYNfKSZDMOiICJydkDYuZrBuPG212RucpJ4eKlaUxz7Y5/5RUpfFidttp6crF9qyWFbhN8pFgqY+jguBW+AwkyThpLuq/CT9/7IOyl+AahNjYviVOydJvgI25btSCaagZBpToVayHIZTORaE4aBioGQvjda1ExJ+OkL1c5KrNRMfABFQY7uHpRGk46lbitQKeoGPhMszBE3eEYFXQ7sHC0NiFAho+c9hSC8Mtv4oVuBzpDxSBA5oq3+9V6La50OSlsXd1zecCuO6gk7sVEfqLbhABplaR0XTaDQj4XqxClNLa2Ho8TahkESP+63lnt2RtNWr/bt0cVJwUVggAwncL8xwC+CmAawHsAvsnM70osLI40VkB6pQXPNRuSgEhk8kmy0BoIxQyjDEQi+iVm/kX9//8JwE3ttEuPUwaiX3jNBWzVL7EZAnBFCgigF6g43V0OBjfIFFcps/EtA9EVgjpXI3k3L99oDEm69F1/zQyBWPbpDEZePz+jA1CuoX247WHM7i4HXYuuiFRfhjhj7EAkov8C4LcAvA/gNuMVKS3xEoj5cMXjU/WCpPOT1cst2bsb8vNBgF9lKt1dDj6sTs8QpoyTVgvAMkSmMNcftx3AVcw82OI41kxhVmZTLJXRf+AUqlNzXw8ZJ4WvfWEJjr16dsash6sXpT27STtpwvCdKwF0XkKtyBFI1SIR9QD4PjPfPN9j1WdgJ15+imOvnm37A9xcwan7f/vwTQyI6HPM/NP6//8jgH/DzHfO9zwVA0UJBz9LmHcRUS9qocW3AERq8KqiKB9jGk34mtRCFEUJF81AVBQFgIqBoih1VAwURQEQUkNUIjqLmsMxaD4D4OchnLcZW9YB2LMWW9YB2LMWP9ZxPTMv9vpFKGIQFkR0olVYJYnrAOxZiy3rAOxZS9Dr0G2CoigAVAwURamTNDF4IuwF1LFlHYA9a7FlHYA9awl0HYnyGSiK0pqkWQaKorQgcWJARMNE9CoRvUREzxJRNqR1bCaicSKaJqLAPddEdAcRnSainxHRQNDnb1jHXxPRe0T0clhraFjLUiI6RkT/UH9v/jCkdVxFRD8iolP1dewM4ryJEwMAzwO4mZlvAfB/AWwPaR0vA9gE4AdBn5iI0gD+HMCXAdwE4C4iuinoddT5GwB3hHTuZi4BuJ+ZbwKwGsB/COl1uQhgLTOvBLAKwB1EtNrvkyZODJj5KDNfqn87AmBJSOt4hZlPh3FuAF8E8DNmfp2ZPwLwXdQa2wYOM/8AwLkwzt0MM/8jM/+k/v9/AvAKgMCbMXCNC/VvnfqX7869xIlBE78N4LmwFxECOQBvN3z/DkK46G2GiJYByAN4MaTzp4loFLWu488zs+/riOUQlXZatRHRg6iZhXvDXIdiH0T0CQBPA9jW1PQ3MJh5CsCquk/rWSK6mZl99avEUgyY+Vfn+j0RfRPAVwD8CvsYW51vHSFSBrC04fsl9Z8lHiJyUBOCvcz8TNjrYeYJIjqGml/FVzFI3DaBiO4A8ACAjcw8GfZ6QuLHAD5HRDcQ0SIA3wBwMOQ1hQ4REYC/AvAKMz8e4joWu1EuIsoA+HcAXvX7vIkTAwDfAfBJAM8T0SgR/fcwFkFEv05E7wD4EoDDRHQkqHPXHah/AOAIak6y/cw8HtT5GyGiJwH8EEAvEb1DRL8TxjrqrAHwmwDW1q+NUSL6tRDWcS2AY0T0EmrC/Twzf8/vk2oGoqIoAJJpGSiK4oGKgaIoAFQMFEWpo2KgKAoAFQNFUeqoGCiKAkDFQFGUOioGiqIAAP4//SQg1DawBiAAAAAASUVORK5CYII=\n", 117 | "text/plain": [ 118 | "
" 119 | ] 120 | }, 121 | "metadata": { 122 | "needs_background": "light" 123 | }, 124 | "output_type": "display_data" 125 | } 126 | ], 127 | "source": [ 128 | "plt.figure(figsize=(4, 4))\n", 129 | "plt.scatter(x, y)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 5, 135 | "id": "4aeb9b09-8e59-45d0-8d99-5306af049b73", 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "class DiffusionProcess(nn.Module):\n", 140 | " def __init__(self, nb_timesteps=50, start=1e-4, end=0.05):\n", 141 | " super().__init__()\n", 142 | " \n", 143 | " self.nb_timesteps = nb_timesteps\n", 144 | " \n", 145 | " # beta = likelihood variance q(x_t | x_t-1)\n", 146 | " beta = torch.linspace(start, end, nb_timesteps)\n", 147 | " alpha = 1. - beta\n", 148 | " alpha_hat = alpha.cumprod(dim=0)\n", 149 | " \n", 150 | " # q(x_t|x_0) = N(x_t ; sqrt(alpha_hat) * x_0, forward_variance)\n", 151 | " prior_variance = (1. - alpha_hat)\n", 152 | " \n", 153 | " # forward process posterior variance (beta_hat) corresponding to (q(x_t-1 | x_t, x_0)\n", 154 | " alpha_hat_t_1 = F.pad(alpha_hat, (1, 0))[:-1]\n", 155 | " posterior_variance = (1 - alpha_hat_t_1) * beta / (1 - alpha_hat)\n", 156 | " posterior_variance[0] = beta[0]\n", 157 | " \n", 158 | " for (name, tensor) in [\n", 159 | " (\"beta\", beta),\n", 160 | " (\"alpha\", alpha),\n", 161 | " (\"alpha_hat\", alpha_hat),\n", 162 | " (\"prior_variance\", prior_variance),\n", 163 | " (\"posterior_variance\", posterior_variance)\n", 164 | " ]:\n", 165 | " self.register_buffer(name, tensor)\n", 166 | " \n", 167 | " def sample_t(self, batch_size=1, device=None):\n", 168 | " \"\"\"\n", 169 | " Sample a random timestep for each batch item.\n", 170 | " \"\"\"\n", 171 | " return torch.randint(0, self.nb_timesteps, size=(batch_size,), device=device)\n", 172 | " \n", 173 | " def sample_q(self, x0, eps, t):\n", 174 | " \"\"\"\n", 175 | " The \"forward process\". Given the data point x_0, we can sample\n", 176 | " any latent x_t from q(x_t|x_0)\n", 177 | " :param x0: the initial data point (batch, *)\n", 178 | " :param eps: noise samples from N(0, I) (batch, *)\n", 179 | " :param t: the timesteps in [0, nb_timesteps] (batch)\n", 180 | " \"\"\"\n", 181 | " assert (t >= 0).all() and (t < self.nb_timesteps).all(), \"Invalid timestep\"\n", 182 | " \n", 183 | " alpha_hat_t = self.alpha_hat[t, None] # (batch, 1, 1, 1)\n", 184 | " return alpha_hat_t.sqrt() * x0 + (1. - alpha_hat_t).sqrt() * eps\n", 185 | " \n", 186 | " def sample_q_next(self, x_t, eps, t):\n", 187 | " \"\"\"\n", 188 | " :param x_t: previous step! (batch, *)\n", 189 | " :param eps: noise samples from N(0, I) (batch, *)\n", 190 | " :param t: the timesteps in [0, nb_timesteps] (batch)\n", 191 | " \"\"\"\n", 192 | " assert (t >= 0).all() and (t < self.nb_timesteps).all(), \"Invalid timestep\"\n", 193 | " beta = self.beta[t, None]\n", 194 | " mu = (1. - beta).sqrt() * x_t\n", 195 | " return mu + (beta.sqrt() * eps)\n", 196 | " \n", 197 | " def sample_p(self, x_t, eps_hat, t, greedy=False):\n", 198 | " \"\"\"\n", 199 | " The \"reverse process\". Given a latent `x_t`, draw a sample from p(x_t-1|x_t)\n", 200 | " using the noise prediction.\n", 201 | " :param x_t: the previous sample (batch, *)\n", 202 | " :param eps_hat: the noise, predicted by neural net (batch, *)\n", 203 | " :param t: the timestep (batch)\n", 204 | " \"\"\"\n", 205 | " \n", 206 | " alpha_t = self.alpha[t, None, None, None] # (batch, 1, 1, 1)\n", 207 | " beta_t = self.beta[t, None, None, None]\n", 208 | " alpha_hat_t = self.alpha_hat[t, None, None, None]\n", 209 | " \n", 210 | " # calculate the mean\n", 211 | " mu = x_t - ((beta_t * eps_hat) / (1. - alpha_hat_t).sqrt())\n", 212 | " mu = (1. / alpha_t.sqrt()) * mu\n", 213 | " \n", 214 | " if greedy:\n", 215 | " return mu\n", 216 | " \n", 217 | " # sample\n", 218 | " std = self.posterior_variance[t, None, None, None].sqrt()\n", 219 | " x_next = mu + std * torch.randn_like(mu)\n", 220 | " \n", 221 | " return x_next" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 6, 227 | "id": "e3f20f87-1553-4c25-a0e9-d037e064b6c4", 228 | "metadata": {}, 229 | "outputs": [ 230 | { 231 | "name": "stdout", 232 | "output_type": "stream", 233 | "text": [ 234 | "tensor(0.9939)\n", 235 | "tensor(0.0001)\n", 236 | "tensor(1.0000e-04)\n", 237 | "tensor(1.0000e-04)\n" 238 | ] 239 | } 240 | ], 241 | "source": [ 242 | "nb_steps = 200\n", 243 | "start = 1e-4\n", 244 | "end = 0.05\n", 245 | "diff_process = DiffusionProcess(nb_steps, start, end)\n", 246 | "print(diff_process.prior_variance[-1])\n", 247 | "print(diff_process.prior_variance[0])\n", 248 | "\n", 249 | "print(diff_process.posterior_variance[0])\n", 250 | "print(diff_process.beta[0])" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 7, 256 | "id": "95405ac8-8c2e-462d-9b1c-7aea0d405453", 257 | "metadata": {}, 258 | "outputs": [ 259 | { 260 | "data": { 261 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfIAAAHxCAYAAACWBT5+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAABC0UlEQVR4nO3dd3xVhf3/8deHEMIIm7D3XiogIlhnRQVnq9ZZ96zWqvVrtbW11rZatbbVOvpzVZzULSq4Klatg70hgMwww0gI2ePz++Ne7DUmcIHcnDvez8fjPnLvuefe+z73Jveds83dERERkcTUIOgAIiIisu9U5CIiIglMRS4iIpLAVOQiIiIJTEUuIiKSwFTkIiIiCUxFLgnJzHaaWe+gc0TDQv5pZtvNbNo+PsdVZva3Oo5W0+vU6/6oZjbNzIbsYZw/mNkWM9tYX7lEEomKXGLOzN41sztrGH6amW00s4Z7+5zununuK+omYcwdDhwHdHX3UXv7YDNrBPwauG9/QpiZm1nf/XmOfXzd88xstZkVmtkbZtYm4u4/A9/53Yh4bHfgJmCwu3esozz7/D6Y2bFmtsTMisxsqpn1qItMe3jNp83sD/vwuJ0RlyozK464fX4sskowVORSHyYAPzYzqzb8AuB5d6+I9on2pfTjQA9glbsX7uPjTwOWuPu6OsxUL8Jz2/+P0GfdASgCHokYZRJwjJnVVtLdga3uvnkfXrtOf1fMrB3wGvAboA0wA/jXXjy+jZml12Wm3Qn/s5vp7pnAGuCUiGHP11cOiT0VudSHN4C2wBG7BphZa+Bk4BkzG2VmX5hZnpltMLOHwnOhu8Z1M7vWzJYByyKG9Q1fP8nMZpvZDjNba2Z3RDy2Z3jci8xsTXgR7W0R96eZ2a/M7GszKzCzmWbWLXzfQDP7wMy2mVm2mZ1V2wSaWWczmxQed7mZXREefhnwBDAmPCf0uxoem2Zmfw5nWxGeVo8oovHAfyLGP9vMVppZi/Dt8eElG1m7yfdJ+OrccI6zaxt3N88x2czuj7g90cye2sPDzgfecvdP3H0noRI83cyaA7h7CTATOKGG1xsLfAB0Dmd+Ojz8VDNbGP59+djMBkU8ZpWZ3WJm84DC6mW+n+/D6cBCd385nPsO4CAzGxjl448DcszsfjMbGs0DzOxKQu/hL8J539qLvJIq3F0XXWJ+AR4Hnoi4fRUwJ3z9YGA00BDoCSwGbogY1wl9obcBmkQM6xu+fjRwAKF/TA8ENgE/CN/XMzzu40AT4CCgFBgUvv9mYD4wALDw/W2BZsBa4JJwruHAFkKLeGuavk8IzWk2BoYBucD3w/ddDHy2m/fmamAJ0C08jVPDmRuG758O/KjaY54Hng5nXQ+cHMVn8M17trtxdnNfR2Az8H1C5bICaL6H53sTuKXasJ3AwRG3HwT+UsvjjwZyIm73BwoJlWI68AtgOdAofP8qYE74vWwSzftAaK4/bzeX88LjPQA8Wu25FgBn7MXfwVBCq0jWhz/Xa4DWe3jM08Afqg17ezd5367hOVYBY+vr712X+r1ojlzqywTgTDNrHL59YXgY7j7T3b909wp3X0VoUexR1R5/t7tvc/fi6k/s7h+7+3x3r3L3ecCLNTz+d+5e7O5zgbmEChvgcuDX7p7tIXPdfSuhpQWr3P2f4VyzgVeBH1V//fAc/PcIFVaJu88hNBd+YZTvzVnA39x9rbtvA+6udn8roKDasGsJFerHhOZ4347ytfaZu28EfkLoc3sAuNDdq+eqLhPIrzYsH2gecbuA0DRG42zgHXf/wN3LCa1jbwIcFjHOg+H38ju/KzVx9zXu3mo3lxf2Ylr29FoL3P1mQv9o3EHoH5WV4aUbLfbieU7eTd6To30eSQ4qcqkX7v4ZoTnaH5hZH2AU8AKAmfU3s7fDi4d3AHcB7ao9xdrantvMDg1veJRrZvmE5nCrPz5yi+ciQl/KEPpC/bqGp+0BHBpefJtnZnmE5kJrWpfbGdhWrdRWA11qy1zD4yOnb3W1+7dTrSzcPQ94mdAc3v3Un7eANCA7/JnuyU6gekG14Nv/mDQnNCcZjc5EvD/uXkXovYt8r2v9XdlP0UwLAGZ2fsSGZVOq3+/ulYSWBM0FthH6HOtt/bkkFxW51KdnCM2l/hh4z903hYc/SmjRcj93bwH8itBi7ki72y3qBUIbTXVz95bAP2p4fG3WAn1qGf6fanM6me7+kxrGXQ+02bXeN6w7EO3GaRsI/UMR+dhI8wgtUv6GmQ0DLiW09OHBKF+nLvyR0KqPTmZ2bhTjL+R/Sz+w0C6DGcDSiHEGESq0aKwn9E/WruczQu9d5Hu9V7vQmVl3+/YW3tUvu7bwrj4tzQj97iys/pzu/rz/b8Oy8RGPyTSzi83sI2AWoX9Aznb3oeElQTX5zvSY2ZTd5P3OPw6S3FTkUp+eAcYCVxBerB7WHNgB7AxvOFRTWe5Oc0JzxCVmNgo4by8e+wTwezPrZyEHmllbQusg+5vZBWaWHr4cErlh1S7uvhb4HLjbzBqb2YHAZcBzUWZ4CfiZmXW10EaAt1a7fzIRqwrCqyeeI/QPzyVAFzO7JorX2QTs8773ZnZk+PUuBC4C/m5me1rq8DxwipkdES6+O4HXdi29CE/LwYS2gYjGS8BJFtoNLJ3QrmmlhN7/aH3rfQgvWs/czWXXFt6vA0PN7Ixw7tuBee6+JJoXNbNxhP4ROZvQ6qMu7n6Nu0/fm7zhzON3k3d8Lc8jySrolfS6pNaF0Drd7UBGxLAjCc2R7wQ+JfRl/1nE/d/ZSItvb+x2JqHFrQWECvgh4LnwfT2J2HAsIsPl4etphPbRXhl+/HRC+3tDaAO4dwhtuLYV+AgYVst0dQ2/9jZCi+qvjrjvYna/sVtD4K/h11hJaP135MZu6YR2H+ocvv1XYErE4w8Kv26/Pbz3VxOa+88DzqplHK9leAtCG0ydEzHsHuB9wPbwuueF8xcS2vitTcR9PyJU7LU99mgiNnYLD/shsIjQ+un/AEMi7lvFHjbqiuZ92M1jx4Z/V4vDv0c99+KxvXZ9hnv5mv0IbcCXB7yxj393e3xfdEnci4U/ZBGJE2bWk1Chp3t4H/vwbkiD3f2GGL+2u3u0qyXq4vW+Ai5z9wX19ZoiyUZFLhJnairyenztei1yEdl/WkcukiTC66Fr3ABqL57mOwesieJ1/1HL6/5jb59LRPae5shFREQSmObIRUREEpiKXEREJIEl4pmkaNeunffs2TPoGCIiIvVi5syZW9y9xhMjJWSR9+zZkxkzZgQdQ0REpF6YWfVDN39Di9ZFREQSmIpcREQkganIRUREEpiKXEREJIGpyEVERBKYilxERCSBqchFREQSmIpcREQkganIRUREEpiKXEREJIGpyEVERBKYilxERCSBxbTIzewpM9tsZgtqud/M7EEzW25m88xsRCzziIiIJJtYz5E/DYzbzf3jgX7hy5XAozHOIyIiklRiWuTu/gmwbTejnAY84yFfAq3MrFMsM4mIiCSToNeRdwHWRtzOCQ/7DjO70sxmmNmM3NzcegknIiIS74Iu8qi5+2PuPtLdR2ZlZQUdR0REpEYl5ZXkFZXV2+s1rLdXqtk6oFvE7a7hYSIiIoGqqnJ2lJSztbCMbYVlbN1ZGrq+s4ztReXkF4cuO4rLySsu++Z2SXkV/Ttk8v6NR9VLzqCLfBLwUzObCBwK5Lv7hoAziYhIEisuq2TTjpLQpaCUTfn/u76loDRU2oVlbC8qo7LKa3yOZo3SaNkknRZN0mnZJJ1e7ZrRMny9ZZN0OrVsUm/TE9MiN7MXgaOBdmaWA/wWSAdw938Ak4ETgeVAEXBJLPOIiEhyq6xyNu4oYe22InK2F7N2WxHr8orZtKOEjeHC3lFS8Z3HNU5vQIcWjWmXmUH3tk0Z3r0VbTMb0aZZBm2bNaJN+NIuM4PWzdLJaJgWwNTVLKZF7u7n7uF+B66NZQYREUkuO0rKWZFbyKothf8r7O2hn+vziqmImIs2gw7NG9OxZWP6ZGVyWJ+2tG/RmA4tGtOxRWM6tMigfYvGtGjcEDMLcKr2XdCL1kVERL6jorKKtduLWZG7kxW5hazYspOvcwtZkVvIlp2l3xo3q3kG3Vo3YVi3VpxyUCe6tm5Kt9ZN6dq6CZ1bNaFRw4TZrnufqMhFRCRQW3eWsmRjAYs37CB7YwFLNhawdFMBpRVV34zTplkjerdrxvcHZtE7K5Pe7ZrRO6sZXVs3pXF6/CzmDoKKXERE6oW7sy6vmHk5+czNyWPR+h0s2VhAbsH/5rDbZWYwqFNzLhjdg/4dm9MnK5M+Wc1o1bRRgMnjm4pcRERiYuvOUubl5DNnbR7zcvKYl5PP1sLQ/tXpaUb/Ds05qn8WAzs2Z1CnFgzo2Jx2mRkBp048KnIREdlv7s7qrUVMW7mNr1ZuY9qqrazdVgyENjjr1z6TYwa256CuLTmwaysGdmoeV1t+JzIVuYiI7LWqKmfp5oJvinv6ym1sDi8ib9OsEYf0bM0Fo3twYNdWDO3SkswM1U2s6J0VEZGorM8r5rNlW/hkWS7/Xb6F7UXlAHRs0ZgxfdoyqlcbDu3Vhj5ZmQm7K1ciUpGLiEiNSsor+WLFVj5Zmsuny7awfPNOANo3z+CYge05rE87Du3Vhq6tm6i4A6QiFxGRb2zdWcpHSzbz4eJNfLJ0C8XllWQ0bMCoXm0455BuHNEvi/4dNMcdT1TkIiIpbvXWQqYs2MiHizYxc8123KFTy8aceXBXjh3UntG926b8vtrxTEUuIpKC1m4rYvL8Dbw9bwPz1+UDMLRLC64/th9jB3VgSOcWmutOECpyEZEUsWlHCW/NXc9b8zYwd20eAAd1a8VtJw5i/AEd6dq6abABZZ+oyEVEklhJeSXvL9rEqzNz+HRZLlUemvO+dfxATjqgE93aqLwTnYpcRCTJuDuz1uTxyswc3p63noKSCjq3bMw1R/fl9BFd6J2VGXREqUMqchGRJLGjpJzXZ63juS9Xs2zzTpqkpzF+aEfOOLgrY3q3pUEDrfNORipyEZEEt3B9Ps99uZo356ynqKySA7u25J4zDuCkAzvriGopQJ+wiEgCKq+sYvL8DTz9+Spmr8mjcXoDTj2oMz8OHxZVUoeKXEQkgRSUlDNx2lr++d+VrM8voXe7Ztx+8mDOGNGVlk3Tg44nAVCRi4gkgPV5xTz9+Spe/GoNBaUVjO7dhj/8cChH92+vdd8pTkUuIhLHVuTu5KGpy5k0Zz0OnHRAJ644ojcHdG0ZdDSJEypyEZE4tHzzTv7+0TLemrueRg0bcOGYnlx6eE8dtEW+Q0UuIhJHlm0q4MGPlvP2vPU0bpjGFUf05ooje9MuMyPoaBKnVOQiInFg1ZZC7v9gKW/PW0/T9DSuPqoPlx/ei7YqcNkDFbmISIC27Czl7/9exvNfrSE9rQE/OaoPlx/RmzbNGgUdTRKEilxEJACFpRU88elKHvvka0oqqjjnkG5cP7Yf7Zs3DjqaJBgVuYhIPaqorOLF6Wt54MOlbNlZxvihHbn5hAE6/rnsMxW5iEg9+WrFVn47aSFLNhYwqlcbHrtwICO6tw46liQ4FbmISIxtzC/hrsmLmTR3PV1aNeHR80cwbmhHzHQgF9l/KnIRkRgprajkyc9W8tBHy6mocn52bD9+clQfmjRKCzqaJBEVuYhIDHz+9RZue30BK7cUcvzgDvzm5MF0a6ODuUjdU5GLiNSh/KJy7p6ymInT19KjbVMmXDqKo/pnBR1LkpiKXESkjkyZv4HbJy1kW2EZVx/VhxvG9qNxuhajS2ypyEVE9tOmHSXc/uYC3lu4iSGdW/DPiw9haBed1ETqh4pcRGQfuTuvzVrHHW8tpKyiil+OH8hlh/eiYVqDoKNJClGRi4jsg22FZdz2+nymLNjIIT1bc9+ZB9GzXbOgY0kKUpGLiOylqdmb+cUr88grKuPW8QO54ojepDXQPuESDBW5iEiUisoquGvyYp77cg0DOjRnwiWjGNy5RdCxJMWpyEVEojA/J5+fTZzNqq2FXHFEL246foC2SJe4oCIXEdkNd+fZL1fzh7cX0zazEc9ffiiH9WkXdCyRb6jIRURqsaOknFtfncfk+Rs5ZkAWfzlrGK11nnCJMypyEZEaLFiXz7UvzCJnezG/DG/Q1kAbtEkcUpGLiERwd577cjW/Dy9K/9eVoxnZs03QsURqpSIXEQkrKa/k1lfn8cac9RwzIIv7zxpGGy1KlzinIhcRAdblFXPVszNYuH4HNx3Xn2uP6atF6ZIQVOQikvKmrdzGNc/PpKS8iscvGMnYwR2CjiQSNRW5iKS0575czR2TFtK9TVMmXjmSvu0zg44ksldU5CKSksoqqrjjrYW88NUajh6QxQPnDKdlk/SgY4nsNRW5iKSc/KJyrnpuBl+u2MZPju7D/x0/QMdKl4SlIheRlLJ2WxGXPD2dNVuL+NvZw/jB8C5BRxLZLypyEUkZc9fmcdmE6ZRVVPHMZaMY3btt0JFE9puKXERSwgeLNvGzF2fTNrMRE68cTd/2zYOOJFInVOQikvQmfL6K3721kKFdWvLkRYeQ1Twj6EgidUZFLiJJy925e8oSHvtkBWMHdeDBc4fRtJG+9iS56DdaRJJSZZXzy9fm8dKMHC4Y3YM7Th2iLdMlKanIRSTplFZUcsPEOUxZsJGffb8vNx7XHzOVuCQnFbmIJJWisgquenYmny7bwq9PGsTlR/QOOpJITKnIRSRp5BeVc8nT05izNo97zziQsw7pFnQkkZhTkYtIUthcUMKFT05jRW4hj5w/gnFDOwUdSaReqMhFJOFtzC/h3Me/ZNOOEp68eCRH9MsKOpJIvVGRi0hC25BfzLmPfcmWnWU8e9koDu7RJuhIIvVKRS4iCWtDfjHnPPYlW3eWMeHSURzco3XQkUTqXYOgA4iI7Iv1eSpxEdAcuYgkoPV5xZz7+Jds21nGM5eNYkR3lbikLhW5iCSUXXPi2wtDJT5cJS4pTovWRSRhbNoR2jpdJS7yPypyEUkI2wvLuODJr9hSUKoSF4mgResiEvcKSsq56J/TWLW1iAmXqMRFImmOXETiWnFZJZdNmMGi9Tt49PwRjOnTNuhIInFFc+QiErfKKqr4yfMzmb5qGw+cM5xjB3UIOpJI3NEcuYjEpcoq58Z/zeHj7Fzu+uEBnHpQ56AjicSlmBe5mY0zs2wzW25mt9Zwf3czm2pms81snpmdGOtMIhLf3J1fvTafd+Zv4LYTB3HuqO5BRxKJWzEtcjNLAx4GxgODgXPNbHC10X4NvOTuw4FzgEdimUlE4t/97y/lXzPWct33+3LFkTqfuMjuxHqOfBSw3N1XuHsZMBE4rdo4DrQIX28JrI9xJhGJY89/tZqHpi7n3FHd+Plx/YOOIxL3Yr2xWxdgbcTtHODQauPcAbxvZtcBzYCxMc4kInHqg0Wb+M0bC/j+wPb8/rShmFnQkUTiXjxs7HYu8LS7dwVOBJ41s+/kMrMrzWyGmc3Izc2t95AiEluz1mznuhdncUCXljx03nAapsXD15NI/Iv1X8o6oFvE7a7hYZEuA14CcPcvgMZAu+pP5O6PuftIdx+ZlZUVo7giEoSVWwq5fMIMOrRozJMXH0LTRtozViRasS7y6UA/M+tlZo0Ibcw2qdo4a4BjAcxsEKEi1yy3SIrILSjloqemYcCES0bRLjMj6EgiCSWmRe7uFcBPgfeAxYS2Tl9oZnea2anh0W4CrjCzucCLwMXu7rHMJSLxoaisgssmTCe3oJQnLz6Enu2aBR1JJOHEfPmVu08GJlcbdnvE9UXA92KdQ0TiS1WV8/N/zWXBunwev3Akw7q1CjqSSELS1iQiEoj7P8jm3YUbue2kwTr0qsh+UJGLSL17bVYOD0/9mnNHdefS7/UMOo5IQlORi0i9mrFqG7e+Op8xvdty52lDtK+4yH5SkYtIvVm7rYirnp1J51aNefTHI0jXvuIi+01/RSJSLwpKyrl8wgzKKqt48uJDaNW0UdCRRJKCjrogIjFXWeVcP3EOy3N3MuGSUfTJygw6kkjS0By5iMTcn9/P5qMlm7njlMEc3u87B24Ukf2gIheRmJoyfwOPfhzaQv2CMT2DjiOSdFTkIhIzSzcVcNPLcxnevRV3nDo46DgiSUlFLiIxkV9czlXPzqRpo4Y8ev7BZDRMCzqSSFJSkYtInQsdfnUOa7cV8cj5I+jYsnHQkUSSlopcROrcgx8t499LNvPrkwYxqleboOOIJDUVuYjUqX8v3sTfPlzG6SO6cNFhPYOOI5L0VOQiUmdWbinkhn/NYWiXFtz1wwN0+FWReqAiF5E6UVJeyTXPzyKtgfGPHx9M43Rt3CZSH3RkNxGpE797axGLN+zgnxcfQtfWTYOOI5IyNEcuIvvtjdnreHHaGn5ydB+OGdg+6DgiKUVFLiL7Zfnmnfzq9fmM6tmGm47rH3QckZSjIheRfVZcVsm1z8+iSXoaD547nIY6LalIvdM6chHZZ7e/uYClmwuYcMkoHfRFJCD691lE9skrM3N4eWYO1x3TlyP7ZwUdRyRlqchFZK8t3VTAr9+Yz5jebbl+rNaLiwRJRS4ie6WkvJLrXphNZkZDHjh3GGkNdNAXkSBpHbmI7JW7Jy8me1MBEy4dRfvmWi8uEjTNkYtI1P69eBMTvljN5Yf34iitFxeJCypyEYnK5h0l3PzKPAZ3asHN4wYEHUdEwlTkIrJHVVXOTS/PpaisggfPHU5GQx1HXSReqMhFZI+e+GwFny7bwm9PGULf9plBxxGRCCpyEdmt+Tn53PdeNuOGdOScQ7oFHUdEqlGRi0itCksr+NnE2bRtlsGfztD5xUXikXY/E5Fa/f7tRazaWsgLl4+mVdNGQccRkRpojlxEavTvxZuYOH0tVx3ZhzF92gYdR0RqoSIXke/YVljGLa/OZ2DH5tx4XL+g44jIbmjRuoh8i7tz2+vzyS8u49nLRmlXM5E4pzlyEfmWN+esZ8qCjfz8uAEM6tQi6DgisgcqchH5xob8Yn7z5gIO7tGaK4/sHXQcEYmCilxEgNDR225+eR6VVc5fzjpIZzUTSRAqchEB4LmvVvPZ8i3cdtIgerRtFnQcEYmSilxEWJG7k7smL+ao/lmcN6p70HFEZC+oyEVSXGX4hCgZDdO498wDdfQ2kQSj3c9EUtxTn61k9po8HjhnGB1aNA46jojsJc2Ri6SwFbk7+fP72Ywd1IFTD+ocdBwR2QcqcpEUVVXl3PLqPDIaNuCuHw7VInWRBKUiF0lRE75YxfRV27n9lCG01yJ1kYSlIhdJQau3FnLvu9kcPSCLM0Z0CTqOiOwHFblIitm1SL1hA+Pu03WOcZFEpyIXSTEvTFvDlyu2cdtJg+jUsknQcURkP6nIRVJIzvYi7p68mMP7tuPsQ7oFHUdE6oCKXCRFuDu/fG0+DlqkLpJEVOQiKeL12ev4dNkWbh0/kG5tmgYdR0TqiIpcJAVs3VnK799exIjurfjxoT2CjiMidUhFLpIC/vjOYnaWVvCnMw6kgU5PKpJUVOQiSe6Tpbm8NnsdVx/Vh/4dmgcdR0TqmIpcJIkVlVVw2xvz6Z3VjGuP6Rt0HBGJAZ39TCSJPfDhMtZuK2bilaNpnJ4WdBwRiQHNkYskqQXr8nnis5Wcc0g3RvduG3QcEYkRFblIEqqorOLW1+bRumkjfjl+UNBxRCSGtGhdJAk9/fkqFqzbwcPnjaBl0/Sg44hIDGmOXCTJrMsr5v73l3LswPaceEDHoOOISIypyEWSzB2TFgJw5w+G6jCsIilARS6SRD5ctIkPFm3i+rH96NJKZzYTSQUqcpEkUVRWwW8nLaR/h0wuO7xX0HFEpJ5oYzeRJPHgv5ezLq+Yl64aQ3qa/kcXSRX6axdJAks3FfDEpys48+CujOrVJug4IlKPVOQiCc7d+fUbC2iW0ZBfjh8YdBwRqWcqcpEE9+qsdUxbuY1bxw+kbWZG0HFEpJ6pyEUSWF5RGXdNXsyI7q04e2S3oOOISABU5CIJ7J53s8kvLuePPzxA5xkXSVEqcpEENWdtHhOnr+GSw3oyqFOLoOOISEBU5CIJqKrK+e2bC8jKzOD6sf2CjiMiAVKRiySgl2euZW5OPr86cRDNG+ukKCKpLOZFbmbjzCzbzJab2a21jHOWmS0ys4Vm9kKsM4kksvyicu55N5tDerbmtGGdg44jIgGL6ZHdzCwNeBg4DsgBppvZJHdfFDFOP+CXwPfcfbuZtY9lJpFE99cPl5JXVMYdp47SSVFEJOZz5KOA5e6+wt3LgInAadXGuQJ42N23A7j75hhnEklYizfs4JkvVnH+oT0Y0rll0HFEJA7Eusi7AGsjbueEh0XqD/Q3s/+a2ZdmNq6mJzKzK81shpnNyM3NjVFckfjl7vz2zYW0bJLOTcf3DzqOiMSJeNjYrSHQDzgaOBd43MxaVR/J3R9z95HuPjIrK6t+E4rEgUlz1zNt1TZ+MW4grZo2CjqOiMSJWBf5OiDycFNdw8Mi5QCT3L3c3VcCSwkVu4iE7Syt4K7JizmgS0vO0hHcRCRCrIt8OtDPzHqZWSPgHGBStXHeIDQ3jpm1I7SofUWMc4kklL9/tIxNO0r53WlDSNMR3EQkQkyL3N0rgJ8C7wGLgZfcfaGZ3Wlmp4ZHew/YamaLgKnAze6+NZa5RBLJ17k7eeqzlZx5cFdGdG8ddBwRiTNR7X5mZhnuXrqnYTVx98nA5GrDbo+47sDPwxcRieDu/O6tRTRumMYt43SKUhH5rmjnyL+IcpiI1KEPFm3ik6W53Hhcf7Ka6xSlIvJdu50jN7OOhHYXa2Jmw4FdK+daAE1jnE0kpZWUV3Ln24vo3yGTC8b0CDqOiMSpPS1aPwG4mNDW5vfzvyLfAfwqdrFE5IlPV5CzvZgXrjiU9LR42FNUROLRbovc3ScAE8zsDHd/tbbxzOyi8LgiUgc27SjhkY+/ZtyQjhzWp13QcUQkjkX1b/7uSjzs+jrIIiJh976bTUWl86sTBwUdRUTiXF0tr9OOrSJ1ZF5OHq/OyuHSw3vRva02RRGR3aurIvc6eh6RlObu3PnWItplNuLaY/oEHUdEEoDmyEXiyNvzNjBj9Xb+7/gBNG+cHnQcEUkAURW5mfXaw7D/1lkikRRVUl7Jn6YsYXCnFvxIx1MXkShFO0de08Zur+y64u4/rZs4IqnriU9XsC6vmN+cPFjHUxeRqO3pgDADgSFASzM7PeKuFkDjWAYTSSWRu5uN6dM26DgikkD2dECYAcDJQCvglIjhBcAVMcokknK0u5mI7Ks9HRDmTeBNMxvj7jq2ukgM7Nrd7Oqj+mh3MxHZa1Gd/QyYbWbXElrM/s0idXe/NCapRFKEdjcTkf0V7cZuzwIdCR17/T+Ejr1eEKtQIqlCu5uJyP6Ktsj7uvtvgMLwMdVPAg6NXSyR5KfdzUSkLkRb5OXhn3lmNhRoCbSPTSSR1PDkZyu1u5mI7Ldo15E/Zmatgd8Ak4BM4PaYpRJJcrkFpTwydTnHD+6g3c1EZL9EVeTu/kT46n+A3rGLI5Ia/vbhUkorqrh1/MCgo4hIgouqyM0sAzgD6Bn5GHe/MzaxRJLXsk0FTJy+lh8f2p3eWZlBxxGRBBftovU3gXxgJlAauzgiye9PU5bQND2N68f2DzqKiCSBaIu8q7uPi2kSkRTw+fIt/HvJZm4dP5A2zRoFHUdEkkC0W61/bmYHxDSJSJKrqnL+OHkxXVo14eLDegYdR0SSxJ5OmjIf8PB4l5jZCkKL1g1wdz8w9hFFksMbc9axcP0OHjhnGI3T04KOIyJJYk+L1k+O5knMrLW7b6+DPCJJqbiskvvey+bAri055cDOQccRkSSyp5OmrI7yef4NjNj/OCLJ6an/rmRDfgl/PXsYDXTwFxGpQ9GuI98TfTOJ1GLXwV+OG9yB0b118BcRqVt1VeReR88jknQe+Hfo4C+/1MFfRCQG6qrIRaQGyzcX8OK0tZyvg7+ISIxo0bpIDN09OXTwl58d2y/oKCKSpKI9IAwAZtYeaLzrtruvCV89ti5DiSSDz78OHfzllnEDaZuZEXQcEUlSUc2Rm9mpZrYMWEnoxCmrgCm77nf3bTFJJ5Kg3J0/TVlC55aNueR7PYOOIyJJLNpF678HRgNL3b0XoTnwL2OWSiTBTZ6/kXk5+dx4XH8d/EVEYiraIi93961AAzNr4O5TgZExzCWSsMorq/jz+9n075DJ6SO6Bh1HRJJctOvI88wsE/gEeN7MNgOFsYslkrhemrGWlVsKeeLCkaTp4C8iEmPRzpGfBhQDNwLvAl8Dp8QqlEiiKiqr4IEPlzGyR2uOHdQ+6DgikgKimiN398i57wkxyiKS8P7531VsLijlkfNHYKa5cRGJvWi3Wj/dzJaZWb6Z7TCzAjPbEetwIolke2EZ//j4a8YOas/Inm2CjiMiKSLadeT3Aqe4++JYhhFJZI98vJydZRXcfIIOxSoi9SfadeSbVOIitVuXV8yEL1Zz+vCuDOjYPOg4IpJCdjtHbmanh6/OMLN/AW8Apbvud/fXYhdNJHH87YOl4HDjcToUq4jUrz0tWo/cMr0IOD7itgMqckl5yzYV8OqsHC75Xi+6tm4adBwRSTG7LXJ3v6S+gogkqnvfy6ZZo4Zce0zfoKOISAqKdqv13mb2lpnlmtlmM3vTzHrFOpxIvJu5ehsfLNrElUf2pk2zRkHHEZEUFO3Gbi8ALwGdgM7Ay8DEWIUSSQTuzj1TsmmXmcFlR+j/WhEJRrRF3tTdn3X3ivDlOSJOZyqSiqZmb2baqm1cf2xfmjbaqzMCi4jUmWi/faaY2a2E5sIdOBuYbGZtQKcxldRTWeXc+242Pdo25ZxR3YOOIyIpLNoiPyv886pqw88hVOy96yyRSAJ4c846lmws4MFzh5OeFu2CLRGRuhftsda1AlAkrLSikvvfX8qQzi04+YBOQccRkRQX7QFhaqQDwkgqev7LNazLK+bu0w+ggU5TKiIB25sDwlSnA8JIyikoKeehqcs5rE9bjujXLug4IiI6IIzI3nj805VsKyzjlnEDdZpSEYkLUe8zY2YnAUOI2O3M3e+MRSiReJRbUMoTn67gpAM6cVC3VkHHEREBoj+y2z8I7XJ2HWDAj4AeMcwlEnce+mgZpRVV3HR8/6CjiIh8I9r9Zg5z9wuB7e7+O2AMoG8zSRlrthbxwrQ1nH1IN3pnZQYdR0TkG9EWeXH4Z5GZdQbKCR2uVSQl3P9BNmkNjOuP1WlKRSS+RLuO/G0zawXcB8witMX647EKJRJPFqzL580567nm6D50aKEjE4tIfIn2gDC/D1991czeBhq7e37sYonEj3vfy6Zlk3SuOqpP0FFERL5jr48t6e6lKnFJFZ9/vYVPluZy7TF9aNkkPeg4IiLfoYNEi9TC3bnn3Ww6tWzMhWN6Bh1HRKRGKnKRWry7YCNz1+Zx49j+NE5PCzqOiEiN9rrIzeyOGOQQiSsVlVXc9342fdtncvqILkHHERGp1b7MkZ9a5ylE4szLM3NYkVvIzScMoKFOUyoicWxfvqF0gGlJasVllfztw6WM6N6K4wd3CDqOiMhu7UuRH1znKUTiyNOfr2LTjlKdGEVEEsK+7H5WFYsgIvEgv6icRz9ezjEDsji0d9ug44iI7JFW/olEeOQ/yykoreAX4wYGHUVEJCoqcpGwDfnFPP3fVfxgWBcGdWoRdBwRkajsc5Gb2SV1GUQkaA98uIwqd35+nE7sJyKJY3/myH9XZylEArZ8805emrGWH4/uQbc2TYOOIyIStd2eNMXM5tV2FxDVfjlmNg54AEgDnnD3P9Uy3hnAK8Ah7j4jmucWqSt/fi+bpo0a8tNj+gYdRURkr+zp7GcdgBOA7dWGG/D5np7czNKAh4HjgBxguplNcvdF1cZrDlwPfBVlbpE6M3vNdt5duJEbx/anbWZG0HFERPbKnhatvw1kuvvqapdVwMdRPP8oYLm7r3D3MmAicFoN4/0euAcoiT66yP4LnRhlCe0yG3H5Eb2CjiMistd2W+Tufpm7f1bLfedF8fxdgLURt3PCw75hZiOAbu7+ThTPJ1Kn/rM0ly9XbOO67/ejWcaeFlCJiMSfQHc/M7MGwF+Am6IY90ozm2FmM3Jzc2MfTpJeVVXoNKXd2zTl3FHdg44jIrJPYl3k64BuEbe7hoft0hwYCnxsZquA0cAkMxtZ/Ync/TF3H+nuI7OysmIYWVLFW/PWs3jDDm46vj+NGuqQCiKSmGL97TUd6GdmvcysEXAOMGnXne6e7+7t3L2nu/cEvgRO1VbrEmtlFVX8+f1sBndqwSkHdg46jojIPotpkbt7BfBT4D1gMfCSuy80szvNTKdDlcC88NVq1m4r5hfjBtCggU6MIiKJK+Zb97j7ZGBytWG31zLu0bHOI7KztIK/f7Sc0b3bcFR/raYRkcSmzXQl5Tzx6Qq2FpbxhE5TKiJJQFv4SErZsrOUxz9ZwbghHRnevXXQcURE9puKXFLKQx8tp7i8kv87YUDQUURE6oSKXFLG2m1FPP/Vas4a2Y2+7TODjiMiUidU5JIy/vLBUhqYccNYnaZURJKHilxSwuINO3hjzjou/l5POrZsHHQcEZE6oyKXlHDvu0tontGQa47SaUpFJLmoyCXpfbViK1Ozc7nmmL60bJoedBwRkTqlIpek5u786d0ldGzRmIsP6xl0HBGROqcil6T2/qJNzF6Txw1j+9E4PS3oOCIidU5FLkmrorKK+97Lpk9WM848uGvQcUREYkJFLknrtVnrWL55JzefMICGafpVF5HkpG83SUol5ZX89cOlDOvWihOGdAw6johIzKjIJSk988UqNuSXcItOjCIiSU5FLkknv7ich6d+zVH9sxjTp23QcUREYkpFLknn//3na/KLy/nFOJ0YRUSSn4pcksqmHSU89d+VnDasM0M6tww6johIzKnIJan87cNlVFY5Nx2nuXERSQ0qckkayzfv5KUZazn/0B50b9s06DgiIvVCRS5J48/vZdMkPY3rvq8To4hI6lCRS1KYuXo77y7cyJVH9qZtZkbQcURE6o2KXBKeu3PPlCW0y8zgssN7BR1HRKReqcgl4X20ZDPTVm3jhrH9aJbRMOg4IiL1SkUuCa2yyrnn3SX0ateMsw/pFnQcEZF6pyKXhPbarByWbgqdGCVdJ0YRkRSkbz5JWCXllfzlg6Uc1K0V44fqxCgikppU5JKwJnweOjHKrToxioikMBW5JKT8onIenrqcowfoxCgiktpU5JKQHvnPcgpKK7hl3MCgo4iIBEpFLglnfV4x//zvKn44vAuDOrUIOo6ISKBU5JJw/vrBUnD4+XH9g44iIhI4FbkklOyNBbw6K4cLx/Sga2udGEVEREUuCeW+95bQLKMh1x6jE6OIiICKXBLItJXb+HDxZn5ydB9aN2sUdBwRkbigIpeE4O78acpiOrTI4JLDdGIUEZFdVOSSEN5buIlZa/K4cWx/mjRKCzqOiEjcUJFL3CuvrOKed5fQt30mZx7cNeg4IiJxRUUuce+Fr9awckshvxw/kIY6MYqIyLfoW1Hi2o6Sch749zLG9G7L9we2DzqOiEjcUZFLXHv046/ZVljGbScN0olRRERqoCKXuLUur5inPlvJD4d3YWiXlkHHERGJSypyiVv3v5eNA/93woCgo4iIxC0VucSlBevyeW32Oi79Xi+6tGoSdBwRkbilIpe44+788Z3FtG6azjXH9Ak6johIXFORS9yZmr2ZL1Zs5Yax/WnROD3oOCIicU1FLnGlorKKuycvoVe7Zpx3aPeg44iIxD0VucSVl2bksGzzTm4ZN5B0HfxFRGSP9E0pcWNnaQV/+WApI3u05oQhHYKOIyKSEFTkEjce+2QFW3aW6uAvIiJ7QUUucWHTjhIe/2QFJx3YieHdWwcdR0QkYajIJS7c/342FVVV3HLCwKCjiIgkFBW5BG7BunxenpnDxYf1pHvbpkHHERFJKCpyCZS7c+dbi2jTtBHXHdsv6DgiIglHRS6Bmjx/I9NWbeOm4wfo4C8iIvtARS6BKSmv5K7JixnYsTlnH9It6DgiIglJRS6BefKzlazLK+b2kweT1kC7m4mI7AsVuQRi844SHp66nOMHd+Cwvu2CjiMikrBU5BKI+97Lpryyil+dOCjoKCIiCU1FLvVufk4+r8zK4dLv9aJnu2ZBxxERSWgqcqlX7s6dby+kTdNGXPv9vkHHERFJeCpyqVeT529k+qrt2t1MRKSOqMil3mh3MxGRuqcil3rzze5mp2h3MxGRuqIil3qxa3ezE4Z04LA+2t1MRKSuqMilXtyr3c1ERGJCRS4xN2vNdl6ZmcOlh/eiR1vtbiYiUpdU5BJTlVXOb99cSIcWGfzs+zq7mYhIXVORS0z9a/pa5q/L57aTBtMso2HQcUREko6KXGJme2EZ9763hEN7teGUAzsFHUdEJCmpyCVm7v8gm4KSCn532hDMtLuZiEgsqMglJhasy+f5r9Zw4ZgeDOzYIug4IiJJK+ZFbmbjzCzbzJab2a013P9zM1tkZvPM7N9m1iPWmSS2qqqc299cQNtmjbhhbP+g44iIJLWYFrmZpQEPA+OBwcC5Zja42mizgZHufiDwCnBvLDNJ7L0+ex2z1uRxy7iBtGyi46mLiMRSrOfIRwHL3X2Fu5cBE4HTIkdw96nuXhS++SXQNcaZJIZ2lJRz95QlDO/eijNG6KMUEYm1WBd5F2BtxO2c8LDaXAZMiWkiiakHPlzG1sJS7jx1KA10PHURkZiLmx17zezHwEjgqFruvxK4EqB79+71mEyitWj9Dp7+fBXnjurOAV1bBh1HRCQlxHqOfB0Qeb7KruFh32JmY4HbgFPdvbSmJ3L3x9x9pLuPzMrKiklY2XdVVc6vXp9Pqybp/OKEAUHHERFJGbEu8ulAPzPrZWaNgHOASZEjmNlw4P8RKvHNMc4jMfLi9DXMWZvHr08eRKumjYKOIyKSMmJa5O5eAfwUeA9YDLzk7gvN7E4zOzU82n1AJvCymc0xs0m1PJ3EqdyCUu6ZsoQxvdvyg2G72wRCRETqWszXkbv7ZGBytWG3R1wfG+sMElt/fGcRJeVV/OGHQ3UENxGReqYju8l++e/yLbwxZz1XH92HPlmZQccREUk5KnLZZyXllfz6jQX0aNuUa47uE3QcEZGUFDe7n0ni+cd/vmbllkKeuXQUjdPTgo4jIpKSNEcu+2TllkIemfo1pxzUmSP7a3dAEZGgqMhlr7k7v3ljARnpDfjNyYOCjiMiktJU5LLXXpmZw2fLt/CLEwbQvnnjoOOIiKQ0Fbnslc0FJfz+7UUc0rM15x+qM86KiARNRS575bdvLqSkooo/nXGgTooiIhIHVOQStSnzNzBlwUZuGNtP+4yLiMQJFblEJb+onN+8uZChXVpw5RG9g44jIiJh2o9covKHdxaxvaiMCZceQsM0/f8nIhIv9I0se/TJ0lxenpnD1Uf1ZkhnnWdcRCSeqMhltwpLK/jla/PpndWM677fL+g4IiJSjRaty27d91426/OLefmqMToMq4hIHNIcudRq2sptTPhiFReO7sHInm2CjiMiIjVQkUuNdpZWcNPLc+jepim/GDcw6DgiIlILLVqXGv3xnUWs217MS1eNoVmGfk1EROKV5sjlOz5asokXp63lyiP7aJG6iEicU5HLt2wvLOOWV+czsGNzbjxOW6mLiMQ7LTOVb7g7v35jAXlFZUy4ZBQZDbWVuohIvNMcuXxj0tz1vDN/Azce15/BnVsEHUdERKKgIhcANuaX8Js3FjCieyuuOrJP0HFERCRKKnLB3bn5lbmUVzr3nzWMNJ2eVEQkYajIhQmfr+LTZVv41YkD6dWuWdBxRERkL6jIU9yi9Tu4a/ISjh3Ynh+P7hF0HBER2Usq8hRWVFbBdS/OolXTdO770UGYaZG6iEii0e5nKezOtxaxYkshz192KG2aNQo6joiI7APNkaeod+ZtYOL0tfzkqD4c1rdd0HFERGQfqchT0NptRdz62jyGdWvFjcf1DzqOiIjsBxV5iimtqOSnL8wC4MFzhpOepl8BEZFEpnXkKeaudxYzNyeff/z4YLq3bRp0HBER2U+aHUshb81dz4QvVnP54b0YN7Rj0HFERKQOqMhTxNe5O7n11Xkc3KM1t4wfGHQcERGpIyryFFBcVsk1z80iIz2Nh87TenERkWSideRJzt259bV5LN1cwIRLRtGpZZOgI4mISB3SrFmSe/zTFbw5Zz3/d/wAjuyfFXQcERGpYyryJPbJ0lz+NGUJJx7QkWuO1qlJRUSSkYo8Sa3eWsh1L86mf4fm3HemjqMuIpKsVORJqLC0giufmQnAYxeMpFmGNoUQEUlWKvIkU1nlXD9xDss2F/DQecN10BcRkSSnIk8yd01ezIeLN/HbU4ZwRD9t3CYikuxU5Enk2S9W8eRnK7n4sJ5cdFjPoOOIiEg9UJEnianZm/ntpIUcO7A9vzl5cNBxRESknqjIk8DiDTv46fOzGNixBQ+eO5y0BtpCXUQkVajIE9zabUVc9NQ0Mhs35MmLtYW6iEiq0bd+AsstKOWCJ7+itKKKl64ao8OvioikIM2RJ6iCknIu/uc0Nu4o4amLRzKgY/OgI4mISABU5AmopLySK56ZQfbGAh798cEc3KNN0JFERCQgWrSeYMoqqvjpC7P5csU2/nb2MI4Z0D7oSCIiEiDNkSeQ8soqrntxFh8u3sTvTxvCD4Z3CTqSiIgETEWeIMorq/jZi7N5b+Em7jhlMBeM6Rl0JBERiQMq8gRQUVnFDRPnMGXBRm4/eTAXf69X0JFERCROqMjjXFlFFddPnMM78zfw65MGcenhKnEREfkfbewWx4rKKrj6uVl8sjSXX580iMuP6B10JBERiTMq8jiVX1TOJU9PY87aPO4940DOOqRb0JFERCQOqcjj0OYdJVz41DRW5BbyyPkjGDe0U9CRREQkTqnI40z2xgIufXo624vKeOriQzi8X7ugI4mISBxTkceRqdmbue6F2TRplMbEK0dzYNdWQUcSEZE4pyKPA+7OhM9XcefbixjYsQVPXDSSzq10AhQREdkzFXnASisqufOtRTz/1RqOG9yBv509TKciFRGRqKkxArRmaxHXvjCL+evyufqoPvzihAE0aGBBxxIRkQSiIg/Iuws2cvMrczHgsQsO5vghHYOOJCIiCUhFXs+Kyyq5590lPP35Kg7q2pKHzhtBtzZNg44lIiIJSkVej2as2sbNr8xj5ZZCLj6sJ788cSAZDdOCjiUiIglMRV4Pissq+fP72Tz135V0btmE5y8/lO/11f7hIiKy/1TkMeTufLh4M394ZxGrtxZxwege3DJ+IJnaKl1EROqIGiVGlmzcwR/eXsxny7fQt30mL1xxKIf10Vy4iIjULRV5HVufV8zDU5fz4rQ1NG+czh2nDOb80T1IT9MZY0VEpO6pyOvIyi2FPPrxcl6fvQ53uGB0D24Y25/WzRoFHU1ERJKYinw/VFU5X63cxnNfrWbK/A2kpzXgvFHdueLI3nRtrV3KREQk9mJe5GY2DngASAOecPc/Vbs/A3gGOBjYCpzt7qtinWt/bMgv5pUZObw8M4c124pontGQK4/sw2WH9yKreUbQ8UREJIXEtMjNLA14GDgOyAGmm9kkd18UMdplwHZ372tm5wD3AGfHMtfeqqxy5uXk8XF2Lh9nb2beunzcYUzvttx4XD/GDelEk0baH1xEROpfrOfIRwHL3X0FgJlNBE4DIov8NOCO8PVXgIfMzNzdY5ytRqUVlazPK2HR+h0s2pDPwvU7mLs2j+1F5ZjB8G6t+PnY/pw2rAvd22rxuYiIBCvWRd4FWBtxOwc4tLZx3L3CzPKBtsCWGGcD4M0565g4bS25O0vJLSglv7j8m/saNjD6ts/k2EEdOKJfO47sl6WN10REJK4kzMZuZnYlcCVA9+7d6+x5yyudiqoq+rXP5LA+bcnKzKBDy8YM6tiCfh0yaZyuReYiIhK/Yl3k64BuEbe7hofVNE6OmTUEWhLa6O1b3P0x4DGAkSNH1tli9zMP7sqZB3etq6cTERGpV7E+Ssl0oJ+Z9TKzRsA5wKRq40wCLgpfPxP4KKj14yIiIokmpnPk4XXePwXeI7T72VPuvtDM7gRmuPsk4EngWTNbDmwjVPYiIiIShZivI3f3ycDkasNuj7heAvwo1jlERESSkQ4ALiIiksBU5CIiIglMRS4iIpLAVOQiIiIJTEUuIiKSwFTkIiIiCUxFLiIiksBU5CIiIglMRS4iIpLAVOQiIiIJTEUuIiKSwFTkIiIiCUxFLiIiksBU5CIiIglMRS4iIpLAzN2DzrDXzCwXWF0HT9UO2FIHzxMPNC3xSdMSnzQt8UnTUrse7p5V0x0JWeR1xcxmuPvIoHPUBU1LfNK0xCdNS3zStOwbLVoXERFJYCpyERGRBJbqRf5Y0AHqkKYlPmla4pOmJT5pWvZBSq8jFxERSXSpPkcuIiKS0FK2yM1snJllm9lyM7s16Dx7w8y6mdlUM1tkZgvN7Prw8DvMbJ2ZzQlfTgw6azTMbJWZzQ9nnhEe1sbMPjCzZeGfrYPOuSdmNiDivZ9jZjvM7IZE+VzM7Ckz22xmCyKG1fg5WMiD4b+feWY2Irjk31XLtNxnZkvCeV83s1bh4T3NrDji8/lHYMFrUMu01Po7ZWa/DH8u2WZ2QjCpa1bLtPwrYjpWmdmc8PC4/Vx28x0czN+Lu6fcBUgDvgZ6A42AucDgoHPtRf5OwIjw9ebAUmAwcAfwf0Hn24fpWQW0qzbsXuDW8PVbgXuCzrmX05QGbAR6JMrnAhwJjAAW7OlzAE4EpgAGjAa+Cjp/FNNyPNAwfP2eiGnpGTlevF1qmZYaf6fC3wNzgQygV/h7Li3oadjdtFS7/37g9nj/XHbzHRzI30uqzpGPApa7+wp3LwMmAqcFnClq7r7B3WeFrxcAi4Euwaaqc6cBE8LXJwA/CC7KPjkW+Nrd6+LARfXC3T8BtlUbXNvncBrwjId8CbQys071EjQKNU2Lu7/v7hXhm18CXes92D6o5XOpzWnARHcvdfeVwHJC33dxYXfTYmYGnAW8WK+h9sFuvoMD+XtJ1SLvAqyNuJ1DghahmfUEhgNfhQf9NLzo5qlEWBwd5sD7ZjbTzK4MD+vg7hvC1zcCHYKJts/O4dtfSIn4uUDtn0Oi/w1dSmgOaZdeZjbbzP5jZkcEFWov1fQ7lcifyxHAJndfFjEs7j+Xat/Bgfy9pGqRJwUzywReBW5w9x3Ao0AfYBiwgdBiqkRwuLuPAMYD15rZkZF3emjZVMLsXmFmjYBTgZfDgxL1c/mWRPscamNmtwEVwPPhQRuA7u4+HPg58IKZtQgqX5SS4neqmnP59j+/cf+51PAd/I36/HtJ1SJfB3SLuN01PCxhmFk6oV+g5939NQB33+Tule5eBTxOHC1S2x13Xxf+uRl4nVDuTbsWPYV/bg4u4V4bD8xy902QuJ9LWG2fQ0L+DZnZxcDJwPnhL1rCi6G3hq/PJLReuX9gIaOwm9+pRP1cGgKnA//aNSzeP5eavoMJ6O8lVYt8OtDPzHqF557OASYFnClq4XVJTwKL3f0vEcMj17n8EFhQ/bHxxsyamVnzXdcJbZC0gNDncVF4tIuAN4NJuE++NWeRiJ9LhNo+h0nAheGtcUcD+RGLFOOSmY0DfgGc6u5FEcOzzCwtfL030A9YEUzK6Ozmd2oScI6ZZZhZL0LTMq2+8+2DscASd8/ZNSCeP5favoMJ6u8l6K3/groQ2opwKaH/8m4LOs9eZj+c0CKbecCc8OVE4Flgfnj4JKBT0FmjmJbehLaynQss3PVZAG2BfwPLgA+BNkFnjXJ6mgFbgZYRwxLicyH0z8cGoJzQOrzLavscCG19+3D472c+MDLo/FFMy3JC6yl3/c38IzzuGeHfvTnALOCUoPNHMS21/k4Bt4U/l2xgfND59zQt4eFPA1dXGzduP5fdfAcH8veiI7uJiIgksFRdtC4iIpIUVOQiIiIJTEUuIiKSwFTkIiIiCUxFLiIiksBU5CJS78ysoZm9Z2ZDdjNOCzP71MwS7fC8IvVKRS6SYMysbcSpHTdGnM5yp5k9EsPXPdrMDquL5/LQyUsuAO4OHyELM2tlZtdEjLMDuAL4S83PIiKA9iMXSWRmdgew093/nOivFT75xNvuPjQWzy+SrDRHLpIkwnPMb4ev32FmE8KLpleb2elmdq+ZzTezdyPmgg8On1lqZnhR967jRP/MzBaFz641MVyyVwM3huf+jwgfQvNVM5sevnwv4rWfNbMvzGyZmV0R5ST8CegTfv776vwNEklSDYMOICIx0wc4BhgMfAGc4e6/MLPXgZPM7B3g78Bp7p5rZmcDfyR0is9bgV7uXmpmrdw9z8z+QcQcuZm9APzV3T8zs+7Ae8Cg8GsfCIwmdMja2Wb2jruv30PeW4Gh7j6s7t4CkeSnIhdJXlPcvdzM5gNpwLvh4fOBnsAAYCjwQegcEKQROg42hI4h/byZvQG8UcvzjwUGhx8L0CJ8WkeAN929GCg2s6mEzs5V2/OIyH5QkYskr1IAd68ys3L/3wYxVYT+9g1Y6O5janjsScCRwCnAbWZ2QA3jNABGu3tJ5MBwsVff+EYb44jEiNaRi6SubCDLzMZA6PzKZjbEzBoA3dx9KnAL0BLIBAqA5hGPfx+4btcNMxsWcd9pZtbYzNoCRxM6dfCeVH9+EYmCilwkRbl7GXAmcI+ZzSV0KsbDCC1ify68SH428KC75wFvAT/ctbEb8DNgZHiDuEWENobbZR4wFfgS+H0U68dx963Af81sgTZ2E4medj8TkTpVn7vEiYjmyEVERBKa5shFREQSmObIRUREEpiKXEREJIGpyEVERBKYilxERCSBqchFREQSmIpcREQkgf1/EmmhdZ249TIAAAAASUVORK5CYII=\n", 262 | "text/plain": [ 263 | "
" 264 | ] 265 | }, 266 | "metadata": { 267 | "needs_background": "light" 268 | }, 269 | "output_type": "display_data" 270 | } 271 | ], 272 | "source": [ 273 | "# visualise the various terms\n", 274 | "t = torch.arange(1, nb_steps+1)\n", 275 | "plt.figure(figsize=(8, 8))\n", 276 | "plt.title(\"Variance of q(x_t | x_0) for t=0 -> t=T\")\n", 277 | "plt.xlabel(\"Timestep `t`\")\n", 278 | "plt.ylabel(\"1 - alpha_hat_t\")\n", 279 | "plt.plot(t[:], diff_process.prior_variance[:])\n", 280 | "plt.show()" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": 8, 286 | "id": "ef83a83b-1b8c-4df9-b2a7-8834dbd1829e", 287 | "metadata": {}, 288 | "outputs": [ 289 | { 290 | "name": "stderr", 291 | "output_type": "stream", 292 | "text": [ 293 | "100%|██████████| 199/199 [00:00<00:00, 4911.25it/s]\n" 294 | ] 295 | } 296 | ], 297 | "source": [ 298 | "with torch.no_grad():\n", 299 | " x0 = torch.tensor(X).float()\n", 300 | " x_t = x0\n", 301 | " out = [x0]\n", 302 | " batch = x0.shape[0]\n", 303 | " for t in tqdm(range(1, nb_steps)):\n", 304 | " t_ = torch.full((batch,), t, dtype=torch.long, device=x0.device)\n", 305 | " eps = torch.randn_like(x0)\n", 306 | "# x_t = diff_process.sample_q(x0, eps, t_)\n", 307 | " x_t = diff_process.sample_q_next(x_t, eps, t_)\n", 308 | " out.append(x_t.cpu())" 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": 9, 314 | "id": "2b228690-3308-4f01-82f5-7daac59e09d7", 315 | "metadata": {}, 316 | "outputs": [ 317 | { 318 | "name": "stderr", 319 | "output_type": "stream", 320 | "text": [ 321 | "/home/angusturner/miniconda3/envs/py38/lib/python3.8/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variable as a keyword arg: y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.\n", 322 | " warnings.warn(\n" 323 | ] 324 | }, 325 | { 326 | "data": { 327 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAA4RklEQVR4nO2df6iu2VXfv+uM6Z25iU6auUMzTXIzBQethKJ0EiGIlUZrKLZpbIXa/GP949Y/RIUWbQwYrAQMQqmIfzg0AQODVohBsZEmoaFRQmJmJGp+qEmFaWIHMzOXmTjc5NrM3f3jnH3vOuusvfba+9nPr/ddH7jMnPM+z372+57n/e71rL1+UEoJQRAEwX45WXsCQRAEwTRCyIMgCHZOCHkQBMHOCSEPgiDYOSHkQRAEO+fr1rjolStX0oMPPrjGpYMAAHD9+RfU37/sJXctPJMg8PP4448/nVK6X/5+FSF/8MEH8dhjj61x6SAAADz60WfV37/l9S9ddB5B0AIRPaH9PlwrwdFREvEg2Csh5MHRUbK6wxoP9soqrpUgWJsQ7eCQCIs8CIJg54SQB0EQ7JxwrQQHC9/UDFdKcMiEkAcHS4h3cCyEayUIgmDnhJAHQRDsnBDyIKgQCUTB1gkhD4Ig2DmThZyI7iaiPyCiPyKiTxPRz46YWBAEQeBjRNTKTQD/OKX0PBG9CMDvE9HvppQ+NmDsIFidiH4Jts5kIU+n3ZufP/vxRWf/oqNz0Mwacd/5miHWwZ4ZEkdORHcBeBzANwL45ZTSx5VjrgG4BgBXr14dcdlgYZYS2iVFda8CLjdg9/o+gjEMEfKU0gsAvpWIXgrgfUT0mpTSp8QxjwB4BAAefvjhsNiDC4QYrUc8meyboVErKaVnAXwYwBtHjhtsg7e8/qW3/wXbYkqI5JLhlRHKOQ906uKeMADR/QD+X0rpWSK6B8AHALwzpfQ7pXMefvjhFB2CghrhPrDRRHHLn1HUvpkOET2eUnpY/n6Ea+UBAL965ic/AfAblogHgZeRX/ZDdB285fUvbbZw1/wc8nwP6W+wFUZErfwxgG8bMJcgCBrZmyjubb57IaofBkfB0gKiWb5bcC2EkB4mIeRBMBDp6ii5PsLFEIwkhDzoYm4h2oL1qmHNq8VfXXpPW33fwbYJIQ82z6FsVO59/q0cyt9tD4SQB5vDckcA6wqDdW0risQ75xC9oIfJceQ9RBx5kNFixUe4KOagZSHZwqITHB6lOPIQ8mB1uOj1Zv7NJZhbm09w3JSEPBpLBKuz1bT/vaS9B0EIebAZtmT9hhAHeyI2Ow+MXt/sXOGEnnFbRXNu690zn5obqCU0cYtPI5mpm7fBMoSQHzn8izpSzOW4wLS460zP/EZuPPIxvOPV3udWi4PFU8l+CCHfIbWkkbWFYM5sRu8Y2hxGJ9vM8bRRGzcShgKN8JEfGC1f7jUtLn7tp2+enPvXcm7P61N59KPP3r4G///R16j9bu73OeXpJ1iWsMh3yChrci9If7T1/pcQ8am0hlmu6afei2Fw7ISQHzCH8Bie593rTimNNxJPtqc8Jv9szXmODdI5/fF7vccOgRDyYDHkRmFNeFuFQRtzC+JSE3pgTMGtTESaHB8h5DvFs9G2pS+uNZcrl241Hd9ynZ6IlR7Xx9TP2ivOUzZXo2zA4RIp+jtkhMtk7n6PLZZxqxVdm/soS3+p8EjPPFpcS3zDWC6SIeL7JmqtLMRSj/Y9VuBoV0bLtXp93K2JNTIaRsKFbe5N0ymfZ6t7RBPxzBpivnTjjEPYD/IQQr4QPTfU3AK7lmU5wh3QQxazZxRRu+/SrdvC5onVHsHoyA9rcfMIuTXWKKID0jyUhDx85IMZbSXnY9bqOZnpuf6oSJPRPH3zBFcu3XJ9rp64dg0unp7PcoRPvmeuIbiHQVjkR4JXKKQYeF0RI+fQSsmlIi3y+9h7sazymoUrx77PsHgta3gOrL9fiUMT8kN2s8xmkRPRqwC8B8DfAZAAPJJS+sWp487Nse3gezYDuQg8c/ME9126ddt6zed5wuWmWpwtyDjzRz/6LK6czbsksCVx0+aYx7J4xrhWi5U8QvRbxzjk+/+Q35tkhGvlawD+fUrpD4no6wE8TkQfTCl9ZsDYwQy0RHhIemuHZKyIih6kayBb5y2WqGWBXznzqVuCbFnkEs1vn8fI1yjN3ZrD0pY/sF3Ld0tzWYrJQp5SehLAk2f//9dE9FkArwCwaSE/xj92DS5YLeIk0cReEyFu7Wd6FpmSmFtIEa9Zznmetc/HcueURDy/po1ZE3c5JncZHXLKfPj2zzPUR05EDwL4CIDXpJS+LF67BuAaAFy9evUfPvHEE8OuO5KtWhlzUHOxjLDyvAKZWcIt47HEM3J+2nGWQAN+Mc/HyUWjRMlPf8julWP6fmrMHn5IRC8B8L8AvCOl9JvWsbHZuT5zWmtT/cJL+tgBv4hrx3MxffKr58d54O7T8zVLu7YRa4URapu5NSHv2QTNHKNgbpVZww+J6EUA3gvg0ZqIB9tgrkfvkshlPC6bERvRHl+8d8GR7h7PBmiN1o3Y0hgl10pm6jyBcGPsgckWORERgF8FcD2l9BOec8Ii3w5TxNzKpLQ29ThzbdLVhNzr8rGeDqRl3GKRt+IR6Jb3OefmaIj+fMzmWiGi7wDwewD+BEC+O346pfT+0jkh5NtilGW+JSHn85ki5ECbmEtGiLicTwsei3zuiJeaqyyE30+k6C9M6Ubd6g28lJhvRcj56xKP315+Xp7EIcAf7WIdO+Uz877n1uih2rhbutf3TKToK6yZFNRaDXCOJgPWGCOEfWTGY2sNkdbX+fjaAlD7PDS/eS2apUY+X36OvVFFPXsC/OdalJO2aGkhpsF4jlrIM3Nt5mhj9lxnT/HAtdA5bwEnT2jgyDDJ0vgtY4/YBNWoxZjLOczF1ESysMrnI1wrO2GkRV4r4tQTXz2CmpBb9U3mcDf0jqu5l0b5yr1Y8/aEIlqVIa04fLnohGtlLOEjn8gWarOMaOHVYklptVc4SzwyWz53mUDDqbmjvJbzVCG3xhtpvY9KDsp4Yvlr9d+XEvEtfDeXInzkG6WlFrkmTkvdvB6L0lM50Vs7e4q/WSukJa8pE3lyqCBgLxDatbyLoybi3iqK8ljPcdJH7Sl45qVUrEw7RmPuDlU1Dk38Q8idzNXCy3uM51G3B8vPXPP5cpHw1lKR1wUuCo/8fQ0ZM177zJ65eYInv3qCz7HzPnfzBA9dOi/mozfqaotTyQ3j/Rx4HPsDd98yqzLWaHXljXgqXMqPzq9/KL77cK0sTE28LfGwrL/eL5I3dbvVDcBFxIrjnhrZIaldk4v49Rs3zr3+ssuX8dClW9VEnhb/s0S+31oSUcvnI4U8jzNyMVqqJv0hiOschGtlZjyWswdpBXtD4HoeFXtqonjO6RVnKWq9Y3A3CZ+3FPEXbjx3+7i7Lt8LALet9GzRZu5T3CItlD6TB+6+Zb5vLQ2/BH/f+VyLpaNeWrsgZULU64SQbxDLIgeWqSG+FFyoshB5BD0LLreqs0UNnG+MkX8uWeKchwqfvSbqnkqG2nH5/eX3KwVYG6s3c3SpGO4RT4rW2CHmNiHkg1jyRvMK+hwC7omVbvXLSqEqCVsWQE3EMyWLWvrEOdkaL11f8z1rr3ngx2uuEECvhKgtHi2bsnOiueq0rlJTrxFiXiaEfCFK4usN87LG7KU3gaXlPK9LxhJ/zRq1rGovWcBfdvkyAJzzj2ekUHuFW1rdLedy5Oc2JQR0DldKab9FaxFY2+PZ8pPj1gkhb2CpzvJWgs6erJIWv7qF5lp42eXLzWL+0KVb+NzNk9vCnX8HXBTwUbQIf56Dxyd+H/tsW0rXykiW1sicWkKQnIcWkli6n0eGRx4bIeQTWEJg5U3d8ojpsXKuGIIwFW8hKY6nCUOGC3Km5OMuvV4T8NpmpES6b2rz8cxDc8FwQW7tWiTdM62lDnri5jl7Mkb2Qgj5IOZIMBgRb+sthtUq4iPqirQIpPSPc2ri7KkTnjdDtTG0TdiSv70Fr3jL3498epALwqjw0zmY6ief+h3dchJRCHkDW/bjLXGTyUffHjGXTRi4ID506ZYqVB4RnyJuUsTzNTXBl8dpsegW2qbmlM3TEi1hi7XEoVoiU0vXp0xvmO4UMd+iAI9i/WV2Z6wRnVL7vcxUk6/xf6PndaUh4cQScQuP+wM4FZT8T75eCvVriXm/T2yGaj56/rtWN460+LXPR1tcSp8//zymFO26Ip5cJD15A9q9qP3u6Zsnt/9Zxy1B3rDdImGRd+CxzEf5z7kVbAl7aRNKY432XzURl5a1Zq3mzUqODDe0BGuKiGtoG66WRW5Z/ZzSAlfKOC3Vj+HnaOeNYkoZgK0K494IIe+kVMCqZGlMvWFr53vHr9VO4UhhL1lMoynFRz/51RNVzPNrMsa7BW1Ts/Qzv74l3C1+9JanDg2tfow8v7bQ1YqcaYyotMivKa83KvJpCdb0oR+1kE8V2K1aE6N8+Zawt6Tqt1hsWSx50SqeENOzQcqxBLMUoVLa5Lx+40bVJ14jF+vi1+CLFb8eX6isjVo+N1kMjJ8jx8loIi7/Dpw5armUBJ0fE9zhaIW8R+hq7pI5fHc9sevWPLwWTqltV+l171hZLLlYSaEpibXXbWK5b6TLQUuX9ywW2aUyQsy1+UhKgl46RkNzV2mC3nIf91jxHloS6LbCmnOL6oeN9BTH0m7K2h+9pU659xxJa/xxxhI6S5BaxCqjVSO05ujZSPXGdnvnOAWZTVqbv5Z9ap3H9x5qn9+USolTDYzWKJYtC/qcRPXDQXh81Uts4iy1c88tVinAlsVrjWWRLd0XbjyHuy7fi4/fOF8MC6gnDZVqsWTL2RLJGiUfvZy/vKZF7XNtQXsPLRmjtSe10UlD8njv92bKk+ohLgJDLHIiejeA7wPwpZTSa2rH79kiX4rSF6Gn9Gym1poL8Mc0exJz5HGWUHLx5eVlgdOaKLVa4XKzTwsN5KJqvZ7nKV0/1mLWU/ulJPKyoiOnFDdfSnpqxbOf0VMvvwVvEpvnXG2cPQv5rD07ieg7ATwP4D0h5GPwRIh4Ut2lD1OOO0rMayKuHcddLby87As3nsPJV+6I+a177lXFnKO5FvJ4mmBaC0aJ0nucWsArX7M0jlxg5qoLU3K99DTYmMpcYr53ZnWtpJQ+QkQPjhgr8NOSvQdc/DJ4fOTW5p/Xwj73e+hi/NClW/gc7sRl37rnTmnZXKWwVp1QLhwXi2Ppi8vJV57DrXvuxQs3nsP14js6nXttrB74Z6QtLvya3Nedqd0DlrXOxbulfyjgT+nvaY4tredjKajVm3+ymI+ciK4BuAYAV69eXeqym6OnjG1GS4m36lQD/mJYlhh4Gj5wwfe4G3gyzLlIFlxWxfRUlP2p+xltscnWeLb6v+GJT+DLr34tACCdibpGnlcW9DxWjdITQSuWiPeEZXrj7luqK3opCVatSNyhC3qve2pY1MqZRf474VrRGX3jTfkyyQWhp7lvKT5ZE/FsZfL639Lf3Zr52Sri0gfPRZyTBV0Tc9mAwkNLaKLlWrGid6xEIEltQ1dew+q3Wnq9FUvItWN6jj0UImplJ3iaIU+1iDQRLzXu5cfVYqxLlrgmgFpqPb+GlcFpXUsTzpL/XYp4KyU3SGkeJbQ66fJ1jrbo1TZ3M6UFOMMzZHlSFjBfdqXHyNGeZLdcxG5pQsg3hKc4/1Kpytwie5oJbUaKgcedIoXFEmqJPLZ0HZmgw8+Tm6jZ+uauFUD3z0ukiHO8Is4/Q0/oo/bZaI2kM9kVpO0VtF4nM7WW+SjXSM2/fshWucaoqJVfA/BdAK4A+CsAb08pvat0/DG6VgD75vPE79ZadU3p7i4zHGVySO5Cz4/NaJEiHClsmhXtiXrhVrWFFn3CxY6LeUa6UmQbOM1VJCNr+Hn53FokSkaL5gHqriYp5HlOfD61qJeScFvRQdqG6dq9QyWHKOZzR6384IhxDp0p2ZiefostbhhpacsvpjZWLUqmNwzv+o0b5zYQrbG5iEtB5mKshR7edfk0MsXj/y5Z1dKq53PJkS95POvzkK/lxCeOJ8yQbxDXFrgSns3sUgGxkhtmDeZq/LwHwrWyMtnP1+qH9H5hrONKr5WiFLTiTEDdx5uPAcqbk56aJZaIy+OyMFtiyq1Xfo7HNaL517l75oXqCBfRomJqewaSUsSNZY2XarXzY2qsLeZ7qI44JyHkG4Bv2vSEek31O9au6YkosYRmSm0TwGfpt/i15SKQf34Bdyxp7o7Ii1TLE4e10JTI4n8ddxYTT1XETH7i8MDH1dwlU2j1m0t60vS93ap6C99t3bLftZDvtbu8RikEq3aDyhhbrx++5oYpZXdqApIjIbyCXctk5JT84iXLs2dzUh7Hxbwl6acUBcMt9Zolr6FFC3k+O/5ZaOVyS3ijlWpNK1qs89bvr3afa9dqca+UyldcuXRr826a3Qr5IYcd1Xpj8htWS5jwpPeXomF6+0d6a69wF0vNheHd3AQubk7KMVrgYp7J4951+V68gPMRL15qx54UkpFaWuNpC1mtNjngs8azoM/ZdciqHlqitn/EvyOekMWS4bRlMd+lkC8l4mta/FZhfe3YmnvFWgykiFsCbrlQ5AYjFw65oVYKX5RjtWzglYpeSUrWPCc983/OHf8U86FnMT/5ynNVa5vz5Ve/tpiAJOfl2QBuIX+essyARFrjwJ2/WakhdQnrntU2JktIAa0Vfsu/t4p7HRq7q0fes2JvGY910PJ42rPItQh5RhPJUnJOqeZ2acxWEZcZo3KsUmgePzdfryfjk/vDLaH2YPn6PZuw8m9QSlqSix5ghxvKY2TIoSeRLeOJwJKUimhZbsfWMfm43uOX5mAzO7fw4ZaoLTreUrVTN48sNJeKRBNhTxQF3zTkG2reLjwZLV67hLYYSHHmopzdKFzwNMvaEvY8J2mh98Dfa677orlLSk80fFNW29zNY8oImZL/vNTFCTh/P85d2pbTs8nppdSLd+vsTsiXatwwF61z5624pHWujTPFB5jxRi9otUw0uIvFEgug7tMu+ZHluVzEpAhLQe8JFeRj5PncuudeV7SKzCStwUMjgfLeg3dj9uQrz92O0MnUwkJrqf2ZnvKzLZazp4JnaQxvPfI9asvuXCs1RhWPnzpOrxViFbMq1Rj3XLsUUthaElWjVP+kxxXAadnk1M7TRFwiXSY9dViymOe5yHlrrhfpG7eeGmrx7Z7iYBJr3FrTDV6F0boXR1nirf7xzCE2bD5Y14pkxKMcP3/uriJyvp5HRe5qabG+tXR8Keg9wi4teG+WJmCLdUkYOdJabRFx4I513BP3XUIuLnljVENa89qGaEZ+bvk6cgOz9lQjRdwb2aPVg+dPiXM8KVvjrVGDaKscnJAD84juXIJuRZzUUuK9ceE8I5M/QufHZZ5mnY8H2sIPObWIlIxV+wTwuzykmGcsUSzRUxFRulm0hYXPKVNyEUm3i5bReoLz/m5PzHue213wF/bSNkVrESuevZ/R+z3WeD3f2a1vekoOUsinYoXyeayO3icCLV2/9KUpVaHTjpOt1Dia9VzzYwN2Sr0U8lp3HsCXTMPxhBC24knWsYRe85sD5YXK8vfL47RreNCiYHJ0T0n4z4t3uU+qhsetl3/XG6lSOmapTcot7tEdnI98NDVRnlIIq3S+HMPTkg0ou0W4kFuNjSW1EEPZMCKjWXF8TBlR4rGCNeHSKg7KuXldJtZCwv3nXotdm68npFGGSFoLnXwC8Pb+1JBlCQC9Lj3H0+bNosWKHiHSraHLW7TKZ22+3MqehDzTKuheMW+9kTzhghytsTFHi8HWsOKygXIZ19JGXJ5LizgCF10TVtlYbxy6J3a8dRPUa8lb8enW3PK5vDE1x9NwQ8Mr4pladnBGG2fJhJ0porylUiBHs9k5F55msVpvwfxaaUw5Rul1ma7f0nhZNjbOtFT7q4W3ZcHhPS+vo77xCfT5sjUR5//N17co+a95eCAX2KkdhUrM4SY69zncuNEs7LJTUInsIqlVyMyUapevLZAWW55bJoR8AaxFQIuLraUtW2JuFTx62eXzjY1rnWNkIpC39jX3/UpfuhQUb+w1oIfkyXH5tazrlrI4v/zq11at5JZ5Au2VEM9tloox5SKjWeMcbyelTGmx9vjH+b1oLfoyJ2IPQrl1QsgbsAR5qs+O+xU9u/tczD3c3nDE+Y2sjCc9m8PrjZTQokm4++N2eCF8Rah4qGAposWywrlL6KRwLStZp6dQlkfE88Kn+fpv/+4s+/TZb/7u279viT6ReGLHP4fL5zI+PfdaqRE0TyiqWflLciiLSAj5BvD4vTVBL8WcWxUMa7U15HkyRTtb5cCpaEpB5dZ4yVqUlnJeFIC6m0WL++bim8eRMeiWeGtIa1ym4XuibLz+f3kt+aR0WxjZItgaPmhh9fzMYg6UBXh0DPdon/QhV0rNhJA3Mnfnblnkn3de4VxhvsaeolcSGUPOKYYT4vI5cQEAKJaiFRFzO9Li7Pce69yqhZI3AGXsda9/W0tK8oZK9lxTi/g5t2HcURHRG5EE2KV0eQJZRgp7vm9LPUa1czKlcF/gcCznuTg6Ie8JJxyBt054Rsu+LNEr4KVryK40Etka7MlLdyuj24/OMpolu1ykdZ7ximK2zqduIGqCNmVB0MjvT+4pAGXBrmVhWn5vbXGwfNkv3HjOvWF936Vbav/XUu1y+YR5DFbznBydkNdYMtjf80gqC/lnpLhrXXy8HXtaFgBtLh4XTQ9WVcHWJg0tkTHnXDWFrFFrzpZf3ErkqfUa9ZQzsIReq0UuN7C1BbDULSn7z3l0S6nUA1De75lbxJeqaLhmotDRCbnHNdLzB/Gc46mropV4rbXZqrVi48dPEdbSNbWYY+0pwrL+SvHpwEUfcs1Pzf3ovMxsy2Zlj2Wv1U3hY2nvRVIS6+zjL55XmRvveFSysGU539oCVquI2JuGP0fZ5kN3zQwRciJ6I4BfBHAXgP+aUvr5EePOxZRmxS2x5KVr167bWq+bH18LI2xB67gOXBRk+YXWrDGZDCSxwgEzPMqlp8iVpzCVxIqQ0Y7N0Sf8HBlpYqEJuTfyxRwXttXOXVteWuqveJHRWzIUd8uCvObcJgs5Ed0F4JcBfA+ALwL4BBH9dkrpM1PHnpup3ec1PDvuXMy5Vc7jwlvF19Mg17s41ATcqtli+fQtd0HJ/8wtY29lw5r12+piaVk0tFrfgL/htKfi4xxYtXMyrUbCHJb11O/qlheCKYywyF8H4PMppb8AACL6dQBvArB5Ic94I1FabyLLiiiJuRetrkqv5Z0pxZJ7WrDJTTFrLlq9bk/IYQuaOyaPtQTc36117ZFI4deSgoD+JxE5t9r1gfIei/zbliKe8u9b2hXO0QWLM8KPvaWU/cwIIX8FgC+wn78I4NvlQUR0DcA1ALh69eqAy/rxbHR4/iCjQw81MW9Jved4zrNcNtLPzasmAvU+mjzCodQ2TNIiSj1i3nMdjnY9j1smXy9b5lbrtQyPf9fi7PP5tWxRywevlcPVjpf1ViyssFWru9VajBTx/PMWxHyxzc6U0iMAHgFOi2bVjh/hE6ulxXsqD8rjesTca5l7KfkjPSVvS1/O2kall9z9XbPmSk0M5rSSe0S8tmBIMS8dr/nXW0ranku+Ev7rFr99C1zEW/zeHkNibTEfLeJbYoSQ/yWAV7GfX3n2u4Ol17deWiB4HfIpVnmJli/k1KYSwB0xz0hRl9ERPZQsY48IA/5qhCUh9D4deGuOl8iZndLHbm1Kap9tqbKjhlzseX18i9J922qZT/Wtew202vF7YnIZWyL6OgB/DuANOBXwTwD4NymlT5fOWaqMbavQatUG52oiYdGb8ixvfM84Vjp/Rmsq4UHLUuR10Ue0V7MsY0tAS375ksB5GzrUBNNaIGRpA2+nJU7JjWOV6JWljLWenK33ZK2UrdXRStIi6C31/Xt6erbWNB/NbGVsU0pfI6IfBfA/cBp++G5LxJekZDn3lJidgvUlKNVQGTWOdY60vrlQlFLEAQCX7q6GE+bXSpEQMkRPvpaxGhkDtrBKa1i6JrQ4b463Toq1kSrH5r5p+V64b/06gI/fOF/TvSXVXlJqzKGJOKfHqCg9TVrFsubsuZm/4/Iasvetd59siwzxkaeU3g/g/SPGmgPrw5/6h/FYAJagtvQx9I7jQXYPAs6XgtWyQzX/OvfheuC10Wsxy61dfixq1+rx03trlWvWvHxPJYv5Oq6qew6lpyS58FlPJXxjdY6YcAsuoqNEfKsiuwRHl9k5Fa/LZVSES+mpwnp0le6S0gbn1NR8LzJhiG/iAXqt8IzlgqgJfI9rZSvcFn8AT924F9cLlSRlR6RSVqzVREQuEtb+zBSR76mcuIVIlz0QQt5ArzB7LQ7vrr62IckTcUop/lpykPVoXgpD85wrkW4bXrSpFGfdKtSy9Rx3reS47ttWvnvmY5At6oC2J40c3mm9ri1OvMyA1Us1Yy3u/B6zBF0zKDwt40oVPr1sJRRwDULIG5jzJmltfeURa+0cHksuhVgrPesV8VJooVY/u7QATHWlaE2MpZiPglu62d/ekvYvsc7nbfS063vxZG9y5OJrPd1peK13XpKZ0xO90hq27OlStMUEIMnuhHwPNRckJT/glOazWnIPLy9bsqxKv2/5kns2Ojklkc+vlWgJmeOM8KlryOu3iCjf9Gxx6dTCJHk5gFv33FtcEHimaa3Y1dLMEV+uPT3L76H3mnvQmt0J+Z6QmZu1YzW0G5LXfS5Z4q2t2zRB5dmePXHtJZHOlnHOeszH9cSXy05BgE/ws+jx80poIYSaiHusco+I99Rhz58df18asgKiJeg5hl0eY1naWmRKT1isdt4ItDH3INQ1JseR97BUHPmWaKmHXPPF8y+C5Yf0xIgDdjcXTcxzpEupip5GKaqiVI/EqoQofcyeAloWLeKqiXjJP92LR8iz71tbUGqLIT9Hi1fXkCUctPjy0r3IxbM1hHYqteS9vYn4bHHkgQ/PDVNqxixvcP4l0sSbF/f3WNLex2w5nuUWMTsAMYucjyO7BGkNJbiIe+uKA/VEH884svFzzxOE9RTQklUqY8AzF1L5lfnJZhnc6m4toVxiCw2WZWmNQyaEfINYiQsWWWQ18bZ86hnrC1yq11JKIuKNLM41EIZdaY/HmGfB5ILW6/+eUnRLwhN3eudRStABztddl8JvWeLAxf0IrSyuXABkIlCpuQk/piX2m/vAW/eLahy6QHsJId8A3rDGnsQJWeNc86lrvRZrlPoy8jElWmeaUwG/eOzIZCCL1hZw/DweEWNtNHrQ4r2v485ikZtT8GNKaJvLWn10GZnkeYLTLG3PU58U8xGMFvE9BlJkQshXYo4aLfyLyL9sVsMKq8ciUP+SlprrckqWvhSS1kf6ligQzcc9pd4Lt/DlRmMt0kRDWvnXcUfMNUqJPfnJRxPz83Hkty74vb14hLuU/NMj4qX48mOOG5eEkK9ATcR70pY9iRo9yOQOrVendl2rjosHy/8sE4G4tZqxMkJHoIUiZjF/9pu/26z3rUXayLFzvZVyJqb+d5ZifvE8vccqUBdoLY68dD+UaBFzqxzFnBujeySEfKO0inmvgE/ZlColiHijW0rizq1JXlyqVABLQ3bV4ed6W8ZxLJ+2tuBky1/r0dnbc7TW8MF6oinVGS+Jd2nT0/vUVLqv5ogX37MAjyKEfKPMWQ1uTrRiXBo8flzGKpfK3EpBlhZnrXCUVXmwBS2DVL4mQ/z4HGVUTmkxyWn53/637wagW9JSiHmylvx8vB1/NHHXiqhthT37tkcRQr5BWkR8avU4zWoq9V6soQmAt8ytlS3KIza4mPPzs/gD4zdIZZ0Uz/g8xI/P8/Z8AeDMHaNZ+/x9yoWulB/Q6r6SyLBVzSrXruFJGOqxxHtqrQDzC/oWF44Q8hVoqYxYK9Df2jyiNI51bH5MtrJJtfN4yVqNlrofQDmEUFrwVuGojFeQ5fVK/m1tXvzJQPq7Zcy8HJuPcR16xA9wXlh5XRvvZ2u1Dcx/U2uRzS6yhy5dfFoA1qleuJS7ZUtunYMT8rUK3LSu0vI4TditAv2cWsq0xBOhovm/W8IUueBbERSASB5ShL8l/rvm85Zx1d4oE034PVEombx5Ceix3nmjtLRJytvnybh8+XlJt0r+fPPfVG6Mt9wjXlqao3jEXvteWcbQKKtZy8jeinhzDk7Is7W7xQ/bIs873+gje3bWKCUQaZR8qJrw5595z0lJsUkzoFqqpcgTbxXCWgs04GKGpbx2yS+e0Szrk688h3TmLnlKZFbenv8NX39PS8Qllk/busdkuGqt3HFtcfe6/1raLXqebLfiBpl7Hgcn5MA8H9bcf4jeuPJWa3zqNbz1WzLeLFA51rkYaiUUMfugs/XJBa0m5p4aLdqm5W2/txRh+EvkynhxPv79Vx5wV5S00OqOT2kMkscqJ3SdIuuwZHr3cLQ6KWsLcg9z5IxIDlLI16C3eWvLjdn6uKtZzz2PzK3iLamJSK0+uqzRIuGd5Us1WjI1vzg/T7Oa+bz4tTk9UTFyM7REqYaNRa+Iy3vFqrY5qhenfJqW3xf5uscqX1v8l7h+VD/cEJZrRasm17rR2TKGJzLCY/FZflg+dmnTTv5OqxnC0SJXrHBBDa2WiVXVsWaNl+bCe2bWKNWxKbk8Wq3xWnq+xwDwVjm0zmthlEC2NGdfe1GI6oc7IYcTekK4PIJeiyKwypF6Qs84WlGuHPFSS0KRG3gSrXpiD5a1XPJPaxuHOdFJLhzaGKVEploPznxMqSaKx21Vo1TKQaJlDns3NGtZxz3NHoB5okbkeKN6785NCPmGyDdNq4WiHV+LDpBWR15Anrl5csECr4mslQVoFdQqxSrXrFQrq7O0aXjO162cd/KV586JsSz1ypkar8192C0bl0BZDEvibVnj2lhSzEt9Yb2ulGeEgVCz7HtFfQqeRupbZ9IdSUQ/QESfJqJbRHTB3D8mHv3os0NW7lE3zZVLty7808hz1oQ/i0wW8ywW/P81ofjczfOP/LU49jwGH5f/nP9Z8dE18c+Ws8eav70heXZsFljL0r11z72uEgIv3HjuTpeeBhHPn9EzZ4st/33+p/HkV0/O/eO05BfksThP3zy58K+GvKYl1nvNbl6DqRb5pwB8P4BfGTCXwKB1U8cblsVpbULRs4kmXTlaolFJMHnNco2a+8UScSuGO6O1PZPX4gLu2aAt4el4n/G6wHIDZeD8wlCyzEt4E8lamy+3CvceLOWlmCTkKaXPAgARjZnNAdOyacJFWzuu5DbhPsNai6vaeEDb43pNiLNgWF9Wz8KQG1UAFzMOuVXORdRbCdGK4c5hjppIahujvWGEmj+efy61qKGSlS8zMEu0ujM8sejea7VcO0T8PIv5yInoGoBrAHD16tWlLrsYc2+61MilPbmYt7p6rBobnFqnodLvNf+q9WX39BnlYiVrr3g3RmuVFHvgTwcZy++ezwHqVQ5HY/muvfsvU67Xer+GiF+kKuRE9CEAL1deeltK6be8F0opPQLgEeA0/NA9ww3SkqAwase7ZZwpu/mlNPyasFt+WsmSWasaWjceTktTaY28oFw4xzHGVBGvNZeoje2p9c0LtWnhpbU651PS/0PEdapCnlK6WLX/yPGGJI2qBTF3+JPmt85obeIkpciUOQU7CxYXK0tstY3IO6F9F5NsAL0UbC2CpzZnSetYngVWay7RskB46qHIqputPnU+Xlji04nww05GJyN4imh56Z2btukkQwUBqNl91oZZq6B7s0e5mFtoSTiy3Zl2zVKHHXldqyCYPFf6u0s+axm6WXJnzeF+sf5elrXeUiirhRDwOpOEnIjeDOCXANwP4L8T0SdTSt87ZGZHTEnErS/EiOy4fF1L0OX/e6hZ517h1qxRufEpsZoO82t7QgF5tAy/rtftohUV87x3La5fLqbyKaom8D19Oq3N6h6Bb2VrWZZbYmrUyvsAvG/QXDaPJrDe6BPrmBo9X565qaVja4/etfKpUoytUMdaQ4WWaotWvW1ZU/z2OYUQSO81M96NZf4eZfSJFOW53VpLYn2HRhay23thrqi14sQS5BY/d49PfA0h916zJwlEWo9Wg2YrBM/jX66F25VEvBbtIjdKa80zejcxS++11IMT6O9yXxujds7I+7ElJ6J03iEStVZWpuUGW6O+A49W8Po016idXkMKt+UDt6iFLvIyuvL3fAwef+5picapfa6aiI+iddweEc/3jzx3ihjv3bLuJYTcyShx9fr55hZzTaxrAq65U1pFXBMIj6+Yb/7l42UEiWX1tiTVZGQ3IYmMC28p5mXFbdf+DnM2QNYW8zme/Pj43IjQvg+9eRF7bDDTS7hWGpmy4VK6EUtj8ONHfrGmlhiVIi4LK1nI5gO5UBcfp9Sx3fKT90Rv8Ot5MzE1sdYE/2KIY92CliJaqkKpbXBmWhbWtfprthR0s/AI+6EJecm1EkK+EL2bnlNibHuiXyRzCblEVl3UGB1qJxeOWhihVhMd0DNDuW98igvE66se8WS0FSEH+gT4GKJawkd+oLQ2pQXaihS1iL5HaC3B8Qo1D72zxN8ar7ZhKjcpefKR5T8vCXhv04WpPvBRPnRvswXPU6R0H2kLSGuEmOf1QyaEfCO0+vNG3LRWUwlOqwXYW0vFSjDSrH+t6p+MEdfEvNZaLmPFqJc6CPFGxPddulg+uLa34I1A8Qi0pwmEF7mJqN2vJRHXWDN09hAJIV8Ar3tE+3LMaWVYTQC4CFrNBGrukFJMs3VejxUpa6cDukVem6/0y9fEXp4n/dd5I6+2WPLXPBmx8m8izxkp4hLtfq6JuKzhspfOO3shhHxGem7UJXbaa6KSxa7UaNeLJiC12i3aZqiHUiSHJYpToz+shCUuqJoVrr13raGxdY7VfSnTK+KePIKW+1vbzByRMHes4YaSEPIZkTeqdaONzFKTaF9K7+aYFSHBX89oPuGSoPcmx5Sur7lSWt9n7XpAOcyx1LPUO27G0xLNi/zbe4TdWjxHRE9lY2VE1nNY9qeEkC/Ekhs1U2/uWjEs6/dWX9D8ekkotEWgBU+FRu0Ybz1167XWhCONkjtLe4rROvx4YtAPkVrBOY9Vv3drPoR8ZrZ6g3jqSGdahdV6z95IGe81ufhp4YTSp53rplubpvm4/N+eeHVtMfSUBNaE2hqbY7nMpM/eYvSCMMVyB6bHlVvj5Nf2LuYh5I3swSfXYo2P8qlOjYOXSCHJ86wJoEzuKaXIe4pV9ZaKLbmjStZ1DXmexLNB2uqumbJBWju3xUcOzCuyW/0OtxJC3oC86UYkIIxcGEpfiiVCvbxftmztadUQPWgC2OvWsOqw1GLQtZK0GU/seOucp1Y09HzGrS65nhwGLSprCTE/FMEuEULewCFsrHiy6oB53ysXc89cAFwQfwutw08pw7J109UTNll6D6WUewtZB35Ks+MSXOSsv7t385Ez5yZ+7drHRKToT2TJG7VGTXhbG1P0FiuqnddblpeH8XkaM2e0WieaH1gKYSmhqISWACTnncetUYsW6nma4Vh/b2DZdPdDL1Nbe5poeSqPWitHgvWlWLNBhWfzzKq3kd+XLLLlCRuUougR2xZqyTc9tWlKc50jKqWnYNVoDrUA1mhDL2qtHAmtVvTcIq5FUvRYkfmRX5tva7p67bgp8yzV2OZ4EnlqY1i0in9LlJDFSJfMobDU+wohPzAsAc/uBO+X1ir8v8RegXwklWJesn49pWJLv5fnW5+VJ3Hmypm7pTYvzxx7nmr47+aMMd/S3tFUK3iPoYgh5DOz5E3h+TJ53BdAufC/dp732j1Wp/xS1lqAWaJVeg+eaJ/e/YL82bVYydrn3Svi2us92Z7Hwl7jysNHPjNLb4aODB/TLPKW2hdyHM7oeHXr+rUxuP8duBhzLbvXtFT5k2NwtM3W0tPAUtUER7VZGzFecJGSj/z48nlXYMmb2Xstb8z3CIHI4/B/Gk/fPCn+yzz60WcnN8LmyE3UHBGT/z+/xo/l1yi9n2du6mNwWj5byyXkScDJ/2pMcZHIa/T8Lbwx5S0VRVuO3yuTXCtE9AsA/hmAvwHwvwH825TSswPmdTAsbZHM9QTgDZ+aC/l0YL3PqbHQHuQ1LJfOqCJYIxZVj4toqmUdVvjyTPWRfxDAW1NKXyOidwJ4K4Cfmj6tYCqtPr5e4WvN5PNsxlpogt4q5iVyxIo3drzmr+fjluDvuRRyaImvZ/N6xEK7hM94jvEPIYnPwzAfORG9GcC/Sim9pXas5iPfQw2TPdBjkU+xwHp6inrOadnca63zUvORl65jXatlHO/Gqed9ecIdR7FWwlBwhyXiyH8YwH8zJnANwDUAuHr16oXX88oZN8Q0ej6/lrrpFlY0RKtV1LLR1/v0IYs3Ta3Sl9GeLGqZlL1ExEkAOCxyIvoQgJcrL70tpfRbZ8e8DcDDAL4/OUz8Y4paOWRGRW9wam4Ea5xWq390xUYPowyVOTZ8a5Q+3zC+lmO2FH0i+iEA/w7AG1JKNyqHAwghb2HUU0rPF6/mpvEIuVbv3BLzVjGeytRN3NayB1sVvZaFb3SY4ZbqFW2dWVwrRPRGAD8J4B95RTxoZ4SYj/6SeL74sm5Jjt6oJRhZ/teehJxMb4MMec0pzTH2asnuaa7HyCSLnIg+D+ASgGfOfvWxlNKP1M7bq0W+1y/hHHiLc9Usck9N69a616OSZzxZr1OvM2VTeeQ96HXVRNJPmSWeLGaxyFNK3zjl/GBftBbiqtUXKd3w3iJLva4XaS2PTllvedrwPG2V3qd1bks4aYj4/okU/cBFa+r/yM04C29JAIm3JkuvW6fX9z6ibMKovYUQ8e0R9ciDSWy9XnSLH9tTDdDarJwreao0Tsmt02L5t2JZ+iHg6xH1yINh9BTTmhu5MamJtRS+Uiap5VrpFbKpNUz4+8pzli6iJXy0a/+dA50omhU001NMa6k0aSk0tUJd3oJenDVSvrVCXaW5zmWNb41jKIblJYQ8mIzXpbCkmGtz0qop9rK2mM81/l5EPDhP+MiDJvYWglmqf+JJme+theKZz5Rx51hE9vC3DGKzc/NM9W/GJtRFrNjv1kJYGnNE30xNr+/Z8Iz7Zj/EZucRoIn5MYaNjQoT5K/P7Urp/ZtY53lE/dDvhWMhLPIDQbPolrImRzOqFK/Eqt295c8jCDLhWjlSlhDzLdSmHmkxjygANXWcINAI18qRsoRr4NAEa8p+xaF9FsE+CCE/ErYkMHMkrnhrrxxL66/guIg48uAo4DHSW1rUgvU4pISiEPJgcVoTT7xfuJZqit5olUPkkAQsOCVcKxsmNs7mZe3PdOm/b08J3UPmkN57CHmweVq+cKMaSS/BEv76rYegRpu3MYSQb5i4ufvY0+e29Fy3+tkc+9PBVELIg2Bm1hSprYvj1ue3F0LIg92xhQQkL7zy45bnGeybEPKgmy34N+e+9lQfc/aDT62UGItAYBFCHkxmLmuztFAsJWpz9L4MgjkIIQ+6WVJQlxbDLcRZxwIQeJkk5ET0cwDeBOAWgC8B+KGU0v8dMbEgmEvIai6hrTeaDgLJpOqHRPQNKaUvn/3/jwH4lpTSj9TOi+qHwVaZoytQEIyiVP1wUop+FvEzXgxg+Zq4QbAQa5Xn7XHzbME1FCzH5ForRPQOIvoCgLcA+BnjuGtE9BgRPfbUU09NvWywAw6ppseeLPGlm10H61N1rRDRhwC8XHnpbSml32LHvRXA3Smlt9cuGq6V42Fv8dPH2Bov2A+zdwgioqsA3p9Sek3t2BDyYCRbiGc/dva2YO+VWXzkRPQQ+/FNAP50ynhB0Eq4D4Jgehz5zxPRN+E0/PAJANWIlSAYSViBQTBRyFNK/3LURIIg2C+xoK5LdAgKgiDYOSHkQRAEOyeEPAiCYOdE0awgEET52GBvhJAHB8fUuPIQ72BvhGslCIJg54RFHhwcXos6XCjBoRBCHhwtId7BoRCulSAIgp0TQh4EQbBzhlU/bLoo0VM4rc0yF1cAPD3j+HMT81+PPc8diPmvzdzzf3VK6X75y1WEfG6I6DGt1ONeiPmvx57nDsT812at+YdrJQiCYOeEkAdBEOycQxXyR9aewERi/uux57kDMf+1WWX+B+kjD4IgOCYO1SIPgiA4GkLIgyAIds7BCjkR/RwR/TERfZKIPkBEf3ftOXkhol8goj89m//7iOila8+pBSL6ASL6NBHdIqLdhJIR0RuJ6M+I6PNE9B/Xnk8LRPRuIvoSEX1q7bn0QESvIqIPE9Fnzu6dH197Tl6I6G4i+gMi+qOzuf/s4nM4VB85EX1DSunLZ///YwC+JaW0i+bQRPRPAPzPlNLXiOidAJBS+qmVp+WGiP4+Thty/wqA/5BSemzlKVUhorsA/DmA7wHwRQCfAPCDKaXPrDoxJ0T0nQCeB/CelNJr1p5PK0T0AIAHUkp/SERfD+BxAP9iD58/ERGAF6eUnieiFwH4fQA/nlL62FJzOFiLPIv4GS8GsJsVK6X0gZTS185+/BiAV645n1ZSSp9NKf3Z2vNo5HUAPp9S+ouU0t8A+HUAb1p5Tm5SSh8BcH3tefSSUnoypfSHZ///1wA+C+AV687KRzrl+bMfX3T2b1G9OVghBwAiegcRfQHAWwD8zNrz6eSHAfzu2pM4Al4B4Avs5y9iJ0JyaBDRgwC+DcDHV56KGyK6i4g+CeBLAD6YUlp07rsWciL6EBF9Svn3JgBIKb0tpfQqAI8C+NF1Z3ue2tzPjnkbgK/hdP6bwjP/IGiFiF4C4L0AfkI8VW+alNILKaVvxenT8+uIaFH31q7rkaeUvtt56KMA3g/g7TNOp4na3InohwB8H4A3pA1uZDR89nvhLwG8iv38yrPfBQtx5l9+L4BHU0q/ufZ8ekgpPUtEHwbwRgCLbTzv2iK3IKKH2I9vAvCna82lFSJ6I4CfBPDPU0o31p7PkfAJAA8R0d8jor8F4F8D+O2V53Q0nG0YvgvAZ1NK/3nt+bRARPfnyDIiugenG+aL6s0hR628F8A34TR64gkAP5JS2oWFRUSfB3AJwDNnv/rYXiJuAICI3gzglwDcD+BZAJ9MKX3vqpNyQET/FMB/AXAXgHenlN6x7oz8ENGvAfgunJZR/SsAb08pvWvVSTVARN8B4PcA/AlOv7MA8NMppfevNysfRPQPAPwqTu+bEwC/kVL6T4vO4VCFPAiC4Fg4WNdKEATBsRBCHgRBsHNCyIMgCHZOCHkQBMHOCSEPgiDYOSHkQRAEOyeEPAiCYOf8f8FBDlvUS2HtAAAAAElFTkSuQmCC\n", 328 | "text/plain": [ 329 | "
" 330 | ] 331 | }, 332 | "metadata": { 333 | "needs_background": "light" 334 | }, 335 | "output_type": "display_data" 336 | } 337 | ], 338 | "source": [ 339 | "# plot the final step\n", 340 | "x, y = out[-1][:, 0], out[-1][:, 1]\n", 341 | "fig, ax = plt.subplots()\n", 342 | "# ax.scatter(x, y)\n", 343 | "sns.kdeplot(x, y, fill=True, ax=ax, bw_adjust=0.2)\n", 344 | "ax.set_xlim(-3.5, 3.5)\n", 345 | "ax.set_ylim(-3.5, 3.5)\n", 346 | "plt.show()\n", 347 | "\n", 348 | "# ax.set_xlim(-3, 3)\n", 349 | "# plt.scatter(x, y)\n", 350 | "# # plt.contour(out[-1])\n", 351 | "# sns.kdeplot(x, y, fill=True)" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 10, 357 | "id": "0bdd9800-f07a-4329-ad4d-c9415dfec2b8", 358 | "metadata": {}, 359 | "outputs": [], 360 | "source": [ 361 | "import os\n", 362 | "\n", 363 | "def plot_density(x, y, i=0, out_dir='pngs'):\n", 364 | " fig, ax = plt.subplots()\n", 365 | "# ax.scatter(x, y)\n", 366 | " sns.kdeplot(x=x, y=y, fill=True, ax=ax, bw_adjust=0.2)\n", 367 | " ax.set_xlim(-3.5, 3.5)\n", 368 | " ax.set_ylim(-3.5, 3.5)\n", 369 | " i_str = \"{}\".format(i).zfill(3)\n", 370 | " out_path = os.path.join(out_dir, f\"fig_{i_str}.png\")\n", 371 | " plt.savefig(out_path, dpi=200, bbox_inches='tight')\n", 372 | " plt.close(fig)" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": 11, 378 | "id": "444fa390-050c-4705-bec0-eb34f09e57b2", 379 | "metadata": {}, 380 | "outputs": [ 381 | { 382 | "name": "stderr", 383 | "output_type": "stream", 384 | "text": [ 385 | "100%|██████████| 200/200 [06:58<00:00, 2.09s/it]\n" 386 | ] 387 | } 388 | ], 389 | "source": [ 390 | "for i in tqdm(range(0, nb_steps)):\n", 391 | " x, y = out[i][:, 0], out[i][:, 1]\n", 392 | " plot_density(x, y, i)" 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "execution_count": 29, 398 | "id": "65414b45-5e93-421c-aa78-87df22dc3a83", 399 | "metadata": {}, 400 | "outputs": [ 401 | { 402 | "data": { 403 | "text/plain": [ 404 | "tensor([0])" 405 | ] 406 | }, 407 | "execution_count": 29, 408 | "metadata": {}, 409 | "output_type": "execute_result" 410 | } 411 | ], 412 | "source": [ 413 | "torch.randint(0, 1, size=(1,))" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": null, 419 | "id": "f5d16e75-558c-46e5-abdd-82164dda7eb3", 420 | "metadata": {}, 421 | "outputs": [], 422 | "source": [] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": null, 427 | "id": "13beb3ec-4d24-4b3e-ad2b-8c2bbcd8a920", 428 | "metadata": {}, 429 | "outputs": [], 430 | "source": [] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": null, 435 | "id": "c48e6eeb-df69-4ec3-8451-b8fe132f7826", 436 | "metadata": {}, 437 | "outputs": [], 438 | "source": [] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": null, 443 | "id": "276a9cc3-17bb-4b69-b931-54a2518c4e93", 444 | "metadata": {}, 445 | "outputs": [], 446 | "source": [] 447 | } 448 | ], 449 | "metadata": { 450 | "kernelspec": { 451 | "display_name": "Python 3", 452 | "language": "python", 453 | "name": "python3" 454 | }, 455 | "language_info": { 456 | "codemirror_mode": { 457 | "name": "ipython", 458 | "version": 3 459 | }, 460 | "file_extension": ".py", 461 | "mimetype": "text/x-python", 462 | "name": "python", 463 | "nbconvert_exporter": "python", 464 | "pygments_lexer": "ipython3", 465 | "version": "3.8.5" 466 | } 467 | }, 468 | "nbformat": 4, 469 | "nbformat_minor": 5 470 | } 471 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | target-version = ['py38'] 4 | include = '\.pyi?$' 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # popgen dependency 2 | # -e git+git://github.com/Popgun-Labs/PopGen.git#popgen 3 | 4 | # note: for GPU version, install via conda 5 | torch>=1.9 6 | torchvision~=0.9.1 7 | 8 | tqdm>=4.46.0 9 | numpy>=1.16.1 10 | seaborn 11 | hydra-core==0.11.3 12 | wandb~=0.10.21 13 | omegaconf==1.4.1 14 | black~=19.10b0 15 | setuptools~=50.3.2 16 | popgen~=0.0.2 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from setuptools import setup, find_packages 3 | from diffuse import __version__ 4 | 5 | readme = open("README.md").read() 6 | requirements_txt = open("requirements.txt").read().split("\n") 7 | requirements = list(filter(lambda x: "--extra" not in x and x is not "", requirements_txt)) 8 | 9 | dependency_links = list(filter(lambda x: "--extra" in x, requirements_txt)) 10 | dependency_links = list(map(lambda x: x.split(" ")[-1], dependency_links)) 11 | 12 | setup( 13 | # Metadata 14 | name="diffuse", 15 | version=__version__, 16 | author="Angus Turner", 17 | author_email="angusturner27@gmail.com", 18 | url="https://github.com/angusturner/phasenet", 19 | description="Experiments in Diffusion Probabilistic Modelling", 20 | long_description=readme, 21 | packages=find_packages(exclude=("test",)), 22 | zip_safe=True, 23 | install_requires=requirements, 24 | dependency_links=dependency_links, 25 | include_package_data=True, 26 | ) 27 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from omegaconf import DictConfig 3 | 4 | from popgen.setup import setup_worker, setup_loaders 5 | 6 | 7 | @hydra.main(config_path="config/config.yaml") 8 | def train(cfg: DictConfig) -> None: 9 | # get the experiment name 10 | name = cfg.get("name", False) 11 | if not name: 12 | raise Exception("Must specify experiment name on CLI. e.g. `python train.py name=vae ...`") 13 | 14 | # setup the worker 15 | overwrite = cfg.get("overwrite", False) 16 | worker, cfg = setup_worker(name, cfg, overwrite=overwrite, module="diffuse") 17 | 18 | # setup data loaders 19 | train_loader, test_loader = setup_loaders( 20 | dataset_class=cfg["dataset_class"], data_opts=cfg["dataset"], loader_opts=cfg["loader"], module="diffuse" 21 | ) 22 | 23 | # train 24 | worker.run(train_loader, test_loader, cfg["nb_epoch"]) 25 | 26 | 27 | if __name__ == "__main__": 28 | train() 29 | --------------------------------------------------------------------------------