├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── config ├── config.yaml ├── dataset │ ├── mnist.yaml │ └── vctk.yaml ├── loader │ └── basic.yaml ├── model │ └── mnist_score.yaml └── worker │ └── mnist_worker.yaml ├── diffuse ├── __init__.py ├── datasets │ ├── __init__.py │ └── mnist.py ├── models │ ├── __init__.py │ ├── components │ │ ├── __init__.py │ │ ├── conv_glu.py │ │ ├── unet_parts.py │ │ └── utils.py │ ├── diffusion_process.py │ └── mnist_score.py ├── utils │ └── __init__.py └── workers │ ├── __init__.py │ └── mnist_worker.py ├── notebooks ├── diffusion.ipynb ├── jax.ipynb └── visualise_diffusion.ipynb ├── pyproject.toml ├── requirements.txt ├── setup.py └── train.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # prevent notebooks from being counted in the language statistics on GitHub 2 | notebooks/* linguist-documentation -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | notebooks/.ipynb_checkpoints/* 3 | .DS_Store 4 | wandb/* 5 | outputs/* 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | env/ 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # dotenv 89 | .env 90 | 91 | # virtualenv 92 | .venv 93 | venv/ 94 | ENV/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | 109 | #intellij 110 | .idea/ 111 | 112 | # vscode 113 | .vscode/ 114 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Angus Turner 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diffusion Experiments 2 | 3 | An educational implementation of [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239), 4 | with corresponding [blog post](https://angusturner.github.io/generative_models/2021/06/29/diffusion-probabilistic-models-I.html). 5 | 6 | Includes: 7 | - A toy U-Net Model, which can be fit to MNIST - `notebooks/diffusion.ipynb` 8 | - Notebook used for blog visualisations - `notebooks/visualize_diffusion.ipynb` 9 | 10 | This repo is a WIP. I hope to add some more substantial experiments soon. 11 | 12 | ### Requirements 13 | 14 | Required: 15 | - Python >= 3.7 16 | - PyTorch >= 1.7 17 | 18 | Recommended: 19 | - Linux and CUDA 20 | 21 | ### Install 22 | 23 | ```shell 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | Uses the [PopGen](https://github.com/Popgun-Labs/PopGen) framework to manage experiments. -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | name: null 2 | model: null 3 | dataset: null 4 | worker: null 5 | nb_epoch: 1000000 6 | overwrite: false 7 | wandb: 8 | project: denoising 9 | defaults: 10 | - loader: basic 11 | -------------------------------------------------------------------------------- /config/dataset/mnist.yaml: -------------------------------------------------------------------------------- 1 | dataset_class: MNIST 2 | dataset: 3 | both: 4 | data_dir: "/home/angusturner/datasets/" 5 | train: 6 | train: true 7 | test: 8 | train: false -------------------------------------------------------------------------------- /config/dataset/vctk.yaml: -------------------------------------------------------------------------------- 1 | dataset_class: AudioDataset 2 | dataset: 3 | both: 4 | formats: ['*/'] 5 | sr: 16000 6 | 7 | train: 8 | paths: ['/home/angusturner/datasets/vctk_16k/train'] 9 | patch_size: 16000 10 | test: 11 | paths: ['/home/angusturner/datasets/vctk_16k/test'] 12 | patch_size: 16000 -------------------------------------------------------------------------------- /config/loader/basic.yaml: -------------------------------------------------------------------------------- 1 | # Good standard settings for the PyTorch DataLoader class. Note that 2 | # the batch_size should probably be set from the CLI for the specific experiment. 3 | loader: 4 | # settings for both train and validation 5 | both: 6 | num_workers: 1 7 | pin_memory: false 8 | shuffle: true 9 | batch_size: 32 10 | # specific settings 11 | train: 12 | drop_last: true 13 | test: 14 | drop_last: false 15 | -------------------------------------------------------------------------------- /config/model/mnist_score.yaml: -------------------------------------------------------------------------------- 1 | model_class: MnistScore 2 | model: 3 | nb_timesteps: 250 4 | hidden: 128 5 | c_dim: 64 6 | input_dim: 1 7 | -------------------------------------------------------------------------------- /config/worker/mnist_worker.yaml: -------------------------------------------------------------------------------- 1 | worker_class: MnistWorker 2 | worker: 3 | diffusion_settings: 4 | nb_timesteps: 200 5 | start: 1e-4 6 | end: 0.05 7 | optim_class: Adam 8 | optim_settings: 9 | lr: 2e-4 10 | weight_decay: 0. 11 | -------------------------------------------------------------------------------- /diffuse/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.0" 2 | -------------------------------------------------------------------------------- /diffuse/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .mnist import MNIST 2 | -------------------------------------------------------------------------------- /diffuse/datasets/mnist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import transforms, datasets 3 | 4 | 5 | class MNIST: 6 | mean = 0.1307 7 | std = 0.3081 8 | default_transforms = transforms.Compose( 9 | [ 10 | transforms.ToTensor(), 11 | # transforms.Normalize((mean,), (std,)) 12 | ] 13 | ) 14 | 15 | def __init__(self, data_dir, train: bool = True): 16 | """ 17 | Return continuous MNIST digits, scaled in [-1, 1] 18 | :param data_dir: location to save 19 | :param train: 20 | """ 21 | self.dataset = datasets.MNIST(data_dir, train=train, transform=MNIST.default_transforms, download=True) 22 | 23 | def __len__(self): 24 | return len(self.dataset) 25 | 26 | @staticmethod 27 | def denormalize(im): 28 | """ 29 | :param im: (B, C, H, W) 30 | """ 31 | return (im + 1) / 2.0 32 | # return (im * MNIST.std) + MNIST.mean 33 | 34 | def __getitem__(self, idx): 35 | x, y = self.dataset[idx] 36 | x = np.array(x).astype(np.float32) 37 | x = (x * 2.0) - 1 38 | 39 | return x, y 40 | -------------------------------------------------------------------------------- /diffuse/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mnist_score import MnistScore 2 | from .diffusion_process import DiffusionProcess 3 | -------------------------------------------------------------------------------- /diffuse/models/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angusturner/diffuse/cb0d8dfdd4c07dc09ad167904377109a973b15dc/diffuse/models/components/__init__.py -------------------------------------------------------------------------------- /diffuse/models/components/conv_glu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ConvGLU(nn.Module): 6 | def __init__(self, channels, c_dim=None, dilation=1, kw=3): 7 | """ 8 | Convolution with GLU activation and (optional) global conditioning 9 | :param channels 10 | :param c_dim: 11 | :param dilation: 12 | :param kw: 13 | """ 14 | super().__init__() 15 | 16 | # TODO: try without batch-norm? 17 | self.bn = nn.BatchNorm2d(channels) 18 | 19 | # main op. + parameterised residual connection 20 | padding = (kw - 1) * dilation // 2 21 | self.conv = nn.Conv2d(channels, channels, kernel_size=kw, dilation=dilation, padding=padding) 22 | self.up1x1 = nn.Conv2d(channels // 2, channels, 1) 23 | 24 | self.rescale = 0.5 ** 0.5 25 | 26 | if c_dim is not None: 27 | self.c_proj = nn.Linear(c_dim, channels) 28 | 29 | def forward(self, x, c=None): 30 | """ 31 | :param x: (B, C H, W) 32 | :param c: optional global conditioning (batch, features) 33 | """ 34 | h = self.bn(x) 35 | h = self.conv(h) 36 | a, b = torch.chunk(h, 2, 1) # (B, C // 2, H, W) 37 | 38 | if c is not None: 39 | assert hasattr(self, "c_proj"), "Oops, conditioning dim not specified!" 40 | batch = x.shape[0] 41 | c_proj = self.c_proj(c) 42 | c_a, c_b = torch.chunk(c_proj, 2, -1) # (B, C // 2) 43 | c_a = c_a.reshape(batch, -1, 1, 1) # (B, C // 2, H=1, W=1) 44 | c_b = c_b.reshape(batch, -1, 1, 1) 45 | a = (a + c_a) * self.rescale 46 | b = (b + c_b) * self.rescale 47 | 48 | # main op + residual, and re-scale to preserve variance 49 | out = torch.sigmoid(a) * b 50 | out = self.up1x1(out) 51 | out = self.rescale * (out + x) # (B, C, H, W) 52 | 53 | return out 54 | -------------------------------------------------------------------------------- /diffuse/models/components/unet_parts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from diffuse.models.components.conv_glu import ConvGLU 5 | 6 | 7 | class DownsampleBlock(nn.Module): 8 | def __init__(self, hidden, main_op=ConvGLU): 9 | """ 10 | Halve the spatial dimensions and double the channels 11 | :param hidden: 12 | """ 13 | super(DownsampleBlock, self).__init__() 14 | 15 | self.down = nn.Conv2d(hidden, hidden * 2, kernel_size=2, stride=2) 16 | self.conv = main_op(hidden * 2) 17 | 18 | def forward(self, x, c=None): 19 | down = self.down(x) 20 | return self.conv(down, c) 21 | 22 | 23 | class UpsampleBlock(nn.Module): 24 | def __init__(self, hidden, main_op=ConvGLU): 25 | super(UpsampleBlock, self).__init__() 26 | 27 | self.up = nn.ConvTranspose2d(hidden, hidden // 2, kernel_size=2, stride=2) 28 | self.conv1 = nn.Conv2d(hidden, hidden // 2, 1) 29 | self.main = main_op(hidden // 2) 30 | 31 | def forward(self, x1, x2, c=None): 32 | x1 = self.up(x1) 33 | feats = torch.cat((x1, x2), dim=1) 34 | feats = self.conv1(feats) 35 | return self.main(feats, c) 36 | -------------------------------------------------------------------------------- /diffuse/models/components/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | 5 | 6 | def unsqueeze_as(x: torch.Tensor, shape: Tuple[int, ...]) -> torch.Tensor: 7 | """ 8 | Add trailing dimensions onto `x` until it matches the rank of a second 9 | tensor whose shape is given. 10 | e.g) 11 | x = torch.randn(3) 12 | y = torch.randn(3, 4, 1) 13 | x = unsqueeze_as(x, y.shape) # (3, 1, 1) 14 | :param x: 15 | :param shape: 16 | :return: 17 | """ 18 | extra_dims = len(shape) - x.dim() 19 | return x[(...,) + (None,) * extra_dims] 20 | -------------------------------------------------------------------------------- /diffuse/models/diffusion_process.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from diffuse.models.components.utils import unsqueeze_as 8 | 9 | 10 | class DiffusionProcess(nn.Module): 11 | def __init__(self, nb_timesteps: int = 50, start: float = 1e-4, end: float = 0.05): 12 | """ 13 | Implements the diffusion process presented in DDPM. 14 | No learnable parameters, but relies on a trained 'score' model 15 | for sampling. 16 | :param nb_timesteps: 17 | :param start: 18 | :param end: 19 | """ 20 | super().__init__() 21 | 22 | self.nb_timesteps = nb_timesteps 23 | 24 | # beta = likelihood variance q(x_t | x_t-1) 25 | beta = torch.linspace(start, end, nb_timesteps) 26 | alpha = 1.0 - beta 27 | alpha_hat = alpha.cumprod(dim=0) 28 | 29 | # variance of conditional prior, q(x_t|x_0) 30 | # = N(x_t ; sqrt(alpha_hat) * x_0, prior_variance) 31 | prior_variance = 1.0 - alpha_hat 32 | 33 | # forward process posterior variance (beta_hat) corresponding to q(x_t-1 | x_t, x_0) 34 | alpha_hat_t_1 = F.pad(alpha_hat, (1, 0))[:-1] 35 | posterior_variance = (1 - alpha_hat_t_1) * beta / (1 - alpha_hat) 36 | posterior_variance[0] = beta[0] 37 | 38 | for (name, tensor) in [ 39 | ("beta", beta), 40 | ("alpha", alpha), 41 | ("alpha_hat", alpha_hat), 42 | ("prior_variance", prior_variance), 43 | ("posterior_variance", posterior_variance), 44 | ]: 45 | self.register_buffer(name, tensor) 46 | 47 | def sample_t(self, batch_size=1, device=None): 48 | """ 49 | Sample a random time-step for each batch item. 50 | """ 51 | return torch.randint(0, self.nb_timesteps, size=(batch_size,), device=device) 52 | 53 | def sample_q(self, x0: torch.Tensor, eps: torch.Tensor, t: torch.Tensor): 54 | """ 55 | The "forward process". Given the data point x_0, we can sample 56 | any latent x_t from q(x_t|x_0) 57 | :param x0: the initial data point (batch, *) 58 | :param eps: noise samples from N(0, I) (batch, *) 59 | :param t: the timesteps in [0, nb_timesteps] (batch) 60 | """ 61 | assert (t >= 0).all() and (t < self.nb_timesteps).all(), "Invalid time step" 62 | 63 | alpha_hat_t = unsqueeze_as(self.alpha_hat[t], x0.shape) # (batch) 64 | return alpha_hat_t.sqrt() * x0 + (1.0 - alpha_hat_t).sqrt() * eps 65 | 66 | def sample_p(self, x_t: torch.Tensor, eps_hat: torch.Tensor, t: torch.Tensor, greedy: bool = False): 67 | """ 68 | The "reverse process". Given a latent `x_t`, draw a sample from p(x_t-1|x_t) 69 | using the noise prediction. 70 | :param x_t: the previous sample (batch, *) 71 | :param eps_hat: the noise, predicted by neural net (batch, *) 72 | :param t: the time step (batch) 73 | :param greedy: use the mean 74 | """ 75 | 76 | alpha_t = unsqueeze_as(self.alpha[t], x_t.shape) 77 | beta_t = unsqueeze_as(self.beta[t], x_t.shape) 78 | alpha_hat_t = unsqueeze_as(self.alpha_hat[t], x_t.shape) 79 | 80 | # calculate the mean 81 | mu = x_t - ((beta_t * eps_hat) / (1.0 - alpha_hat_t).sqrt()) 82 | mu = (1.0 / alpha_t.sqrt()) * mu 83 | 84 | if greedy: 85 | return mu 86 | 87 | # sample 88 | std = unsqueeze_as(self.posterior_variance[t].sqrt(), x_t.shape) 89 | x_next = mu + std * torch.randn_like(mu) 90 | 91 | return x_next 92 | 93 | @torch.no_grad() 94 | def generate( 95 | self, shape: Tuple[int, ...], score_model: nn.Module, return_freq: int = 20 96 | ) -> Tuple[torch.Tensor, List[torch.Tensor]]: 97 | """ 98 | Generate a batch of samples. Returns intermediate values as well. 99 | :param shape: a tuple indicating the shape of the data 100 | :param score_model: trained pytorch model, which produces the update 101 | :param return_freq: how often to accumulate intermediate values 102 | """ 103 | device = list(score_model.parameters())[0].device 104 | batch, *_ = shape 105 | x_T = torch.randn(shape, device=device) 106 | 107 | x = x_T 108 | 109 | out = [x_T.cpu()] 110 | for t in range(self.nb_timesteps - 1, -1, -1): 111 | t_ = torch.full((batch,), t, dtype=torch.long, device=device) 112 | 113 | eps_hat = score_model(x, t_) 114 | x = self.sample_p(x, eps_hat, t_, greedy=t == 0) 115 | 116 | if (t + 1) % return_freq == 0: 117 | out.append(x.cpu()) 118 | 119 | return x, out 120 | -------------------------------------------------------------------------------- /diffuse/models/mnist_score.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from diffuse.models.components.conv_glu import ConvGLU 7 | from diffuse.models.components.unet_parts import DownsampleBlock, UpsampleBlock 8 | 9 | 10 | class MnistScore(nn.Module): 11 | def __init__(self, input_dim=1, hidden=64, c_dim=64, nb_timesteps=50): 12 | """ 13 | A simple U-Net based model, tailored to the dimensions of MNIST. 14 | Predicts `epsilon` required to invert the forward process. 15 | (Equivalently, can be considered as the 'score' network in NCSN) 16 | :param input_dim: 17 | :param hidden: 18 | :param c_dim: 19 | :param nb_timesteps: 20 | """ 21 | super().__init__() 22 | 23 | # create positional embeddings (Vaswani et al, 2018) 24 | dims = torch.arange(c_dim // 2).unsqueeze(0) # (1, c_dim // 2) 25 | steps = torch.arange(nb_timesteps).unsqueeze(1) # (nb_timesteps, 1) 26 | first_half = torch.sin(steps * 10.0 ** (dims * 4.0 / (c_dim // 2 - 1))) 27 | second_half = torch.cos(steps * 10.0 ** (dims * 4.0 / (c_dim // 2 - 1))) 28 | diff_embedding = torch.cat((first_half, second_half), dim=1) # (nb_timesteps, c_dim) 29 | self.register_buffer("diff_embedding", diff_embedding) 30 | 31 | # define the main convolution op. 32 | op = partial(ConvGLU, c_dim=c_dim, kw=3) 33 | 34 | self.init1 = nn.Conv2d(input_dim, hidden, 3, padding=1) 35 | self.init2 = op(hidden) 36 | 37 | self.down1 = DownsampleBlock(hidden, op) # 14x14 38 | self.down2 = DownsampleBlock(hidden * 2, op) # 7x7 39 | self.up1 = UpsampleBlock(hidden * 4, op) # 14x14 40 | self.up2 = UpsampleBlock(hidden * 2, op) # 28x28 41 | 42 | self.out = nn.Conv2d(hidden, 1, 1) 43 | 44 | def forward(self, x, t): 45 | """ 46 | Produces an estimate for the noise term `epsilon`. 47 | :param x: (batch, 1, H, W) torch.float 48 | :param t: (batch) torch.int 49 | """ 50 | 51 | # get the conditioning for this time-step 52 | c = self.diff_embedding[t] # (batch, c_dim) 53 | 54 | # initial channel up-sampling 55 | x1 = self.init1(x) 56 | x1 = self.init2(x1) 57 | 58 | # u-net 59 | x2 = self.down1(x1, c) 60 | x3 = self.down2(x2, c) 61 | x = self.up1(x3, x2, c) 62 | x = self.up2(x, x1, c) 63 | 64 | # output 65 | return self.out(x) 66 | -------------------------------------------------------------------------------- /diffuse/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angusturner/diffuse/cb0d8dfdd4c07dc09ad167904377109a973b15dc/diffuse/utils/__init__.py -------------------------------------------------------------------------------- /diffuse/workers/__init__.py: -------------------------------------------------------------------------------- 1 | from .mnist_worker import MnistWorker 2 | -------------------------------------------------------------------------------- /diffuse/workers/mnist_worker.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import numpy as np 8 | from tqdm import tqdm 9 | import wandb 10 | from wandb.wandb_run import Run 11 | 12 | from diffuse.datasets import MNIST 13 | from diffuse.models import DiffusionProcess 14 | from popgen.workers.abstract_worker import AbstractWorker 15 | 16 | 17 | class MnistWorker(AbstractWorker): 18 | def __init__( 19 | self, 20 | exp_name: str, 21 | model: nn.Module, 22 | run_dir: str, 23 | run: Optional[Run], 24 | diffusion_settings: Dict, 25 | optim_class: str, 26 | optim_settings: Dict, 27 | *args, 28 | **kwargs, 29 | ): 30 | """ 31 | :param exp_name: 32 | :param model: 33 | :param run_dir: 34 | :param wandb: 35 | :param diffusion_settings: 36 | :param optim_class: 37 | :param optim_settings: 38 | :param args: 39 | :param kwargs: 40 | """ 41 | super(MnistWorker, self).__init__(exp_name, model, run_dir, run, *args, **kwargs) 42 | 43 | self.model = model 44 | 45 | # define the diffusion process 46 | self.diffusion = DiffusionProcess(**diffusion_settings) 47 | 48 | # setup the optimiser 49 | self.params = [p for p in model.parameters() if p.requires_grad] 50 | optim_class = getattr(torch.optim, optim_class) 51 | self.optim = optim_class(self.params, **optim_settings) 52 | 53 | # register the optimiser state, to include in the checkpoints 54 | self.register_state(self.optim, "optim") 55 | 56 | # put everything on GPU if available 57 | if torch.cuda.is_available(): 58 | self.diffusion.cuda() 59 | self.cuda() 60 | 61 | # track the number of iterations 62 | self.iterations = {"train": 0, "test": 0} 63 | 64 | # load existing checkpoint 65 | self.load(checkpoint_id="best") 66 | 67 | # train / evaluation logic 68 | def main(self, loader, train=True): 69 | losses = [] 70 | for i, (x0, _) in enumerate(tqdm(loader)): 71 | if train: 72 | self.optim.zero_grad() 73 | 74 | # put features on GPU 75 | x0 = x0.float().cuda() 76 | 77 | # sample x_t ~ q(x_t|x0), for a random step t ~ {0..nb_timesteps} 78 | eps = torch.randn_like(x0) 79 | t = self.diffusion.sample_t(eps.shape[0], device=eps.device) 80 | x_t = self.diffusion.sample_q(x0, eps, t) 81 | 82 | # predict the noise 83 | eps_hat = self.model(x_t, t) 84 | loss = F.mse_loss(eps_hat, eps, reduction="mean") 85 | 86 | # DEBUG - check for NaN values 87 | if torch.isnan(loss).any(): 88 | raise Exception("NaN :(") 89 | 90 | losses.append(loss.item()) 91 | 92 | if train: 93 | loss.backward() 94 | self.optim.step() 95 | 96 | self._plot_loss({"MSE": loss.item()}, train=train) 97 | 98 | if i % 500 == 0 and not train: 99 | self._plot_sample() 100 | 101 | return (np.mean(losses),) 102 | 103 | def train(self, loader): 104 | self.model.train() 105 | return self.main(loader, train=True) 106 | 107 | @torch.no_grad() 108 | def evaluate(self, loader): 109 | self.model.eval() 110 | return self.main(loader, train=False) 111 | 112 | @torch.no_grad() 113 | def _plot_sample(self): 114 | x_gen, _ = self.diffusion.generate((1, 1, 28, 28), self.model) 115 | x_gen = MNIST.denormalize(x_gen).clamp(0, 1) 116 | x_np = x_gen.view(28, 28).cpu().numpy() 117 | self.wandb.log({"Forward Process Sample": wandb.Image(x_np)}) 118 | -------------------------------------------------------------------------------- /notebooks/visualise_diffusion.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "c753f8cd-32b7-478f-a869-70bb3452a535", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import torch.nn as nn\n", 12 | "import torch.nn.functional as F\n", 13 | "import torch.utils.data\n", 14 | "\n", 15 | "import matplotlib.pyplot as plt\n", 16 | "%matplotlib inline\n", 17 | "import seaborn as sns\n", 18 | "\n", 19 | "import numpy as np\n", 20 | "import os\n", 21 | "from functools import partial\n", 22 | "\n", 23 | "from tqdm import tqdm as tqdm\n", 24 | "\n", 25 | "# wandb for plotting\n", 26 | "import wandb\n", 27 | "\n", 28 | "# PopGen for data loader and worker utils.\n", 29 | "from popgen.setup import setup_loaders, setup_config\n", 30 | "from popgen.workers.abstract_worker import AbstractWorker\n", 31 | "\n", 32 | "from sklearn.datasets import make_s_curve, make_swiss_roll" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "id": "4d66f118-e891-4a5d-bcfa-f542803cc605", 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stdout", 43 | "output_type": "stream", 44 | "text": [ 45 | "(5000, 2)\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "X, color = make_swiss_roll(n_samples=5000, noise=0.3)\n", 51 | "X = X / 4\n", 52 | "X = np.stack((X[:, 0], X[:, 2]), axis=1)\n", 53 | "x, y = X[:, 0], X[:, 1]\n", 54 | "print(X.shape)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 3, 60 | "id": "9b67472a-ca64-4093-b5c1-9412fae5ca9b", 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "name": "stderr", 65 | "output_type": "stream", 66 | "text": [ 67 | "/home/angusturner/miniconda3/envs/py38/lib/python3.8/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variable as a keyword arg: y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.\n", 68 | " warnings.warn(\n" 69 | ] 70 | }, 71 | { 72 | "data": { 73 | "text/plain": [ 74 | "" 75 | ] 76 | }, 77 | "execution_count": 3, 78 | "metadata": {}, 79 | "output_type": "execute_result" 80 | }, 81 | { 82 | "data": { 83 | "image/png": "\n", 84 | "text/plain": [ 85 | "
" 86 | ] 87 | }, 88 | "metadata": { 89 | "needs_background": "light" 90 | }, 91 | "output_type": "display_data" 92 | } 93 | ], 94 | "source": [ 95 | "sns.kdeplot(x, y, fill=True, bw_adjust=0.2)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 4, 101 | "id": "8053e264-78d2-44bf-9b43-6eec3354cb3e", 102 | "metadata": {}, 103 | "outputs": [ 104 | { 105 | "data": { 106 | "text/plain": [ 107 | "" 108 | ] 109 | }, 110 | "execution_count": 4, 111 | "metadata": {}, 112 | "output_type": "execute_result" 113 | }, 114 | { 115 | "data": { 116 | "image/png": "\n", 117 | "text/plain": [ 118 | "
" 119 | ] 120 | }, 121 | "metadata": { 122 | "needs_background": "light" 123 | }, 124 | "output_type": "display_data" 125 | } 126 | ], 127 | "source": [ 128 | "plt.figure(figsize=(4, 4))\n", 129 | "plt.scatter(x, y)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 5, 135 | "id": "4aeb9b09-8e59-45d0-8d99-5306af049b73", 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "class DiffusionProcess(nn.Module):\n", 140 | " def __init__(self, nb_timesteps=50, start=1e-4, end=0.05):\n", 141 | " super().__init__()\n", 142 | " \n", 143 | " self.nb_timesteps = nb_timesteps\n", 144 | " \n", 145 | " # beta = likelihood variance q(x_t | x_t-1)\n", 146 | " beta = torch.linspace(start, end, nb_timesteps)\n", 147 | " alpha = 1. - beta\n", 148 | " alpha_hat = alpha.cumprod(dim=0)\n", 149 | " \n", 150 | " # q(x_t|x_0) = N(x_t ; sqrt(alpha_hat) * x_0, forward_variance)\n", 151 | " prior_variance = (1. - alpha_hat)\n", 152 | " \n", 153 | " # forward process posterior variance (beta_hat) corresponding to (q(x_t-1 | x_t, x_0)\n", 154 | " alpha_hat_t_1 = F.pad(alpha_hat, (1, 0))[:-1]\n", 155 | " posterior_variance = (1 - alpha_hat_t_1) * beta / (1 - alpha_hat)\n", 156 | " posterior_variance[0] = beta[0]\n", 157 | " \n", 158 | " for (name, tensor) in [\n", 159 | " (\"beta\", beta),\n", 160 | " (\"alpha\", alpha),\n", 161 | " (\"alpha_hat\", alpha_hat),\n", 162 | " (\"prior_variance\", prior_variance),\n", 163 | " (\"posterior_variance\", posterior_variance)\n", 164 | " ]:\n", 165 | " self.register_buffer(name, tensor)\n", 166 | " \n", 167 | " def sample_t(self, batch_size=1, device=None):\n", 168 | " \"\"\"\n", 169 | " Sample a random timestep for each batch item.\n", 170 | " \"\"\"\n", 171 | " return torch.randint(0, self.nb_timesteps, size=(batch_size,), device=device)\n", 172 | " \n", 173 | " def sample_q(self, x0, eps, t):\n", 174 | " \"\"\"\n", 175 | " The \"forward process\". Given the data point x_0, we can sample\n", 176 | " any latent x_t from q(x_t|x_0)\n", 177 | " :param x0: the initial data point (batch, *)\n", 178 | " :param eps: noise samples from N(0, I) (batch, *)\n", 179 | " :param t: the timesteps in [0, nb_timesteps] (batch)\n", 180 | " \"\"\"\n", 181 | " assert (t >= 0).all() and (t < self.nb_timesteps).all(), \"Invalid timestep\"\n", 182 | " \n", 183 | " alpha_hat_t = self.alpha_hat[t, None] # (batch, 1, 1, 1)\n", 184 | " return alpha_hat_t.sqrt() * x0 + (1. - alpha_hat_t).sqrt() * eps\n", 185 | " \n", 186 | " def sample_q_next(self, x_t, eps, t):\n", 187 | " \"\"\"\n", 188 | " :param x_t: previous step! (batch, *)\n", 189 | " :param eps: noise samples from N(0, I) (batch, *)\n", 190 | " :param t: the timesteps in [0, nb_timesteps] (batch)\n", 191 | " \"\"\"\n", 192 | " assert (t >= 0).all() and (t < self.nb_timesteps).all(), \"Invalid timestep\"\n", 193 | " beta = self.beta[t, None]\n", 194 | " mu = (1. - beta).sqrt() * x_t\n", 195 | " return mu + (beta.sqrt() * eps)\n", 196 | " \n", 197 | " def sample_p(self, x_t, eps_hat, t, greedy=False):\n", 198 | " \"\"\"\n", 199 | " The \"reverse process\". Given a latent `x_t`, draw a sample from p(x_t-1|x_t)\n", 200 | " using the noise prediction.\n", 201 | " :param x_t: the previous sample (batch, *)\n", 202 | " :param eps_hat: the noise, predicted by neural net (batch, *)\n", 203 | " :param t: the timestep (batch)\n", 204 | " \"\"\"\n", 205 | " \n", 206 | " alpha_t = self.alpha[t, None, None, None] # (batch, 1, 1, 1)\n", 207 | " beta_t = self.beta[t, None, None, None]\n", 208 | " alpha_hat_t = self.alpha_hat[t, None, None, None]\n", 209 | " \n", 210 | " # calculate the mean\n", 211 | " mu = x_t - ((beta_t * eps_hat) / (1. - alpha_hat_t).sqrt())\n", 212 | " mu = (1. / alpha_t.sqrt()) * mu\n", 213 | " \n", 214 | " if greedy:\n", 215 | " return mu\n", 216 | " \n", 217 | " # sample\n", 218 | " std = self.posterior_variance[t, None, None, None].sqrt()\n", 219 | " x_next = mu + std * torch.randn_like(mu)\n", 220 | " \n", 221 | " return x_next" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 6, 227 | "id": "e3f20f87-1553-4c25-a0e9-d037e064b6c4", 228 | "metadata": {}, 229 | "outputs": [ 230 | { 231 | "name": "stdout", 232 | "output_type": "stream", 233 | "text": [ 234 | "tensor(0.9939)\n", 235 | "tensor(0.0001)\n", 236 | "tensor(1.0000e-04)\n", 237 | "tensor(1.0000e-04)\n" 238 | ] 239 | } 240 | ], 241 | "source": [ 242 | "nb_steps = 200\n", 243 | "start = 1e-4\n", 244 | "end = 0.05\n", 245 | "diff_process = DiffusionProcess(nb_steps, start, end)\n", 246 | "print(diff_process.prior_variance[-1])\n", 247 | "print(diff_process.prior_variance[0])\n", 248 | "\n", 249 | "print(diff_process.posterior_variance[0])\n", 250 | "print(diff_process.beta[0])" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 7, 256 | "id": "95405ac8-8c2e-462d-9b1c-7aea0d405453", 257 | "metadata": {}, 258 | "outputs": [ 259 | { 260 | "data": { 261 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfIAAAHxCAYAAACWBT5+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAABC0UlEQVR4nO3dd3xVhf3/8deHEMIIm7D3XiogIlhnRQVnq9ZZ96zWqvVrtbW11rZatbbVOvpzVZzULSq4Klatg70hgMwww0gI2ePz++Ne7DUmcIHcnDvez8fjPnLvuefe+z73Jveds83dERERkcTUIOgAIiIisu9U5CIiIglMRS4iIpLAVOQiIiIJTEUuIiKSwFTkIiIiCUxFLgnJzHaaWe+gc0TDQv5pZtvNbNo+PsdVZva3Oo5W0+vU6/6oZjbNzIbsYZw/mNkWM9tYX7lEEomKXGLOzN41sztrGH6amW00s4Z7+5zununuK+omYcwdDhwHdHX3UXv7YDNrBPwauG9/QpiZm1nf/XmOfXzd88xstZkVmtkbZtYm4u4/A9/53Yh4bHfgJmCwu3esozz7/D6Y2bFmtsTMisxsqpn1qItMe3jNp83sD/vwuJ0RlyozK464fX4sskowVORSHyYAPzYzqzb8AuB5d6+I9on2pfTjQA9glbsX7uPjTwOWuPu6OsxUL8Jz2/+P0GfdASgCHokYZRJwjJnVVtLdga3uvnkfXrtOf1fMrB3wGvAboA0wA/jXXjy+jZml12Wm3Qn/s5vp7pnAGuCUiGHP11cOiT0VudSHN4C2wBG7BphZa+Bk4BkzG2VmX5hZnpltMLOHwnOhu8Z1M7vWzJYByyKG9Q1fP8nMZpvZDjNba2Z3RDy2Z3jci8xsTXgR7W0R96eZ2a/M7GszKzCzmWbWLXzfQDP7wMy2mVm2mZ1V2wSaWWczmxQed7mZXREefhnwBDAmPCf0uxoem2Zmfw5nWxGeVo8oovHAfyLGP9vMVppZi/Dt8eElG1m7yfdJ+OrccI6zaxt3N88x2czuj7g90cye2sPDzgfecvdP3H0noRI83cyaA7h7CTATOKGG1xsLfAB0Dmd+Ojz8VDNbGP59+djMBkU8ZpWZ3WJm84DC6mW+n+/D6cBCd385nPsO4CAzGxjl448DcszsfjMbGs0DzOxKQu/hL8J539qLvJIq3F0XXWJ+AR4Hnoi4fRUwJ3z9YGA00BDoCSwGbogY1wl9obcBmkQM6xu+fjRwAKF/TA8ENgE/CN/XMzzu40AT4CCgFBgUvv9mYD4wALDw/W2BZsBa4JJwruHAFkKLeGuavk8IzWk2BoYBucD3w/ddDHy2m/fmamAJ0C08jVPDmRuG758O/KjaY54Hng5nXQ+cHMVn8M17trtxdnNfR2Az8H1C5bICaL6H53sTuKXasJ3AwRG3HwT+UsvjjwZyIm73BwoJlWI68AtgOdAofP8qYE74vWwSzftAaK4/bzeX88LjPQA8Wu25FgBn7MXfwVBCq0jWhz/Xa4DWe3jM08Afqg17ezd5367hOVYBY+vr712X+r1ojlzqywTgTDNrHL59YXgY7j7T3b909wp3X0VoUexR1R5/t7tvc/fi6k/s7h+7+3x3r3L3ecCLNTz+d+5e7O5zgbmEChvgcuDX7p7tIXPdfSuhpQWr3P2f4VyzgVeBH1V//fAc/PcIFVaJu88hNBd+YZTvzVnA39x9rbtvA+6udn8roKDasGsJFerHhOZ4347ytfaZu28EfkLoc3sAuNDdq+eqLhPIrzYsH2gecbuA0DRG42zgHXf/wN3LCa1jbwIcFjHOg+H38ju/KzVx9zXu3mo3lxf2Ylr29FoL3P1mQv9o3EHoH5WV4aUbLfbieU7eTd6To30eSQ4qcqkX7v4ZoTnaH5hZH2AU8AKAmfU3s7fDi4d3AHcB7ao9xdrantvMDg1veJRrZvmE5nCrPz5yi+ciQl/KEPpC/bqGp+0BHBpefJtnZnmE5kJrWpfbGdhWrdRWA11qy1zD4yOnb3W1+7dTrSzcPQ94mdAc3v3Un7eANCA7/JnuyU6gekG14Nv/mDQnNCcZjc5EvD/uXkXovYt8r2v9XdlP0UwLAGZ2fsSGZVOq3+/ulYSWBM0FthH6HOtt/bkkFxW51KdnCM2l/hh4z903hYc/SmjRcj93bwH8itBi7ki72y3qBUIbTXVz95bAP2p4fG3WAn1qGf6fanM6me7+kxrGXQ+02bXeN6w7EO3GaRsI/UMR+dhI8wgtUv6GmQ0DLiW09OHBKF+nLvyR0KqPTmZ2bhTjL+R/Sz+w0C6DGcDSiHEGESq0aKwn9E/WruczQu9d5Hu9V7vQmVl3+/YW3tUvu7bwrj4tzQj97iys/pzu/rz/b8Oy8RGPyTSzi83sI2AWoX9Aznb3oeElQTX5zvSY2ZTd5P3OPw6S3FTkUp+eAcYCVxBerB7WHNgB7AxvOFRTWe5Oc0JzxCVmNgo4by8e+wTwezPrZyEHmllbQusg+5vZBWaWHr4cErlh1S7uvhb4HLjbzBqb2YHAZcBzUWZ4CfiZmXW10EaAt1a7fzIRqwrCqyeeI/QPzyVAFzO7JorX2QTs8773ZnZk+PUuBC4C/m5me1rq8DxwipkdES6+O4HXdi29CE/LwYS2gYjGS8BJFtoNLJ3QrmmlhN7/aH3rfQgvWs/czWXXFt6vA0PN7Ixw7tuBee6+JJoXNbNxhP4ROZvQ6qMu7n6Nu0/fm7zhzON3k3d8Lc8jySrolfS6pNaF0Drd7UBGxLAjCc2R7wQ+JfRl/1nE/d/ZSItvb+x2JqHFrQWECvgh4LnwfT2J2HAsIsPl4etphPbRXhl+/HRC+3tDaAO4dwhtuLYV+AgYVst0dQ2/9jZCi+qvjrjvYna/sVtD4K/h11hJaP135MZu6YR2H+ocvv1XYErE4w8Kv26/Pbz3VxOa+88DzqplHK9leAtCG0ydEzHsHuB9wPbwuueF8xcS2vitTcR9PyJU7LU99mgiNnYLD/shsIjQ+un/AEMi7lvFHjbqiuZ92M1jx4Z/V4vDv0c99+KxvXZ9hnv5mv0IbcCXB7yxj393e3xfdEnci4U/ZBGJE2bWk1Chp3t4H/vwbkiD3f2GGL+2u3u0qyXq4vW+Ai5z9wX19ZoiyUZFLhJnairyenztei1yEdl/WkcukiTC66Fr3ABqL57mOwesieJ1/1HL6/5jb59LRPae5shFREQSmObIRUREEpiKXEREJIEl4pmkaNeunffs2TPoGCIiIvVi5syZW9y9xhMjJWSR9+zZkxkzZgQdQ0REpF6YWfVDN39Di9ZFREQSmIpcREQkganIRUREEpiKXEREJIGpyEVERBKYilxERCSBqchFREQSmIpcREQkganIRUREEpiKXEREJIGpyEVERBKYilxERCSBxbTIzewpM9tsZgtqud/M7EEzW25m88xsRCzziIiIJJtYz5E/DYzbzf3jgX7hy5XAozHOIyIiklRiWuTu/gmwbTejnAY84yFfAq3MrFMsM4mIiCSToNeRdwHWRtzOCQ/7DjO70sxmmNmM3NzcegknIiIS74Iu8qi5+2PuPtLdR2ZlZQUdR0REpEYl5ZXkFZXV2+s1rLdXqtk6oFvE7a7hYSIiIoGqqnJ2lJSztbCMbYVlbN1ZGrq+s4ztReXkF4cuO4rLySsu++Z2SXkV/Ttk8v6NR9VLzqCLfBLwUzObCBwK5Lv7hoAziYhIEisuq2TTjpLQpaCUTfn/u76loDRU2oVlbC8qo7LKa3yOZo3SaNkknRZN0mnZJJ1e7ZrRMny9ZZN0OrVsUm/TE9MiN7MXgaOBdmaWA/wWSAdw938Ak4ETgeVAEXBJLPOIiEhyq6xyNu4oYe22InK2F7N2WxHr8orZtKOEjeHC3lFS8Z3HNU5vQIcWjWmXmUH3tk0Z3r0VbTMb0aZZBm2bNaJN+NIuM4PWzdLJaJgWwNTVLKZF7u7n7uF+B66NZQYREUkuO0rKWZFbyKothf8r7O2hn+vziqmImIs2gw7NG9OxZWP6ZGVyWJ+2tG/RmA4tGtOxRWM6tMigfYvGtGjcEDMLcKr2XdCL1kVERL6jorKKtduLWZG7kxW5hazYspOvcwtZkVvIlp2l3xo3q3kG3Vo3YVi3VpxyUCe6tm5Kt9ZN6dq6CZ1bNaFRw4TZrnufqMhFRCRQW3eWsmRjAYs37CB7YwFLNhawdFMBpRVV34zTplkjerdrxvcHZtE7K5Pe7ZrRO6sZXVs3pXF6/CzmDoKKXERE6oW7sy6vmHk5+czNyWPR+h0s2VhAbsH/5rDbZWYwqFNzLhjdg/4dm9MnK5M+Wc1o1bRRgMnjm4pcRERiYuvOUubl5DNnbR7zcvKYl5PP1sLQ/tXpaUb/Ds05qn8WAzs2Z1CnFgzo2Jx2mRkBp048KnIREdlv7s7qrUVMW7mNr1ZuY9qqrazdVgyENjjr1z6TYwa256CuLTmwaysGdmoeV1t+JzIVuYiI7LWqKmfp5oJvinv6ym1sDi8ib9OsEYf0bM0Fo3twYNdWDO3SkswM1U2s6J0VEZGorM8r5rNlW/hkWS7/Xb6F7UXlAHRs0ZgxfdoyqlcbDu3Vhj5ZmQm7K1ciUpGLiEiNSsor+WLFVj5Zmsuny7awfPNOANo3z+CYge05rE87Du3Vhq6tm6i4A6QiFxGRb2zdWcpHSzbz4eJNfLJ0C8XllWQ0bMCoXm0455BuHNEvi/4dNMcdT1TkIiIpbvXWQqYs2MiHizYxc8123KFTy8aceXBXjh3UntG926b8vtrxTEUuIpKC1m4rYvL8Dbw9bwPz1+UDMLRLC64/th9jB3VgSOcWmutOECpyEZEUsWlHCW/NXc9b8zYwd20eAAd1a8VtJw5i/AEd6dq6abABZZ+oyEVEklhJeSXvL9rEqzNz+HRZLlUemvO+dfxATjqgE93aqLwTnYpcRCTJuDuz1uTxyswc3p63noKSCjq3bMw1R/fl9BFd6J2VGXREqUMqchGRJLGjpJzXZ63juS9Xs2zzTpqkpzF+aEfOOLgrY3q3pUEDrfNORipyEZEEt3B9Ps99uZo356ynqKySA7u25J4zDuCkAzvriGopQJ+wiEgCKq+sYvL8DTz9+Spmr8mjcXoDTj2oMz8OHxZVUoeKXEQkgRSUlDNx2lr++d+VrM8voXe7Ztx+8mDOGNGVlk3Tg44nAVCRi4gkgPV5xTz9+Spe/GoNBaUVjO7dhj/8cChH92+vdd8pTkUuIhLHVuTu5KGpy5k0Zz0OnHRAJ644ojcHdG0ZdDSJEypyEZE4tHzzTv7+0TLemrueRg0bcOGYnlx6eE8dtEW+Q0UuIhJHlm0q4MGPlvP2vPU0bpjGFUf05ooje9MuMyPoaBKnVOQiInFg1ZZC7v9gKW/PW0/T9DSuPqoPlx/ei7YqcNkDFbmISIC27Czl7/9exvNfrSE9rQE/OaoPlx/RmzbNGgUdTRKEilxEJACFpRU88elKHvvka0oqqjjnkG5cP7Yf7Zs3DjqaJBgVuYhIPaqorOLF6Wt54MOlbNlZxvihHbn5hAE6/rnsMxW5iEg9+WrFVn47aSFLNhYwqlcbHrtwICO6tw46liQ4FbmISIxtzC/hrsmLmTR3PV1aNeHR80cwbmhHzHQgF9l/KnIRkRgprajkyc9W8tBHy6mocn52bD9+clQfmjRKCzqaJBEVuYhIDHz+9RZue30BK7cUcvzgDvzm5MF0a6ODuUjdU5GLiNSh/KJy7p6ymInT19KjbVMmXDqKo/pnBR1LkpiKXESkjkyZv4HbJy1kW2EZVx/VhxvG9qNxuhajS2ypyEVE9tOmHSXc/uYC3lu4iSGdW/DPiw9haBed1ETqh4pcRGQfuTuvzVrHHW8tpKyiil+OH8hlh/eiYVqDoKNJClGRi4jsg22FZdz2+nymLNjIIT1bc9+ZB9GzXbOgY0kKUpGLiOylqdmb+cUr88grKuPW8QO54ojepDXQPuESDBW5iEiUisoquGvyYp77cg0DOjRnwiWjGNy5RdCxJMWpyEVEojA/J5+fTZzNqq2FXHFEL246foC2SJe4oCIXEdkNd+fZL1fzh7cX0zazEc9ffiiH9WkXdCyRb6jIRURqsaOknFtfncfk+Rs5ZkAWfzlrGK11nnCJMypyEZEaLFiXz7UvzCJnezG/DG/Q1kAbtEkcUpGLiERwd577cjW/Dy9K/9eVoxnZs03QsURqpSIXEQkrKa/k1lfn8cac9RwzIIv7zxpGGy1KlzinIhcRAdblFXPVszNYuH4HNx3Xn2uP6atF6ZIQVOQikvKmrdzGNc/PpKS8iscvGMnYwR2CjiQSNRW5iKS0575czR2TFtK9TVMmXjmSvu0zg44ksldU5CKSksoqqrjjrYW88NUajh6QxQPnDKdlk/SgY4nsNRW5iKSc/KJyrnpuBl+u2MZPju7D/x0/QMdKl4SlIheRlLJ2WxGXPD2dNVuL+NvZw/jB8C5BRxLZLypyEUkZc9fmcdmE6ZRVVPHMZaMY3btt0JFE9puKXERSwgeLNvGzF2fTNrMRE68cTd/2zYOOJFInVOQikvQmfL6K3721kKFdWvLkRYeQ1Twj6EgidUZFLiJJy925e8oSHvtkBWMHdeDBc4fRtJG+9iS56DdaRJJSZZXzy9fm8dKMHC4Y3YM7Th2iLdMlKanIRSTplFZUcsPEOUxZsJGffb8vNx7XHzOVuCQnFbmIJJWisgquenYmny7bwq9PGsTlR/QOOpJITKnIRSRp5BeVc8nT05izNo97zziQsw7pFnQkkZhTkYtIUthcUMKFT05jRW4hj5w/gnFDOwUdSaReqMhFJOFtzC/h3Me/ZNOOEp68eCRH9MsKOpJIvVGRi0hC25BfzLmPfcmWnWU8e9koDu7RJuhIIvVKRS4iCWtDfjHnPPYlW3eWMeHSURzco3XQkUTqXYOgA4iI7Iv1eSpxEdAcuYgkoPV5xZz7+Jds21nGM5eNYkR3lbikLhW5iCSUXXPi2wtDJT5cJS4pTovWRSRhbNoR2jpdJS7yPypyEUkI2wvLuODJr9hSUKoSF4mgResiEvcKSsq56J/TWLW1iAmXqMRFImmOXETiWnFZJZdNmMGi9Tt49PwRjOnTNuhIInFFc+QiErfKKqr4yfMzmb5qGw+cM5xjB3UIOpJI3NEcuYjEpcoq58Z/zeHj7Fzu+uEBnHpQ56AjicSlmBe5mY0zs2wzW25mt9Zwf3czm2pms81snpmdGOtMIhLf3J1fvTafd+Zv4LYTB3HuqO5BRxKJWzEtcjNLAx4GxgODgXPNbHC10X4NvOTuw4FzgEdimUlE4t/97y/lXzPWct33+3LFkTqfuMjuxHqOfBSw3N1XuHsZMBE4rdo4DrQIX28JrI9xJhGJY89/tZqHpi7n3FHd+Plx/YOOIxL3Yr2xWxdgbcTtHODQauPcAbxvZtcBzYCxMc4kInHqg0Wb+M0bC/j+wPb8/rShmFnQkUTiXjxs7HYu8LS7dwVOBJ41s+/kMrMrzWyGmc3Izc2t95AiEluz1mznuhdncUCXljx03nAapsXD15NI/Iv1X8o6oFvE7a7hYZEuA14CcPcvgMZAu+pP5O6PuftIdx+ZlZUVo7giEoSVWwq5fMIMOrRozJMXH0LTRtozViRasS7y6UA/M+tlZo0Ibcw2qdo4a4BjAcxsEKEi1yy3SIrILSjloqemYcCES0bRLjMj6EgiCSWmRe7uFcBPgfeAxYS2Tl9oZnea2anh0W4CrjCzucCLwMXu7rHMJSLxoaisgssmTCe3oJQnLz6Enu2aBR1JJOHEfPmVu08GJlcbdnvE9UXA92KdQ0TiS1WV8/N/zWXBunwev3Akw7q1CjqSSELS1iQiEoj7P8jm3YUbue2kwTr0qsh+UJGLSL17bVYOD0/9mnNHdefS7/UMOo5IQlORi0i9mrFqG7e+Op8xvdty52lDtK+4yH5SkYtIvVm7rYirnp1J51aNefTHI0jXvuIi+01/RSJSLwpKyrl8wgzKKqt48uJDaNW0UdCRRJKCjrogIjFXWeVcP3EOy3N3MuGSUfTJygw6kkjS0By5iMTcn9/P5qMlm7njlMEc3u87B24Ukf2gIheRmJoyfwOPfhzaQv2CMT2DjiOSdFTkIhIzSzcVcNPLcxnevRV3nDo46DgiSUlFLiIxkV9czlXPzqRpo4Y8ev7BZDRMCzqSSFJSkYtInQsdfnUOa7cV8cj5I+jYsnHQkUSSlopcROrcgx8t499LNvPrkwYxqleboOOIJDUVuYjUqX8v3sTfPlzG6SO6cNFhPYOOI5L0VOQiUmdWbinkhn/NYWiXFtz1wwN0+FWReqAiF5E6UVJeyTXPzyKtgfGPHx9M43Rt3CZSH3RkNxGpE797axGLN+zgnxcfQtfWTYOOI5IyNEcuIvvtjdnreHHaGn5ydB+OGdg+6DgiKUVFLiL7Zfnmnfzq9fmM6tmGm47rH3QckZSjIheRfVZcVsm1z8+iSXoaD547nIY6LalIvdM6chHZZ7e/uYClmwuYcMkoHfRFJCD691lE9skrM3N4eWYO1x3TlyP7ZwUdRyRlqchFZK8t3VTAr9+Yz5jebbl+rNaLiwRJRS4ie6WkvJLrXphNZkZDHjh3GGkNdNAXkSBpHbmI7JW7Jy8me1MBEy4dRfvmWi8uEjTNkYtI1P69eBMTvljN5Yf34iitFxeJCypyEYnK5h0l3PzKPAZ3asHN4wYEHUdEwlTkIrJHVVXOTS/PpaisggfPHU5GQx1HXSReqMhFZI+e+GwFny7bwm9PGULf9plBxxGRCCpyEdmt+Tn53PdeNuOGdOScQ7oFHUdEqlGRi0itCksr+NnE2bRtlsGfztD5xUXikXY/E5Fa/f7tRazaWsgLl4+mVdNGQccRkRpojlxEavTvxZuYOH0tVx3ZhzF92gYdR0RqoSIXke/YVljGLa/OZ2DH5tx4XL+g44jIbmjRuoh8i7tz2+vzyS8u49nLRmlXM5E4pzlyEfmWN+esZ8qCjfz8uAEM6tQi6DgisgcqchH5xob8Yn7z5gIO7tGaK4/sHXQcEYmCilxEgNDR225+eR6VVc5fzjpIZzUTSRAqchEB4LmvVvPZ8i3cdtIgerRtFnQcEYmSilxEWJG7k7smL+ao/lmcN6p70HFEZC+oyEVSXGX4hCgZDdO498wDdfQ2kQSj3c9EUtxTn61k9po8HjhnGB1aNA46jojsJc2Ri6SwFbk7+fP72Ywd1IFTD+ocdBwR2QcqcpEUVVXl3PLqPDIaNuCuHw7VInWRBKUiF0lRE75YxfRV27n9lCG01yJ1kYSlIhdJQau3FnLvu9kcPSCLM0Z0CTqOiOwHFblIitm1SL1hA+Pu03WOcZFEpyIXSTEvTFvDlyu2cdtJg+jUsknQcURkP6nIRVJIzvYi7p68mMP7tuPsQ7oFHUdE6oCKXCRFuDu/fG0+DlqkLpJEVOQiKeL12ev4dNkWbh0/kG5tmgYdR0TqiIpcJAVs3VnK799exIjurfjxoT2CjiMidUhFLpIC/vjOYnaWVvCnMw6kgU5PKpJUVOQiSe6Tpbm8NnsdVx/Vh/4dmgcdR0TqmIpcJIkVlVVw2xvz6Z3VjGuP6Rt0HBGJAZ39TCSJPfDhMtZuK2bilaNpnJ4WdBwRiQHNkYskqQXr8nnis5Wcc0g3RvduG3QcEYkRFblIEqqorOLW1+bRumkjfjl+UNBxRCSGtGhdJAk9/fkqFqzbwcPnjaBl0/Sg44hIDGmOXCTJrMsr5v73l3LswPaceEDHoOOISIypyEWSzB2TFgJw5w+G6jCsIilARS6SRD5ctIkPFm3i+rH96NJKZzYTSQUqcpEkUVRWwW8nLaR/h0wuO7xX0HFEpJ5oYzeRJPHgv5ezLq+Yl64aQ3qa/kcXSRX6axdJAks3FfDEpys48+CujOrVJug4IlKPVOQiCc7d+fUbC2iW0ZBfjh8YdBwRqWcqcpEE9+qsdUxbuY1bxw+kbWZG0HFEpJ6pyEUSWF5RGXdNXsyI7q04e2S3oOOISABU5CIJ7J53s8kvLuePPzxA5xkXSVEqcpEENWdtHhOnr+GSw3oyqFOLoOOISEBU5CIJqKrK+e2bC8jKzOD6sf2CjiMiAVKRiySgl2euZW5OPr86cRDNG+ukKCKpLOZFbmbjzCzbzJab2a21jHOWmS0ys4Vm9kKsM4kksvyicu55N5tDerbmtGGdg44jIgGL6ZHdzCwNeBg4DsgBppvZJHdfFDFOP+CXwPfcfbuZtY9lJpFE99cPl5JXVMYdp47SSVFEJOZz5KOA5e6+wt3LgInAadXGuQJ42N23A7j75hhnEklYizfs4JkvVnH+oT0Y0rll0HFEJA7Eusi7AGsjbueEh0XqD/Q3s/+a2ZdmNq6mJzKzK81shpnNyM3NjVFckfjl7vz2zYW0bJLOTcf3DzqOiMSJeNjYrSHQDzgaOBd43MxaVR/J3R9z95HuPjIrK6t+E4rEgUlz1zNt1TZ+MW4grZo2CjqOiMSJWBf5OiDycFNdw8Mi5QCT3L3c3VcCSwkVu4iE7Syt4K7JizmgS0vO0hHcRCRCrIt8OtDPzHqZWSPgHGBStXHeIDQ3jpm1I7SofUWMc4kklL9/tIxNO0r53WlDSNMR3EQkQkyL3N0rgJ8C7wGLgZfcfaGZ3Wlmp4ZHew/YamaLgKnAze6+NZa5RBLJ17k7eeqzlZx5cFdGdG8ddBwRiTNR7X5mZhnuXrqnYTVx98nA5GrDbo+47sDPwxcRieDu/O6tRTRumMYt43SKUhH5rmjnyL+IcpiI1KEPFm3ik6W53Hhcf7Ka6xSlIvJdu50jN7OOhHYXa2Jmw4FdK+daAE1jnE0kpZWUV3Ln24vo3yGTC8b0CDqOiMSpPS1aPwG4mNDW5vfzvyLfAfwqdrFE5IlPV5CzvZgXrjiU9LR42FNUROLRbovc3ScAE8zsDHd/tbbxzOyi8LgiUgc27SjhkY+/ZtyQjhzWp13QcUQkjkX1b/7uSjzs+jrIIiJh976bTUWl86sTBwUdRUTiXF0tr9OOrSJ1ZF5OHq/OyuHSw3vRva02RRGR3aurIvc6eh6RlObu3PnWItplNuLaY/oEHUdEEoDmyEXiyNvzNjBj9Xb+7/gBNG+cHnQcEUkAURW5mfXaw7D/1lkikRRVUl7Jn6YsYXCnFvxIx1MXkShFO0de08Zur+y64u4/rZs4IqnriU9XsC6vmN+cPFjHUxeRqO3pgDADgSFASzM7PeKuFkDjWAYTSSWRu5uN6dM26DgikkD2dECYAcDJQCvglIjhBcAVMcokknK0u5mI7Ks9HRDmTeBNMxvj7jq2ukgM7Nrd7Oqj+mh3MxHZa1Gd/QyYbWbXElrM/s0idXe/NCapRFKEdjcTkf0V7cZuzwIdCR17/T+Ejr1eEKtQIqlCu5uJyP6Ktsj7uvtvgMLwMdVPAg6NXSyR5KfdzUSkLkRb5OXhn3lmNhRoCbSPTSSR1PDkZyu1u5mI7Ldo15E/Zmatgd8Ak4BM4PaYpRJJcrkFpTwydTnHD+6g3c1EZL9EVeTu/kT46n+A3rGLI5Ia/vbhUkorqrh1/MCgo4hIgouqyM0sAzgD6Bn5GHe/MzaxRJLXsk0FTJy+lh8f2p3eWZlBxxGRBBftovU3gXxgJlAauzgiye9PU5bQND2N68f2DzqKiCSBaIu8q7uPi2kSkRTw+fIt/HvJZm4dP5A2zRoFHUdEkkC0W61/bmYHxDSJSJKrqnL+OHkxXVo14eLDegYdR0SSxJ5OmjIf8PB4l5jZCkKL1g1wdz8w9hFFksMbc9axcP0OHjhnGI3T04KOIyJJYk+L1k+O5knMrLW7b6+DPCJJqbiskvvey+bAri055cDOQccRkSSyp5OmrI7yef4NjNj/OCLJ6an/rmRDfgl/PXsYDXTwFxGpQ9GuI98TfTOJ1GLXwV+OG9yB0b118BcRqVt1VeReR88jknQe+Hfo4C+/1MFfRCQG6qrIRaQGyzcX8OK0tZyvg7+ISIxo0bpIDN09OXTwl58d2y/oKCKSpKI9IAwAZtYeaLzrtruvCV89ti5DiSSDz78OHfzllnEDaZuZEXQcEUlSUc2Rm9mpZrYMWEnoxCmrgCm77nf3bTFJJ5Kg3J0/TVlC55aNueR7PYOOIyJJLNpF678HRgNL3b0XoTnwL2OWSiTBTZ6/kXk5+dx4XH8d/EVEYiraIi93961AAzNr4O5TgZExzCWSsMorq/jz+9n075DJ6SO6Bh1HRJJctOvI88wsE/gEeN7MNgOFsYslkrhemrGWlVsKeeLCkaTp4C8iEmPRzpGfBhQDNwLvAl8Dp8QqlEiiKiqr4IEPlzGyR2uOHdQ+6DgikgKimiN398i57wkxyiKS8P7531VsLijlkfNHYKa5cRGJvWi3Wj/dzJaZWb6Z7TCzAjPbEetwIolke2EZ//j4a8YOas/Inm2CjiMiKSLadeT3Aqe4++JYhhFJZI98vJydZRXcfIIOxSoi9SfadeSbVOIitVuXV8yEL1Zz+vCuDOjYPOg4IpJCdjtHbmanh6/OMLN/AW8Apbvud/fXYhdNJHH87YOl4HDjcToUq4jUrz0tWo/cMr0IOD7itgMqckl5yzYV8OqsHC75Xi+6tm4adBwRSTG7LXJ3v6S+gogkqnvfy6ZZo4Zce0zfoKOISAqKdqv13mb2lpnlmtlmM3vTzHrFOpxIvJu5ehsfLNrElUf2pk2zRkHHEZEUFO3Gbi8ALwGdgM7Ay8DEWIUSSQTuzj1TsmmXmcFlR+j/WhEJRrRF3tTdn3X3ivDlOSJOZyqSiqZmb2baqm1cf2xfmjbaqzMCi4jUmWi/faaY2a2E5sIdOBuYbGZtQKcxldRTWeXc+242Pdo25ZxR3YOOIyIpLNoiPyv886pqw88hVOy96yyRSAJ4c846lmws4MFzh5OeFu2CLRGRuhftsda1AlAkrLSikvvfX8qQzi04+YBOQccRkRQX7QFhaqQDwkgqev7LNazLK+bu0w+ggU5TKiIB25sDwlSnA8JIyikoKeehqcs5rE9bjujXLug4IiI6IIzI3nj805VsKyzjlnEDdZpSEYkLUe8zY2YnAUOI2O3M3e+MRSiReJRbUMoTn67gpAM6cVC3VkHHEREBoj+y2z8I7XJ2HWDAj4AeMcwlEnce+mgZpRVV3HR8/6CjiIh8I9r9Zg5z9wuB7e7+O2AMoG8zSRlrthbxwrQ1nH1IN3pnZQYdR0TkG9EWeXH4Z5GZdQbKCR2uVSQl3P9BNmkNjOuP1WlKRSS+RLuO/G0zawXcB8witMX647EKJRJPFqzL580567nm6D50aKEjE4tIfIn2gDC/D1991czeBhq7e37sYonEj3vfy6Zlk3SuOqpP0FFERL5jr48t6e6lKnFJFZ9/vYVPluZy7TF9aNkkPeg4IiLfoYNEi9TC3bnn3Ww6tWzMhWN6Bh1HRKRGKnKRWry7YCNz1+Zx49j+NE5PCzqOiEiN9rrIzeyOGOQQiSsVlVXc9342fdtncvqILkHHERGp1b7MkZ9a5ylE4szLM3NYkVvIzScMoKFOUyoicWxfvqF0gGlJasVllfztw6WM6N6K4wd3CDqOiMhu7UuRH1znKUTiyNOfr2LTjlKdGEVEEsK+7H5WFYsgIvEgv6icRz9ezjEDsji0d9ug44iI7JFW/olEeOQ/yykoreAX4wYGHUVEJCoqcpGwDfnFPP3fVfxgWBcGdWoRdBwRkajsc5Gb2SV1GUQkaA98uIwqd35+nE7sJyKJY3/myH9XZylEArZ8805emrGWH4/uQbc2TYOOIyIStd2eNMXM5tV2FxDVfjlmNg54AEgDnnD3P9Uy3hnAK8Ah7j4jmucWqSt/fi+bpo0a8tNj+gYdRURkr+zp7GcdgBOA7dWGG/D5np7czNKAh4HjgBxguplNcvdF1cZrDlwPfBVlbpE6M3vNdt5duJEbx/anbWZG0HFERPbKnhatvw1kuvvqapdVwMdRPP8oYLm7r3D3MmAicFoN4/0euAcoiT66yP4LnRhlCe0yG3H5Eb2CjiMistd2W+Tufpm7f1bLfedF8fxdgLURt3PCw75hZiOAbu7+ThTPJ1Kn/rM0ly9XbOO67/ejWcaeFlCJiMSfQHc/M7MGwF+Am6IY90ozm2FmM3Jzc2MfTpJeVVXoNKXd2zTl3FHdg44jIrJPYl3k64BuEbe7hoft0hwYCnxsZquA0cAkMxtZ/Ync/TF3H+nuI7OysmIYWVLFW/PWs3jDDm46vj+NGuqQCiKSmGL97TUd6GdmvcysEXAOMGnXne6e7+7t3L2nu/cEvgRO1VbrEmtlFVX8+f1sBndqwSkHdg46jojIPotpkbt7BfBT4D1gMfCSuy80szvNTKdDlcC88NVq1m4r5hfjBtCggU6MIiKJK+Zb97j7ZGBytWG31zLu0bHOI7KztIK/f7Sc0b3bcFR/raYRkcSmzXQl5Tzx6Qq2FpbxhE5TKiJJQFv4SErZsrOUxz9ZwbghHRnevXXQcURE9puKXFLKQx8tp7i8kv87YUDQUURE6oSKXFLG2m1FPP/Vas4a2Y2+7TODjiMiUidU5JIy/vLBUhqYccNYnaZURJKHilxSwuINO3hjzjou/l5POrZsHHQcEZE6oyKXlHDvu0tontGQa47SaUpFJLmoyCXpfbViK1Ozc7nmmL60bJoedBwRkTqlIpek5u786d0ldGzRmIsP6xl0HBGROqcil6T2/qJNzF6Txw1j+9E4PS3oOCIidU5FLkmrorKK+97Lpk9WM848uGvQcUREYkJFLknrtVnrWL55JzefMICGafpVF5HkpG83SUol5ZX89cOlDOvWihOGdAw6johIzKjIJSk988UqNuSXcItOjCIiSU5FLkknv7ich6d+zVH9sxjTp23QcUREYkpFLknn//3na/KLy/nFOJ0YRUSSn4pcksqmHSU89d+VnDasM0M6tww6johIzKnIJan87cNlVFY5Nx2nuXERSQ0qckkayzfv5KUZazn/0B50b9s06DgiIvVCRS5J48/vZdMkPY3rvq8To4hI6lCRS1KYuXo77y7cyJVH9qZtZkbQcURE6o2KXBKeu3PPlCW0y8zgssN7BR1HRKReqcgl4X20ZDPTVm3jhrH9aJbRMOg4IiL1SkUuCa2yyrnn3SX0ateMsw/pFnQcEZF6pyKXhPbarByWbgqdGCVdJ0YRkRSkbz5JWCXllfzlg6Uc1K0V44fqxCgikppU5JKwJnweOjHKrToxioikMBW5JKT8onIenrqcowfoxCgiktpU5JKQHvnPcgpKK7hl3MCgo4iIBEpFLglnfV4x//zvKn44vAuDOrUIOo6ISKBU5JJw/vrBUnD4+XH9g44iIhI4FbkklOyNBbw6K4cLx/Sga2udGEVEREUuCeW+95bQLKMh1x6jE6OIiICKXBLItJXb+HDxZn5ydB9aN2sUdBwRkbigIpeE4O78acpiOrTI4JLDdGIUEZFdVOSSEN5buIlZa/K4cWx/mjRKCzqOiEjcUJFL3CuvrOKed5fQt30mZx7cNeg4IiJxRUUuce+Fr9awckshvxw/kIY6MYqIyLfoW1Hi2o6Sch749zLG9G7L9we2DzqOiEjcUZFLXHv046/ZVljGbScN0olRRERqoCKXuLUur5inPlvJD4d3YWiXlkHHERGJSypyiVv3v5eNA/93woCgo4iIxC0VucSlBevyeW32Oi79Xi+6tGoSdBwRkbilIpe44+788Z3FtG6azjXH9Ak6johIXFORS9yZmr2ZL1Zs5Yax/WnROD3oOCIicU1FLnGlorKKuycvoVe7Zpx3aPeg44iIxD0VucSVl2bksGzzTm4ZN5B0HfxFRGSP9E0pcWNnaQV/+WApI3u05oQhHYKOIyKSEFTkEjce+2QFW3aW6uAvIiJ7QUUucWHTjhIe/2QFJx3YieHdWwcdR0QkYajIJS7c/342FVVV3HLCwKCjiIgkFBW5BG7BunxenpnDxYf1pHvbpkHHERFJKCpyCZS7c+dbi2jTtBHXHdsv6DgiIglHRS6Bmjx/I9NWbeOm4wfo4C8iIvtARS6BKSmv5K7JixnYsTlnH9It6DgiIglJRS6BefKzlazLK+b2kweT1kC7m4mI7AsVuQRi844SHp66nOMHd+Cwvu2CjiMikrBU5BKI+97Lpryyil+dOCjoKCIiCU1FLvVufk4+r8zK4dLv9aJnu2ZBxxERSWgqcqlX7s6dby+kTdNGXPv9vkHHERFJeCpyqVeT529k+qrt2t1MRKSOqMil3mh3MxGRuqcil3rzze5mp2h3MxGRuqIil3qxa3ezE4Z04LA+2t1MRKSuqMilXtyr3c1ERGJCRS4xN2vNdl6ZmcOlh/eiR1vtbiYiUpdU5BJTlVXOb99cSIcWGfzs+zq7mYhIXVORS0z9a/pa5q/L57aTBtMso2HQcUREko6KXGJme2EZ9763hEN7teGUAzsFHUdEJCmpyCVm7v8gm4KSCn532hDMtLuZiEgsqMglJhasy+f5r9Zw4ZgeDOzYIug4IiJJK+ZFbmbjzCzbzJab2a013P9zM1tkZvPM7N9m1iPWmSS2qqqc299cQNtmjbhhbP+g44iIJLWYFrmZpQEPA+OBwcC5Zja42mizgZHufiDwCnBvLDNJ7L0+ex2z1uRxy7iBtGyi46mLiMRSrOfIRwHL3X2Fu5cBE4HTIkdw96nuXhS++SXQNcaZJIZ2lJRz95QlDO/eijNG6KMUEYm1WBd5F2BtxO2c8LDaXAZMiWkiiakHPlzG1sJS7jx1KA10PHURkZiLmx17zezHwEjgqFruvxK4EqB79+71mEyitWj9Dp7+fBXnjurOAV1bBh1HRCQlxHqOfB0Qeb7KruFh32JmY4HbgFPdvbSmJ3L3x9x9pLuPzMrKiklY2XdVVc6vXp9Pqybp/OKEAUHHERFJGbEu8ulAPzPrZWaNgHOASZEjmNlw4P8RKvHNMc4jMfLi9DXMWZvHr08eRKumjYKOIyKSMmJa5O5eAfwUeA9YDLzk7gvN7E4zOzU82n1AJvCymc0xs0m1PJ3EqdyCUu6ZsoQxvdvyg2G72wRCRETqWszXkbv7ZGBytWG3R1wfG+sMElt/fGcRJeVV/OGHQ3UENxGReqYju8l++e/yLbwxZz1XH92HPlmZQccREUk5KnLZZyXllfz6jQX0aNuUa47uE3QcEZGUFDe7n0ni+cd/vmbllkKeuXQUjdPTgo4jIpKSNEcu+2TllkIemfo1pxzUmSP7a3dAEZGgqMhlr7k7v3ljARnpDfjNyYOCjiMiktJU5LLXXpmZw2fLt/CLEwbQvnnjoOOIiKQ0Fbnslc0FJfz+7UUc0rM15x+qM86KiARNRS575bdvLqSkooo/nXGgTooiIhIHVOQStSnzNzBlwUZuGNtP+4yLiMQJFblEJb+onN+8uZChXVpw5RG9g44jIiJh2o9covKHdxaxvaiMCZceQsM0/f8nIhIv9I0se/TJ0lxenpnD1Uf1ZkhnnWdcRCSeqMhltwpLK/jla/PpndWM677fL+g4IiJSjRaty27d91426/OLefmqMToMq4hIHNIcudRq2sptTPhiFReO7sHInm2CjiMiIjVQkUuNdpZWcNPLc+jepim/GDcw6DgiIlILLVqXGv3xnUWs217MS1eNoVmGfk1EROKV5sjlOz5asokXp63lyiP7aJG6iEicU5HLt2wvLOOWV+czsGNzbjxOW6mLiMQ7LTOVb7g7v35jAXlFZUy4ZBQZDbWVuohIvNMcuXxj0tz1vDN/Azce15/BnVsEHUdERKKgIhcANuaX8Js3FjCieyuuOrJP0HFERCRKKnLB3bn5lbmUVzr3nzWMNJ2eVEQkYajIhQmfr+LTZVv41YkD6dWuWdBxRERkL6jIU9yi9Tu4a/ISjh3Ynh+P7hF0HBER2Usq8hRWVFbBdS/OolXTdO770UGYaZG6iEii0e5nKezOtxaxYkshz192KG2aNQo6joiI7APNkaeod+ZtYOL0tfzkqD4c1rdd0HFERGQfqchT0NptRdz62jyGdWvFjcf1DzqOiIjsBxV5iimtqOSnL8wC4MFzhpOepl8BEZFEpnXkKeaudxYzNyeff/z4YLq3bRp0HBER2U+aHUshb81dz4QvVnP54b0YN7Rj0HFERKQOqMhTxNe5O7n11Xkc3KM1t4wfGHQcERGpIyryFFBcVsk1z80iIz2Nh87TenERkWSideRJzt259bV5LN1cwIRLRtGpZZOgI4mISB3SrFmSe/zTFbw5Zz3/d/wAjuyfFXQcERGpYyryJPbJ0lz+NGUJJx7QkWuO1qlJRUSSkYo8Sa3eWsh1L86mf4fm3HemjqMuIpKsVORJqLC0giufmQnAYxeMpFmGNoUQEUlWKvIkU1nlXD9xDss2F/DQecN10BcRkSSnIk8yd01ezIeLN/HbU4ZwRD9t3CYikuxU5Enk2S9W8eRnK7n4sJ5cdFjPoOOIiEg9UJEnianZm/ntpIUcO7A9vzl5cNBxRESknqjIk8DiDTv46fOzGNixBQ+eO5y0BtpCXUQkVajIE9zabUVc9NQ0Mhs35MmLtYW6iEiq0bd+AsstKOWCJ7+itKKKl64ao8OvioikIM2RJ6iCknIu/uc0Nu4o4amLRzKgY/OgI4mISABU5AmopLySK56ZQfbGAh798cEc3KNN0JFERCQgWrSeYMoqqvjpC7P5csU2/nb2MI4Z0D7oSCIiEiDNkSeQ8soqrntxFh8u3sTvTxvCD4Z3CTqSiIgETEWeIMorq/jZi7N5b+Em7jhlMBeM6Rl0JBERiQMq8gRQUVnFDRPnMGXBRm4/eTAXf69X0JFERCROqMjjXFlFFddPnMM78zfw65MGcenhKnEREfkfbewWx4rKKrj6uVl8sjSXX580iMuP6B10JBERiTMq8jiVX1TOJU9PY87aPO4940DOOqRb0JFERCQOqcjj0OYdJVz41DRW5BbyyPkjGDe0U9CRREQkTqnI40z2xgIufXo624vKeOriQzi8X7ugI4mISBxTkceRqdmbue6F2TRplMbEK0dzYNdWQUcSEZE4pyKPA+7OhM9XcefbixjYsQVPXDSSzq10AhQREdkzFXnASisqufOtRTz/1RqOG9yBv509TKciFRGRqKkxArRmaxHXvjCL+evyufqoPvzihAE0aGBBxxIRkQSiIg/Iuws2cvMrczHgsQsO5vghHYOOJCIiCUhFXs+Kyyq5590lPP35Kg7q2pKHzhtBtzZNg44lIiIJSkVej2as2sbNr8xj5ZZCLj6sJ788cSAZDdOCjiUiIglMRV4Pissq+fP72Tz135V0btmE5y8/lO/11f7hIiKy/1TkMeTufLh4M394ZxGrtxZxwege3DJ+IJnaKl1EROqIGiVGlmzcwR/eXsxny7fQt30mL1xxKIf10Vy4iIjULRV5HVufV8zDU5fz4rQ1NG+czh2nDOb80T1IT9MZY0VEpO6pyOvIyi2FPPrxcl6fvQ53uGB0D24Y25/WzRoFHU1ERJKYinw/VFU5X63cxnNfrWbK/A2kpzXgvFHdueLI3nRtrV3KREQk9mJe5GY2DngASAOecPc/Vbs/A3gGOBjYCpzt7qtinWt/bMgv5pUZObw8M4c124pontGQK4/sw2WH9yKreUbQ8UREJIXEtMjNLA14GDgOyAGmm9kkd18UMdplwHZ372tm5wD3AGfHMtfeqqxy5uXk8XF2Lh9nb2beunzcYUzvttx4XD/GDelEk0baH1xEROpfrOfIRwHL3X0FgJlNBE4DIov8NOCO8PVXgIfMzNzdY5ytRqUVlazPK2HR+h0s2pDPwvU7mLs2j+1F5ZjB8G6t+PnY/pw2rAvd22rxuYiIBCvWRd4FWBtxOwc4tLZx3L3CzPKBtsCWGGcD4M0565g4bS25O0vJLSglv7j8m/saNjD6ts/k2EEdOKJfO47sl6WN10REJK4kzMZuZnYlcCVA9+7d6+x5yyudiqoq+rXP5LA+bcnKzKBDy8YM6tiCfh0yaZyuReYiIhK/Yl3k64BuEbe7hofVNE6OmTUEWhLa6O1b3P0x4DGAkSNH1tli9zMP7sqZB3etq6cTERGpV7E+Ssl0oJ+Z9TKzRsA5wKRq40wCLgpfPxP4KKj14yIiIokmpnPk4XXePwXeI7T72VPuvtDM7gRmuPsk4EngWTNbDmwjVPYiIiIShZivI3f3ycDkasNuj7heAvwo1jlERESSkQ4ALiIiksBU5CIiIglMRS4iIpLAVOQiIiIJTEUuIiKSwFTkIiIiCUxFLiIiksBU5CIiIglMRS4iIpLAVOQiIiIJTEUuIiKSwFTkIiIiCUxFLiIiksBU5CIiIglMRS4iIpLAzN2DzrDXzCwXWF0HT9UO2FIHzxMPNC3xSdMSnzQt8UnTUrse7p5V0x0JWeR1xcxmuPvIoHPUBU1LfNK0xCdNS3zStOwbLVoXERFJYCpyERGRBJbqRf5Y0AHqkKYlPmla4pOmJT5pWvZBSq8jFxERSXSpPkcuIiKS0FK2yM1snJllm9lyM7s16Dx7w8y6mdlUM1tkZgvN7Prw8DvMbJ2ZzQlfTgw6azTMbJWZzQ9nnhEe1sbMPjCzZeGfrYPOuSdmNiDivZ9jZjvM7IZE+VzM7Ckz22xmCyKG1fg5WMiD4b+feWY2Irjk31XLtNxnZkvCeV83s1bh4T3NrDji8/lHYMFrUMu01Po7ZWa/DH8u2WZ2QjCpa1bLtPwrYjpWmdmc8PC4/Vx28x0czN+Lu6fcBUgDvgZ6A42AucDgoHPtRf5OwIjw9ebAUmAwcAfwf0Hn24fpWQW0qzbsXuDW8PVbgXuCzrmX05QGbAR6JMrnAhwJjAAW7OlzAE4EpgAGjAa+Cjp/FNNyPNAwfP2eiGnpGTlevF1qmZYaf6fC3wNzgQygV/h7Li3oadjdtFS7/37g9nj/XHbzHRzI30uqzpGPApa7+wp3LwMmAqcFnClq7r7B3WeFrxcAi4Euwaaqc6cBE8LXJwA/CC7KPjkW+Nrd6+LARfXC3T8BtlUbXNvncBrwjId8CbQys071EjQKNU2Lu7/v7hXhm18CXes92D6o5XOpzWnARHcvdfeVwHJC33dxYXfTYmYGnAW8WK+h9sFuvoMD+XtJ1SLvAqyNuJ1DghahmfUEhgNfhQf9NLzo5qlEWBwd5sD7ZjbTzK4MD+vg7hvC1zcCHYKJts/O4dtfSIn4uUDtn0Oi/w1dSmgOaZdeZjbbzP5jZkcEFWov1fQ7lcifyxHAJndfFjEs7j+Xat/Bgfy9pGqRJwUzywReBW5w9x3Ao0AfYBiwgdBiqkRwuLuPAMYD15rZkZF3emjZVMLsXmFmjYBTgZfDgxL1c/mWRPscamNmtwEVwPPhQRuA7u4+HPg58IKZtQgqX5SS4neqmnP59j+/cf+51PAd/I36/HtJ1SJfB3SLuN01PCxhmFk6oV+g5939NQB33+Tule5eBTxOHC1S2x13Xxf+uRl4nVDuTbsWPYV/bg4u4V4bD8xy902QuJ9LWG2fQ0L+DZnZxcDJwPnhL1rCi6G3hq/PJLReuX9gIaOwm9+pRP1cGgKnA//aNSzeP5eavoMJ6O8lVYt8OtDPzHqF557OASYFnClq4XVJTwKL3f0vEcMj17n8EFhQ/bHxxsyamVnzXdcJbZC0gNDncVF4tIuAN4NJuE++NWeRiJ9LhNo+h0nAheGtcUcD+RGLFOOSmY0DfgGc6u5FEcOzzCwtfL030A9YEUzK6Ozmd2oScI6ZZZhZL0LTMq2+8+2DscASd8/ZNSCeP5favoMJ6u8l6K3/groQ2opwKaH/8m4LOs9eZj+c0CKbecCc8OVE4Flgfnj4JKBT0FmjmJbehLaynQss3PVZAG2BfwPLgA+BNkFnjXJ6mgFbgZYRwxLicyH0z8cGoJzQOrzLavscCG19+3D472c+MDLo/FFMy3JC6yl3/c38IzzuGeHfvTnALOCUoPNHMS21/k4Bt4U/l2xgfND59zQt4eFPA1dXGzduP5fdfAcH8veiI7uJiIgksFRdtC4iIpIUVOQiIiIJTEUuIiKSwFTkIiIiCUxFLiIiksBU5CJS78ysoZm9Z2ZDdjNOCzP71MwS7fC8IvVKRS6SYMysbcSpHTdGnM5yp5k9EsPXPdrMDquL5/LQyUsuAO4OHyELM2tlZtdEjLMDuAL4S83PIiKA9iMXSWRmdgew093/nOivFT75xNvuPjQWzy+SrDRHLpIkwnPMb4ev32FmE8KLpleb2elmdq+ZzTezdyPmgg8On1lqZnhR967jRP/MzBaFz641MVyyVwM3huf+jwgfQvNVM5sevnwv4rWfNbMvzGyZmV0R5ST8CegTfv776vwNEklSDYMOICIx0wc4BhgMfAGc4e6/MLPXgZPM7B3g78Bp7p5rZmcDfyR0is9bgV7uXmpmrdw9z8z+QcQcuZm9APzV3T8zs+7Ae8Cg8GsfCIwmdMja2Wb2jruv30PeW4Gh7j6s7t4CkeSnIhdJXlPcvdzM5gNpwLvh4fOBnsAAYCjwQegcEKQROg42hI4h/byZvQG8UcvzjwUGhx8L0CJ8WkeAN929GCg2s6mEzs5V2/OIyH5QkYskr1IAd68ys3L/3wYxVYT+9g1Y6O5janjsScCRwCnAbWZ2QA3jNABGu3tJ5MBwsVff+EYb44jEiNaRi6SubCDLzMZA6PzKZjbEzBoA3dx9KnAL0BLIBAqA5hGPfx+4btcNMxsWcd9pZtbYzNoCRxM6dfCeVH9+EYmCilwkRbl7GXAmcI+ZzSV0KsbDCC1ify68SH428KC75wFvAT/ctbEb8DNgZHiDuEWENobbZR4wFfgS+H0U68dx963Af81sgTZ2E4medj8TkTpVn7vEiYjmyEVERBKa5shFREQSmObIRUREEpiKXEREJIGpyEVERBKYilxERCSBqchFREQSmIpcREQkgf1/EmmhdZ249TIAAAAASUVORK5CYII=\n", 262 | "text/plain": [ 263 | "
" 264 | ] 265 | }, 266 | "metadata": { 267 | "needs_background": "light" 268 | }, 269 | "output_type": "display_data" 270 | } 271 | ], 272 | "source": [ 273 | "# visualise the various terms\n", 274 | "t = torch.arange(1, nb_steps+1)\n", 275 | "plt.figure(figsize=(8, 8))\n", 276 | "plt.title(\"Variance of q(x_t | x_0) for t=0 -> t=T\")\n", 277 | "plt.xlabel(\"Timestep `t`\")\n", 278 | "plt.ylabel(\"1 - alpha_hat_t\")\n", 279 | "plt.plot(t[:], diff_process.prior_variance[:])\n", 280 | "plt.show()" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": 8, 286 | "id": "ef83a83b-1b8c-4df9-b2a7-8834dbd1829e", 287 | "metadata": {}, 288 | "outputs": [ 289 | { 290 | "name": "stderr", 291 | "output_type": "stream", 292 | "text": [ 293 | "100%|██████████| 199/199 [00:00<00:00, 4911.25it/s]\n" 294 | ] 295 | } 296 | ], 297 | "source": [ 298 | "with torch.no_grad():\n", 299 | " x0 = torch.tensor(X).float()\n", 300 | " x_t = x0\n", 301 | " out = [x0]\n", 302 | " batch = x0.shape[0]\n", 303 | " for t in tqdm(range(1, nb_steps)):\n", 304 | " t_ = torch.full((batch,), t, dtype=torch.long, device=x0.device)\n", 305 | " eps = torch.randn_like(x0)\n", 306 | "# x_t = diff_process.sample_q(x0, eps, t_)\n", 307 | " x_t = diff_process.sample_q_next(x_t, eps, t_)\n", 308 | " out.append(x_t.cpu())" 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": 9, 314 | "id": "2b228690-3308-4f01-82f5-7daac59e09d7", 315 | "metadata": {}, 316 | "outputs": [ 317 | { 318 | "name": "stderr", 319 | "output_type": "stream", 320 | "text": [ 321 | "/home/angusturner/miniconda3/envs/py38/lib/python3.8/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variable as a keyword arg: y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.\n", 322 | " warnings.warn(\n" 323 | ] 324 | }, 325 | { 326 | "data": { 327 | "image/png": "\n", 328 | "text/plain": [ 329 | "
" 330 | ] 331 | }, 332 | "metadata": { 333 | "needs_background": "light" 334 | }, 335 | "output_type": "display_data" 336 | } 337 | ], 338 | "source": [ 339 | "# plot the final step\n", 340 | "x, y = out[-1][:, 0], out[-1][:, 1]\n", 341 | "fig, ax = plt.subplots()\n", 342 | "# ax.scatter(x, y)\n", 343 | "sns.kdeplot(x, y, fill=True, ax=ax, bw_adjust=0.2)\n", 344 | "ax.set_xlim(-3.5, 3.5)\n", 345 | "ax.set_ylim(-3.5, 3.5)\n", 346 | "plt.show()\n", 347 | "\n", 348 | "# ax.set_xlim(-3, 3)\n", 349 | "# plt.scatter(x, y)\n", 350 | "# # plt.contour(out[-1])\n", 351 | "# sns.kdeplot(x, y, fill=True)" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 10, 357 | "id": "0bdd9800-f07a-4329-ad4d-c9415dfec2b8", 358 | "metadata": {}, 359 | "outputs": [], 360 | "source": [ 361 | "import os\n", 362 | "\n", 363 | "def plot_density(x, y, i=0, out_dir='pngs'):\n", 364 | " fig, ax = plt.subplots()\n", 365 | "# ax.scatter(x, y)\n", 366 | " sns.kdeplot(x=x, y=y, fill=True, ax=ax, bw_adjust=0.2)\n", 367 | " ax.set_xlim(-3.5, 3.5)\n", 368 | " ax.set_ylim(-3.5, 3.5)\n", 369 | " i_str = \"{}\".format(i).zfill(3)\n", 370 | " out_path = os.path.join(out_dir, f\"fig_{i_str}.png\")\n", 371 | " plt.savefig(out_path, dpi=200, bbox_inches='tight')\n", 372 | " plt.close(fig)" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": 11, 378 | "id": "444fa390-050c-4705-bec0-eb34f09e57b2", 379 | "metadata": {}, 380 | "outputs": [ 381 | { 382 | "name": "stderr", 383 | "output_type": "stream", 384 | "text": [ 385 | "100%|██████████| 200/200 [06:58<00:00, 2.09s/it]\n" 386 | ] 387 | } 388 | ], 389 | "source": [ 390 | "for i in tqdm(range(0, nb_steps)):\n", 391 | " x, y = out[i][:, 0], out[i][:, 1]\n", 392 | " plot_density(x, y, i)" 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "execution_count": 29, 398 | "id": "65414b45-5e93-421c-aa78-87df22dc3a83", 399 | "metadata": {}, 400 | "outputs": [ 401 | { 402 | "data": { 403 | "text/plain": [ 404 | "tensor([0])" 405 | ] 406 | }, 407 | "execution_count": 29, 408 | "metadata": {}, 409 | "output_type": "execute_result" 410 | } 411 | ], 412 | "source": [ 413 | "torch.randint(0, 1, size=(1,))" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": null, 419 | "id": "f5d16e75-558c-46e5-abdd-82164dda7eb3", 420 | "metadata": {}, 421 | "outputs": [], 422 | "source": [] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": null, 427 | "id": "13beb3ec-4d24-4b3e-ad2b-8c2bbcd8a920", 428 | "metadata": {}, 429 | "outputs": [], 430 | "source": [] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": null, 435 | "id": "c48e6eeb-df69-4ec3-8451-b8fe132f7826", 436 | "metadata": {}, 437 | "outputs": [], 438 | "source": [] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": null, 443 | "id": "276a9cc3-17bb-4b69-b931-54a2518c4e93", 444 | "metadata": {}, 445 | "outputs": [], 446 | "source": [] 447 | } 448 | ], 449 | "metadata": { 450 | "kernelspec": { 451 | "display_name": "Python 3", 452 | "language": "python", 453 | "name": "python3" 454 | }, 455 | "language_info": { 456 | "codemirror_mode": { 457 | "name": "ipython", 458 | "version": 3 459 | }, 460 | "file_extension": ".py", 461 | "mimetype": "text/x-python", 462 | "name": "python", 463 | "nbconvert_exporter": "python", 464 | "pygments_lexer": "ipython3", 465 | "version": "3.8.5" 466 | } 467 | }, 468 | "nbformat": 4, 469 | "nbformat_minor": 5 470 | } 471 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | target-version = ['py38'] 4 | include = '\.pyi?$' 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # popgen dependency 2 | # -e git+git://github.com/Popgun-Labs/PopGen.git#popgen 3 | 4 | # note: for GPU version, install via conda 5 | torch>=1.9 6 | torchvision~=0.9.1 7 | 8 | tqdm>=4.46.0 9 | numpy>=1.16.1 10 | seaborn 11 | hydra-core==0.11.3 12 | wandb~=0.10.21 13 | omegaconf==1.4.1 14 | black~=19.10b0 15 | setuptools~=50.3.2 16 | popgen~=0.0.2 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from setuptools import setup, find_packages 3 | from diffuse import __version__ 4 | 5 | readme = open("README.md").read() 6 | requirements_txt = open("requirements.txt").read().split("\n") 7 | requirements = list(filter(lambda x: "--extra" not in x and x is not "", requirements_txt)) 8 | 9 | dependency_links = list(filter(lambda x: "--extra" in x, requirements_txt)) 10 | dependency_links = list(map(lambda x: x.split(" ")[-1], dependency_links)) 11 | 12 | setup( 13 | # Metadata 14 | name="diffuse", 15 | version=__version__, 16 | author="Angus Turner", 17 | author_email="angusturner27@gmail.com", 18 | url="https://github.com/angusturner/phasenet", 19 | description="Experiments in Diffusion Probabilistic Modelling", 20 | long_description=readme, 21 | packages=find_packages(exclude=("test",)), 22 | zip_safe=True, 23 | install_requires=requirements, 24 | dependency_links=dependency_links, 25 | include_package_data=True, 26 | ) 27 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from omegaconf import DictConfig 3 | 4 | from popgen.setup import setup_worker, setup_loaders 5 | 6 | 7 | @hydra.main(config_path="config/config.yaml") 8 | def train(cfg: DictConfig) -> None: 9 | # get the experiment name 10 | name = cfg.get("name", False) 11 | if not name: 12 | raise Exception("Must specify experiment name on CLI. e.g. `python train.py name=vae ...`") 13 | 14 | # setup the worker 15 | overwrite = cfg.get("overwrite", False) 16 | worker, cfg = setup_worker(name, cfg, overwrite=overwrite, module="diffuse") 17 | 18 | # setup data loaders 19 | train_loader, test_loader = setup_loaders( 20 | dataset_class=cfg["dataset_class"], data_opts=cfg["dataset"], loader_opts=cfg["loader"], module="diffuse" 21 | ) 22 | 23 | # train 24 | worker.run(train_loader, test_loader, cfg["nb_epoch"]) 25 | 26 | 27 | if __name__ == "__main__": 28 | train() 29 | --------------------------------------------------------------------------------