├── .github └── workflows │ ├── python-publish.yml │ └── python-test-build.yml ├── .gitignore ├── .projectignore ├── LICENSE ├── __init__.py ├── callbacks ├── __init__.py ├── ema.py └── logger.py ├── config ├── dataset │ └── mnist.yaml ├── model │ ├── unet_class_conditioned.yaml │ └── unet_paper.yaml ├── model_dataset │ ├── unet_class_conditioned-mnist.yaml │ └── unet_paper-mnist.yaml ├── model_scheduler │ ├── unet_paper-cosine.yaml │ └── unet_paper-linear.yaml ├── optimizer │ └── adam_ddpm.yaml ├── scheduler │ ├── cosine.yaml │ ├── linear.yaml │ └── tan.yaml └── train.yaml ├── generate.py ├── model ├── __init__.py ├── classifier_free_ddpm.py ├── ddpm.py ├── distributions.py ├── unet.py └── unet_class.py ├── pyproject.toml ├── readme.md ├── readme_pip.md ├── tests └── test_unet.py ├── train.py ├── utils ├── __init__.py └── paths.py └── variance_scheduler ├── __init__.py ├── abs_var_scheduler.py ├── cosine.py ├── hyperbolic_secant.py └── linear.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | workflow_dispatch: 15 | 16 | permissions: 17 | contents: read 18 | 19 | jobs: 20 | deploy: 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.9' 29 | - name: Assemble python package 30 | run: | 31 | mv model/ ddpm/ 32 | cp -r variance_scheduler/ ddpm/ 33 | - name: Install dependencies 34 | run: | 35 | python -m pip install --upgrade pip 36 | pip install build 37 | python -m pip install .[dev] 38 | - name: Test model/ folder 39 | run: | 40 | mv ddpm model 41 | python -m pytest 42 | mv model ddpm 43 | - name: Build package 44 | run: python -m build 45 | - name: Publish package 46 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 47 | with: 48 | user: __token__ 49 | password: ${{ secrets.PYPI_API_TOKEN }} 50 | -------------------------------------------------------------------------------- /.github/workflows/python-test-build.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Test and build python package 10 | 11 | on: 12 | push: 13 | tags: 14 | - 'v*.*.*' 15 | branches: 16 | - master 17 | schedule: 18 | # weekly 19 | - cron: '0 0 * * 1' 20 | 21 | permissions: 22 | contents: read 23 | 24 | jobs: 25 | # test on several python versions 26 | test-build: 27 | strategy: 28 | matrix: 29 | python-version: [3.8, 3.9, '3.10', 3.11] 30 | name: Test and build python package version ${{ matrix.python-version }} 31 | 32 | runs-on: ubuntu-latest 33 | 34 | steps: 35 | - uses: actions/checkout@v3 36 | - name: Set up Python 37 | uses: actions/setup-python@v3 38 | with: 39 | python-version: '${{ matrix.python-version }}' 40 | - name: Assemble python package 41 | run: | 42 | mv model/ ddpm/ 43 | cp -r variance_scheduler/ ddpm/ 44 | - name: Install dependencies 45 | run: | 46 | python -m pip install --upgrade pip 47 | pip install build 48 | python -m pip install .[dev] 49 | - name: Test model/ folder 50 | run: | 51 | mv ddpm model 52 | python -m pytest 53 | mv model ddpm 54 | - name: Build package 55 | run: python -m build 56 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | envs/ 2 | <<<<<<< HEAD 3 | saved_models/ 4 | .idea/ 5 | .github/ 6 | .vscode/ 7 | 8 | 9 | 10 | ======= 11 | .idea 12 | saved_models/ 13 | >>>>>>> bef2d25 (update gitignore) 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | share/python-wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .nox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | *.py,cover 63 | .hypothesis/ 64 | .pytest_cache/ 65 | cover/ 66 | 67 | # Translations 68 | *.mo 69 | *.pot 70 | 71 | # Django stuff: 72 | *.log 73 | local_settings.py 74 | db.sqlite3 75 | db.sqlite3-journal 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | .pybuilder/ 89 | target/ 90 | 91 | # Jupyter Notebook 92 | .ipynb_checkpoints 93 | 94 | # IPython 95 | profile_default/ 96 | ipython_config.py 97 | 98 | # pyenv 99 | # For a library or package, you might want to ignore these files since the code is 100 | # intended to run in multiple environments; otherwise, check them in: 101 | # .python-version 102 | 103 | # pipenv 104 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 105 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 106 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 107 | # install all needed dependencies. 108 | #Pipfile.lock 109 | 110 | # poetry 111 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 112 | # This is especially recommended for binary packages to ensure reproducibility, and is more 113 | # commonly ignored for libraries. 114 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 115 | #poetry.lock 116 | 117 | # pdm 118 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 119 | #pdm.lock 120 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 121 | # in version control. 122 | # https://pdm.fming.dev/#use-with-ide 123 | .pdm.toml 124 | 125 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 126 | __pypackages__/ 127 | 128 | # Celery stuff 129 | celerybeat-schedule 130 | celerybeat.pid 131 | 132 | # SageMath parsed files 133 | *.sage.py 134 | 135 | # Environments 136 | .env 137 | .venv 138 | env/ 139 | venv/ 140 | ENV/ 141 | env.bak/ 142 | venv.bak/ 143 | 144 | # Spyder project settings 145 | .spyderproject 146 | .spyproject 147 | 148 | # Rope project settings 149 | .ropeproject 150 | 151 | # mkdocs documentation 152 | /site 153 | 154 | # mypy 155 | .mypy_cache/ 156 | .dmypy.json 157 | dmypy.json 158 | 159 | # Pyre type checker 160 | .pyre/ 161 | 162 | # pytype static type analyzer 163 | .pytype/ 164 | 165 | # Cython debug symbols 166 | cython_debug/ 167 | 168 | # PyCharm 169 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 170 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 171 | # and can be added to the global gitignore or merged into this file. For a more nuclear 172 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 173 | #.idea/ 174 | 175 | 176 | launch.json 177 | -------------------------------------------------------------------------------- /.projectignore: -------------------------------------------------------------------------------- 1 | # This file contains a list of match patterns that instructs 2 | # anaconda-project to exclude certain files or directories when 3 | # building a project archive. The file format is a simplfied 4 | # version of Git's .gitignore file format. In fact, if the 5 | # project is hosted in a Git repository, these patterns can be 6 | # merged into the .gitignore file and this file removed. 7 | # See the anaconda-project documentation for more details. 8 | 9 | # Python caching 10 | *.pyc 11 | *.pyd 12 | *.pyo 13 | __pycache__/ 14 | 15 | # Jupyter & Spyder stuff 16 | .ipynb_checkpoints/ 17 | .Trash-*/ 18 | /.spyderproject 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Michele De Vita 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Michedev/DDPMs-Pytorch/4cdb52d5cee7070ddc243ecb781c9fe5e68d0ba8/__init__.py -------------------------------------------------------------------------------- /callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Michedev/DDPMs-Pytorch/4cdb52d5cee7070ddc243ecb781c9fe5e68d0ba8/callbacks/__init__.py -------------------------------------------------------------------------------- /callbacks/ema.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | 6 | 7 | class EMA(pl.Callback): 8 | """ 9 | Exponential Moving Average 10 | Let \beta the smoothing parameter, p the current parameter value and v the accumulated value, the EMA is calculated 11 | as follows 12 | 13 | v_t = \beta * p_{t-1} + (1 - \beta) * v_{t-1} 14 | p_t = v_t 15 | """ 16 | 17 | def __init__(self, decay_factor: float): 18 | assert 0.0 <= decay_factor <= 1.0 19 | self.decay_factor = decay_factor 20 | self.dict_params = dict() 21 | 22 | @torch.no_grad() 23 | def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 24 | """ 25 | For each parameter in the model, we add the parameter to the dictionary 26 | 27 | :param trainer: The trainer object 28 | :type trainer: "pl.Trainer" 29 | :param pl_module: The LightningModule that is being trained 30 | :type pl_module: "pl.LightningModule" 31 | """ 32 | for n, p in pl_module.named_parameters(): 33 | self.dict_params[n] = p 34 | 35 | @torch.no_grad() 36 | def on_train_batch_start( 37 | self, 38 | trainer: "pl.Trainer", 39 | pl_module: "pl.LightningModule", 40 | batch: Any, 41 | batch_idx: int, 42 | unused: int = 0, 43 | ) -> None: 44 | """ 45 | For each parameter in the model, we multiply the parameter by a decay factor and add the current 46 | parameter multiplied by the decay factor to the parameter in the dictionary 47 | 48 | :param trainer: The trainer object 49 | :type trainer: "pl.Trainer" 50 | :param pl_module: The LightningModule that is being trained 51 | :type pl_module: "pl.LightningModule" 52 | :param batch: The batch of data that is being passed to the model 53 | :type batch: Any 54 | :param batch_idx: the index of the batch within the current epoch 55 | :type batch_idx: int 56 | :param unused: int = 0, defaults to 0 57 | :type unused: int (optional) 58 | """ 59 | for n, p in pl_module.named_parameters(): 60 | self.dict_params[n] = self.dict_params[n] * (1.0 - self.decay_factor) + p * self.decay_factor 61 | p[:] = self.dict_params[n][:] 62 | -------------------------------------------------------------------------------- /callbacks/logger.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from pytorch_lightning.callbacks import Callback 3 | import pytorch_lightning as pl 4 | import torchvision 5 | import torch 6 | 7 | class LoggerCallback(Callback): 8 | 9 | def __init__(self, freq_train_log: int, freq_train_norm_gradients: int, batch_size_gen_images: int) -> None: 10 | super().__init__() 11 | self.freq_train = freq_train_log 12 | self.freq_train_norm_gradients = freq_train_norm_gradients 13 | self.batch_size_gen_images = batch_size_gen_images 14 | 15 | def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: dict, batch: Any, batch_idx: int) -> None: 16 | if trainer.global_step % self.freq_train == 0: 17 | pl_module.log("train/loss", outputs["loss"], on_step=True, on_epoch=False, prog_bar=True, logger=True) 18 | pl_module.log("train/noise_loss", outputs["noise_loss"], on_step=True, on_epoch=False, prog_bar=True, logger=True) 19 | if outputs['vlb_loss'] is not None: 20 | pl_module.log("train/vlb_loss", outputs["vlb_loss"], on_step=True, on_epoch=False, prog_bar=True, logger=True) 21 | 22 | def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 23 | if trainer.global_step % self.freq_train_norm_gradients == 0: 24 | norm_grad = 0 25 | for p in pl_module.parameters(): 26 | if p.grad is not None: 27 | norm_grad += p.grad.norm(2).item() 28 | pl_module.log("train/norm_grad", norm_grad, on_step=True, on_epoch=False, prog_bar=True, logger=True) 29 | 30 | def on_validation_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: dict, batch: Any, batch_idx: int) -> None: 31 | pl_module.log("val/loss", outputs["loss"], on_step=True, on_epoch=True, prog_bar=True, logger=True) 32 | pl_module.log("val/noise_loss", outputs["noise_loss"], on_step=True, on_epoch=True, prog_bar=True, logger=True) 33 | if outputs['vlb_loss'] is not None: 34 | pl_module.log("val/vlb_loss", outputs["vlb_loss"], on_step=True, on_epoch=True, prog_bar=True, logger=True) 35 | 36 | def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 37 | gen_images = pl_module.generate(batch_size=self.batch_size_gen_images) # Generate images 38 | gen_images = torchvision.utils.make_grid(gen_images) # Convert to grid 39 | pl_module.logger.experiment.add_image('gen_val_images', gen_images, trainer.current_epoch) # Log the images 40 | torchvision.utils.save_image(gen_images, f'gen_images/epoch={pl_module.current_epoch}.png') # Save the images 41 | -------------------------------------------------------------------------------- /config/dataset/mnist.yaml: -------------------------------------------------------------------------------- 1 | width: 28 2 | height: 28 3 | channels: 1 4 | num_classes: 10 5 | files_location: ~/.cache/torchvision_dataset 6 | train: 7 | _target_: torchvision.datasets.MNIST 8 | root: ${dataset.files_location} 9 | train: true 10 | download: true 11 | transform: 12 | _target_: torchvision.transforms.ToTensor 13 | val: 14 | _target_: torchvision.datasets.MNIST 15 | root: ${dataset.files_location} 16 | train: false 17 | download: true 18 | transform: 19 | _target_: torchvision.transforms.ToTensor -------------------------------------------------------------------------------- /config/model/unet_class_conditioned.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.classifier_free_ddpm.GaussianDDPMClassifierFreeGuidance 2 | denoiser_module: 3 | _target_: model.unet_class.UNetTimeStepClassConditioned 4 | channels: [3, 128, 256, 256, 384] 5 | kernel_sizes: [3, 3, 3, 3] 6 | strides: [1, 1, 1, 1] 7 | paddings: [1, 1, 1, 1] 8 | p_dropouts: [0.1, 0.1, 0.1, 0.1] 9 | time_embed_size: 100 #did not found this hp on the paper 10 | downsample: true 11 | num_classes: ${dataset.num_classes} 12 | class_embed_size: 3 13 | assert_shapes: false 14 | T: ${noise_steps} 15 | width: ${dataset.width} 16 | height: ${dataset.height} 17 | logging_freq: 1_000 18 | input_channels: ${dataset.channels} 19 | num_classes: ${dataset.num_classes} 20 | v: 0.2 21 | w: 0.3 22 | p_uncond: 0.2 23 | -------------------------------------------------------------------------------- /config/model/unet_paper.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.ddpm.GaussianDDPM 2 | denoiser_module: 3 | _target_: model.unet.UNetTimeStep 4 | channels: [3, 128, 256, 256, 384] 5 | kernel_sizes: [3, 3, 3, 3] 6 | strides: [1, 1, 1, 1] 7 | paddings: [1, 1, 1, 1] 8 | p_dropouts: [0.1, 0.1, 0.1, 0.1] 9 | time_embed_size: 100 #did not found this hp on the paper 10 | downsample: true 11 | T: ${noise_steps} 12 | lambda_variational: 0.0001 13 | width: ${dataset.width} 14 | height: ${dataset.height} 15 | logging_freq: 1_000 16 | input_channels: ${dataset.channels} 17 | vlb: false 18 | init_step_vlb: 1_000 -------------------------------------------------------------------------------- /config/model_dataset/unet_class_conditioned-mnist.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | model: 4 | denoiser_module: 5 | channels: [1, 32, 64, 64, 96] 6 | -------------------------------------------------------------------------------- /config/model_dataset/unet_paper-mnist.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | model: 4 | denoiser_module: 5 | channels: [1, 32, 64, 64, 96] 6 | -------------------------------------------------------------------------------- /config/model_scheduler/unet_paper-cosine.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | model: 4 | denoiser_module: 5 | p_dropouts: [0.3, 0.3, 0.3, 0.3] -------------------------------------------------------------------------------- /config/model_scheduler/unet_paper-linear.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | model: 4 | denoiser_module: 5 | p_dropouts: [0.1, 0.1, 0.1, 0.1] -------------------------------------------------------------------------------- /config/optimizer/adam_ddpm.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.Adam 2 | _partial_: true 3 | lr: 0.0001 4 | weight_decay: 0.0 5 | -------------------------------------------------------------------------------- /config/scheduler/cosine.yaml: -------------------------------------------------------------------------------- 1 | _target_: variance_scheduler.cosine.CosineScheduler 2 | s: 0.008 3 | T: ${noise_steps} -------------------------------------------------------------------------------- /config/scheduler/linear.yaml: -------------------------------------------------------------------------------- 1 | _target_: variance_scheduler.linear.LinearScheduler 2 | beta_start: 0.000025 3 | beta_end: 0.005 4 | T: ${noise_steps} -------------------------------------------------------------------------------- /config/scheduler/tan.yaml: -------------------------------------------------------------------------------- 1 | _target_: variance_scheduler.hyperbolic_secant.HyperbolicSecant 2 | lambda_min: -20 3 | lambda_max: 20 4 | T: ${noise_steps} -------------------------------------------------------------------------------- /config/train.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model: unet_paper 3 | - scheduler: linear 4 | - dataset: mnist 5 | - optimizer: adam_ddpm 6 | - optional model_dataset: ${model}-${dataset} 7 | - optional model_scheduler: ${model}-${scheduler} 8 | 9 | batch_size: 128 10 | noise_steps: 4_000 # T 11 | accelerator: null # from pytorch-lightning, the hardware platform used to train the neural network 12 | devices: null # the devices to use in a given hardware platform (see argument above) 13 | gradient_clip_val: 0.0 # gradient clip value - set to 0.0 to disable 14 | gradient_clip_algorithm: norm # gradient clip algorithm - either 'norm' or 'value' 15 | ema: true # exponential moving average 16 | ema_decay: 0.99 # exponential moving average decay rate 17 | early_stop: true # stop training if the validation loss does not improve for patience epochs 18 | patience: 10 # early stopping patience; set to -1 to disable 19 | min_delta: 0.0 # minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement. 20 | ckpt: null # path to checkpoint 21 | seed: 1337 # random seed 22 | freq_logging: 100 # frequency of logging 23 | freq_logging_norm_grad: 100 # frequency of logging the norm of the gradient 24 | batch_size_gen_images: 64 # batch size for generating images 25 | 26 | hydra: 27 | run: 28 | dir: saved_models/${now:%Y_%m_%d_%H_%M_%S} # where run train.py it will create under {current working directory}/saved_models a folder with the current date and time and it will be setted as new cwd 29 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import pytorch_lightning as pl 3 | import argparse 4 | 5 | import torch 6 | from omegaconf import OmegaConf 7 | from path import Path 8 | from tqdm import tqdm 9 | 10 | from model.classifier_free_ddpm import GaussianDDPMClassifierFreeGuidance 11 | import torchvision 12 | 13 | from utils.paths import SCHEDULER 14 | 15 | scheduler_paths = [p for p in SCHEDULER.files('*.yaml')] 16 | scheduler_names = [x.basename().replace('.yaml', '') for x in scheduler_paths] 17 | scheduler_map = {name: path for name, path in zip(scheduler_names, scheduler_paths)} 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('-r', '--run', type=Path, required=True, help='Path to the checkpoint file') 23 | parser.add_argument('--seed', '-s', type=int, default=0, help='Random seed') 24 | parser.add_argument('--device', '-d', type=str, default='cpu', help='Device to use') 25 | parser.add_argument('--batch-size', '-b', type=int, default=16, help='Batch size') 26 | parser.add_argument('-w', type=float, default=None, help='Class guidance') 27 | parser.add_argument('--scheduler', choices=scheduler_names, default=None, 28 | help='use a custom scheduler', dest='scheduler') 29 | parser.add_argument('-T', type=int, default=None, help='Number of diffusion steps') 30 | return parser.parse_args() 31 | 32 | 33 | @torch.no_grad() 34 | def main(): 35 | """ 36 | Generate images from a trained model in the checkpoint folder 37 | """ 38 | args = parse_args() 39 | 40 | print(args) 41 | run_path = args.run.absolute() 42 | pl.seed_everything(args.seed) 43 | assert run_path.exists(), run_path 44 | assert run_path.basename().endswith('.ckpt'), run_path 45 | print('loading model from', run_path) 46 | hparams = OmegaConf.load(run_path.parent / 'config.yaml') 47 | if args.T is not None: 48 | hparams.T = args.T 49 | if args.w is None: 50 | args.w = hparams.model.w 51 | model_hparams = hparams.model 52 | denoiser = hydra.utils.instantiate(model_hparams.denoiser_module) 53 | if args.scheduler is None: 54 | scheduler = hydra.utils.instantiate(hparams.scheduler) 55 | else: 56 | scheduler_conf = OmegaConf.load(scheduler_map[args.scheduler]) 57 | scheduler_conf.T = hparams.noise_steps 58 | scheduler = hydra.utils.instantiate(scheduler_conf) 59 | model = GaussianDDPMClassifierFreeGuidance( 60 | denoiser_module=denoiser, T=model_hparams.T, 61 | w=args.w, p_uncond=model_hparams.p_uncond, width=model_hparams.width, 62 | height=model_hparams.height, input_channels=model_hparams.input_channels, 63 | num_classes=model_hparams.num_classes, logging_freq=1000, v=model_hparams.v, 64 | variance_scheduler=scheduler).to(args.device) 65 | model.load_state_dict(torch.load(run_path, map_location=args.device)['state_dict']) 66 | model = model.eval() 67 | images = [] 68 | model.on_fit_start() 69 | 70 | for i_c in tqdm(range(model.num_classes)): 71 | c = torch.zeros((args.batch_size, model.num_classes), device=args.device) 72 | c[:, i_c] = 1 73 | gen_images = model.generate(batch_size=args.batch_size, c=c) 74 | images.append(gen_images) 75 | images = torch.cat(images, dim=0) 76 | # save images 77 | torchvision.utils.save_image(images, run_path.parent / 'generated_images.png', nrow=4, padding=2, normalize=True) 78 | 79 | 80 | if __name__ == '__main__': 81 | main() 82 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | # Auto-generated by initgen - Made by Mikedev 2 | from .unet import positional_embedding_vector 3 | from .unet import timestep_embedding 4 | from .unet import init_zero 5 | from .unet import ResBlockTimeEmbed 6 | from .unet import ImageSelfAttention 7 | from .unet import UNetTimeStep 8 | from .unet_class import ResBlockTimeEmbedClassConditioned 9 | from .unet_class import UNetTimeStepClassConditioned 10 | from .classifier_free_ddpm import GaussianDDPMClassifierFreeGuidance 11 | from .ddpm import GaussianDDPM 12 | 13 | __all__ = ['positional_embedding_vector', 'timestep_embedding', 'init_zero', 'ResBlockTimeEmbed', 'ImageSelfAttention', 14 | 'UNetTimeStep', 'ResBlockTimeEmbedClassConditioned', 'UNetTimeStepClassConditioned', 15 | 'GaussianDDPMClassifierFreeGuidance', 'GaussianDDPM' ] -------------------------------------------------------------------------------- /model/classifier_free_ddpm.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, List, Union, Optional 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | import torchvision 6 | from path import Path 7 | from torch import nn 8 | from torch.nn.functional import one_hot 9 | 10 | from variance_scheduler.abs_var_scheduler import Scheduler 11 | from .distributions import x0_to_xt 12 | 13 | 14 | class GaussianDDPMClassifierFreeGuidance(pl.LightningModule): 15 | """ 16 | Implementation of "Classifier-Free Diffusion Guidance" 17 | """ 18 | 19 | def __init__(self, denoiser_module: nn.Module, T: int, 20 | w: float, p_uncond: float, width: int, 21 | height: int, input_channels: int, num_classes: int, 22 | logging_freq: int, v: float, variance_scheduler: 'Scheduler'): 23 | """ 24 | :param denoiser_module: The nn which computes the denoise step i.e. q(x_{t-1} | x_t, c) 25 | :param T: the amount of noising steps 26 | :param w: strength of class guidance 27 | :param p_uncond: probability of train a batch without class conditioning 28 | :param variance_scheduler: the variance scheduler cited in DDPM paper. See folder variance_scheduler for practical implementation 29 | :param width: image width 30 | :param height: image height 31 | :param input_channels: image input channels 32 | :param num_classes: number of classes 33 | :param logging_freq: frequency of logging loss function during training 34 | :param v: generative variance hyper-parameter 35 | """ 36 | assert 0.0 <= v <= 1.0, f'0.0 <= {v} <= 1.0' 37 | assert 0.0 <= w, f'0.0 <= {w}' 38 | assert 0.0 <= p_uncond <= 1.0, f'0.0 <= {p_uncond} <= 1.0' 39 | super().__init__() 40 | self.input_channels = input_channels 41 | self.denoiser_module = denoiser_module 42 | self.T = T 43 | self.w = w 44 | self.v = v 45 | self.var_scheduler = variance_scheduler 46 | self.alphas_hat: torch.FloatTensor = self.var_scheduler.get_alpha_hat().to(self.device) 47 | self.alphas: torch.FloatTensor = self.var_scheduler.get_alphas().to(self.device) 48 | self.betas = self.var_scheduler.get_betas().to(self.device) 49 | self.betas_hat = self.var_scheduler.get_betas_hat().to(self.device) 50 | 51 | self.p_uncond = p_uncond 52 | self.mse = nn.MSELoss() 53 | self.width = width 54 | self.height = height 55 | self.logging_freq = logging_freq 56 | self.iteration = 0 57 | self.num_classes = num_classes 58 | self.gen_images = Path('training_gen_images') 59 | self.gen_images.mkdir_p() 60 | 61 | def forward(self, x: torch.Tensor, t: torch.Tensor, c: torch.Tensor) -> torch.Tensor: 62 | """ 63 | predict the score (noise) to transition from step t to t-1 64 | :param x: input image [bs, c, w, h] 65 | :param t: time step [bs] 66 | :param c: class [bs, num_classes] 67 | :return: the predicted noise to transition from t to t-1 68 | """ 69 | return self.denoiser_module(x, t, c) 70 | 71 | def training_step(self, batch, batch_idx): 72 | return self._step(batch, batch_idx, 'train') 73 | 74 | def validation_step(self, batch, batch_idx): 75 | if batch_idx == 0 and self.current_epoch % 10 == 0: 76 | batch_size = 32 77 | for i_c in range(self.num_classes): 78 | c = torch.zeros(batch_size, self.num_classes, device=self.device) 79 | c[:, i_c] = 1 80 | x_c = self.generate(batch_size, c) 81 | x_c = torchvision.utils.make_grid(x_c) 82 | self.logger.experiment.add_image(f'epoch_gen_val_images_class_{i_c}', x_c, self.current_epoch) 83 | torchvision.utils.save_image(x_c, self.gen_images / f'epoch_{self.current_epoch}_class_{i_c}.png') 84 | 85 | return self._step(batch, batch_idx, 'valid') 86 | 87 | def _step(self, batch, batch_idx, dataset: Literal['train', 'valid']) -> torch.Tensor: 88 | """ 89 | train/validation step of DDPM. The logic is mostly taken from the original DDPM paper, 90 | except for the class conditioning part. 91 | """ 92 | X, y = batch 93 | with torch.no_grad(): 94 | X = X * 2 - 1 # normalize to -1, 1 95 | y = one_hot(y, self.num_classes).float() 96 | 97 | # dummy flags that with probability p_uncond, we train without class conditioning 98 | is_class_cond = torch.rand(size=(X.shape[0],1), device=X.device) >= self.p_uncond 99 | y = y * is_class_cond.float() # set to zero the batch elements not class conditioned 100 | t = torch.randint(0, self.T - 1, (X.shape[0],), device=X.device) # sample t uniformly from [0, T-1] 101 | t_expanded = t.reshape(-1, 1, 1, 1) 102 | eps = torch.randn_like(X) # [bs, c, w, h] 103 | alpha_hat_t = self.alphas_hat[t_expanded] # get \hat{\alpha}_t 104 | x_t = x0_to_xt(X, alpha_hat_t, eps) # go from x_0 to x_t in a single equation thanks to the step 105 | pred_eps = self(x_t, t / self.T, y) # predict the noise to transition from x_t to x_{t-1} 106 | loss = self.mse(eps, pred_eps) # compute the MSE between the predicted noise and the real noise 107 | 108 | # log every batch on validation set, otherwise log every self.logging_freq batches on training set 109 | if dataset == 'valid' or (self.iteration % self.logging_freq) == 0: 110 | self.log(f'loss/{dataset}_loss', loss, on_step=True) 111 | if dataset == 'train': 112 | norm_params = sum( 113 | [torch.norm(p.grad) for p in self.parameters() if 114 | hasattr(p, 'grad') and p.grad is not None]) 115 | self.log('grad_norm', norm_params) 116 | self.logger.experiment.add_image(f'{dataset}_pred_score', eps[0], self.iteration) 117 | with torch.no_grad(): 118 | self.log(f'noise/{dataset}_mean_eps', eps.mean(), on_step=True) 119 | self.log(f'noise/{dataset}_std_eps', eps.flatten(1).std(dim=1).mean(), on_step=True) 120 | self.log(f'noise/{dataset}_max_eps', eps.max(), on_step=True) 121 | self.log(f'noise/{dataset}_min_eps', eps.min(), on_step=True) 122 | 123 | self.iteration += 1 124 | return loss 125 | 126 | def configure_optimizers(self): 127 | return torch.optim.Adam(params=self.parameters(), lr=1e-4) 128 | 129 | def on_fit_start(self) -> None: 130 | self.betas = self.betas.to(self.device) 131 | self.betas_hat = self.betas_hat.to(self.device) 132 | self.alphas = self.alphas.to(self.device) 133 | self.alphas_hat = self.alphas_hat.to(self.device) 134 | 135 | def generate(self, batch_size: Optional[int] = None, c: Optional[torch.Tensor] = None, T: Optional[int] = None, 136 | get_intermediate_steps: bool = False) -> Union[torch.Tensor, List[torch.Tensor]]: 137 | """ 138 | Generate a new sample starting from pure random noise sampled from a normal standard distribution 139 | :param batch_size: the generated batch size 140 | :param c: the class conditional matrix [batch_size, num_classes]. By default, it will be deactivated by passing a matrix of full zeroes 141 | :param T: the number of generation steps. By default, it will be the number of steps of the training 142 | :param get_intermediate_steps: if true, it will all return the intermediate steps of the generation 143 | :return: the generated image or the list of intermediate steps 144 | """ 145 | T = T or self.T 146 | batch_size = batch_size or 1 147 | is_unconditioned = c is None 148 | if is_unconditioned: 149 | c = torch.zeros(batch_size, self.num_classes, device=self.device) 150 | if get_intermediate_steps: 151 | steps = [] 152 | z_t = torch.randn(batch_size, self.input_channels, # start with random noise sampled from N(0, 1) 153 | self.width, self.height, device=self.device) 154 | for t in range(T - 1, 0, -1): 155 | if get_intermediate_steps: 156 | steps.append(z_t) 157 | t = torch.LongTensor([t] * batch_size).to(self.device).view(-1, 1) 158 | t_expanded = t.view(-1, 1, 1, 1) 159 | if is_unconditioned: 160 | # compute unconditioned noise 161 | eps = self(z_t, t / T, c) # predict via nn the noise 162 | else: 163 | # compute class conditioned noise 164 | eps1 = (1 + self.w) * self(z_t, t / T, c) # compute class conditioned noise 165 | eps2 = self.w * self(z_t, t / T, c * 0) # compute unconditioned noise 166 | eps = eps1 - eps2 # compute noise difference 167 | alpha_t = self.alphas[t_expanded] 168 | z = torch.randn_like(z_t) 169 | alpha_hat_t = self.alphas_hat[t_expanded] 170 | # denoise step from x_t to x_{t-1} following the DDPM paper 171 | z_t = (z_t - ((1 - alpha_t) / torch.sqrt(1 - alpha_hat_t)) * eps) / (torch.sqrt(alpha_t)) + \ 172 | self.betas[t_expanded] * z 173 | z_t = (z_t + 1) / 2 # bring back to [0, 1] 174 | if get_intermediate_steps: 175 | steps.append(z_t) 176 | return z_t if not get_intermediate_steps else steps 177 | -------------------------------------------------------------------------------- /model/ddpm.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | from typing import Callable, Iterator, Tuple, Optional, Type, Union, List, ClassVar 3 | 4 | import pytorch_lightning as pl 5 | import torch 6 | import torchvision.utils 7 | from torch import nn 8 | 9 | from model.distributions import sigma_x_t, mu_x_t, mu_hat_xt_x0, sigma_hat_xt_x0, x0_to_xt 10 | from variance_scheduler.abs_var_scheduler import Scheduler 11 | from torch.nn.parameter import Parameter 12 | 13 | class GaussianDDPM(pl.LightningModule): 14 | """ 15 | Gaussian De-noising Diffusion Probabilistic Model 16 | This class implements both original DDPM model (by setting vlb=False) and Improved DDPM paper 17 | """ 18 | 19 | def __init__(self, denoiser_module: nn.Module, opt: Union[Type[torch.optim.Optimizer], Callable[[Iterator[Parameter]], torch.optim.Optimizer]], T: int, variance_scheduler: Scheduler, lambda_variational: float, width: int, height: int, input_channels: int, logging_freq: int, vlb: bool, init_step_vlb: int): 20 | """ 21 | :param denoiser_module: The nn which computes the denoise step i.e. q(x_{t-1} | x_t, t) 22 | :param T: the amount of noising steps 23 | :param variance_scheduler: the variance scheduler cited in DDPM paper. See folder variance_scheduler for practical implementation 24 | :param lambda_variational: the coefficient in from of variational loss 25 | :param width: image width 26 | :param height: image height 27 | :param input_channels: image input channels 28 | :param logging_freq: frequency of logging loss function during training 29 | :param vlb: true to include the variational lower bound into the loss function 30 | :param init_step_vlb: the step at which the variational lower bound is included into the loss function 31 | """ 32 | super().__init__() 33 | self.input_channels = input_channels 34 | self.denoiser_module = denoiser_module 35 | self.T = T 36 | self.opt_class = opt 37 | 38 | self.var_scheduler = variance_scheduler 39 | self.lambda_variational = lambda_variational 40 | self.alphas_hat: torch.FloatTensor = self.var_scheduler.get_alpha_hat().to(self.device) 41 | self.alphas: torch.FloatTensor = self.var_scheduler.get_alphas().to(self.device) 42 | self.betas = self.var_scheduler.get_betas().to(self.device) 43 | self.betas_hat = self.var_scheduler.get_betas_hat().to(self.device) 44 | self.mse = nn.MSELoss() 45 | self.width = width 46 | self.height = height 47 | self.logging_freq = logging_freq 48 | self.vlb = vlb 49 | self.init_step_vlb = init_step_vlb 50 | self.iteration = 0 51 | self.init_step_vlb = max(1, self.init_step_vlb) # make sure that init_step_vlb is at least 1 52 | 53 | def forward(self, x: torch.FloatTensor, t: int) -> Tuple[torch.Tensor, torch.Tensor]: 54 | """ 55 | Forward pass of the DDPM model. 56 | 57 | Args: 58 | x: Input image tensor. 59 | t: Time step tensor. 60 | 61 | Returns: 62 | Tuple of predicted noise tensor and predicted variance tensor. 63 | """ 64 | return self.denoiser_module(x, t) 65 | 66 | 67 | def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int): 68 | """ 69 | Training step of the DDPM model. 70 | 71 | Args: 72 | batch: Tuple of input image tensor and target tensor. 73 | batch_idx: Batch index. 74 | 75 | Returns: 76 | Dictionary containing the loss. 77 | """ 78 | X, y = batch 79 | with torch.no_grad(): 80 | # Map image values from [0, 1] to [-1, 1] 81 | X = X * 2 - 1 82 | # Sample a random time step t from 0 to T-1 for each image in the batch 83 | t: torch.Tensor = torch.randint(0, self.T - 1, (X.shape[0],), device=X.device) # todo add importance sampling 84 | # Compute alpha_hat for the selected time steps 85 | alpha_hat = self.alphas_hat[t].reshape(-1, 1, 1, 1) 86 | # Sample noise eps from a normal distribution with the same shape as X 87 | eps = torch.randn_like(X) 88 | # Compute the intermediate image x_t from the original image X, alpha_hat, and eps 89 | x_t = x0_to_xt(X, alpha_hat, eps) # go from x_0 to x_t with the formula 90 | # Run the intermediate image x_t through the model to obtain predicted noise and scale vectors (pred_eps, v) 91 | pred_eps, v = self(x_t, t) 92 | # Compute the loss for the predicted noise 93 | loss = 0.0 94 | noise_loss = self.mse(eps, pred_eps) 95 | loss = loss + noise_loss 96 | # If using the VLB loss, compute the VLB loss and add it to the total loss 97 | use_vlb = self.iteration >= self.init_step_vlb and self.vlb 98 | if use_vlb: 99 | loss_vlb = self.lambda_variational * self.variational_loss(x_t, X, pred_eps, v, t).mean(dim=0).sum() 100 | loss = loss + loss_vlb 101 | 102 | self.iteration += 1 103 | 104 | return dict(loss=loss, noise_loss=noise_loss, vlb_loss=loss_vlb if use_vlb else None) 105 | 106 | def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int): 107 | 108 | # Unpack the batch into inputs X and ground truth y 109 | X, y = batch 110 | 111 | # Normalize inputs to [-1, 1] range 112 | with torch.no_grad(): 113 | X = X * 2 - 1 114 | 115 | # Sample a time step t uniformly from [0, T-1] for each sample in the batch 116 | # TODO: Replace uniform sampling with importance sampling 117 | t: torch.Tensor = torch.randint(0, self.T - 1, (X.shape[0],), device=X.device) 118 | 119 | # Compute alpha_hat from the precomputed alphas_hat for the sampled t 120 | alpha_hat = self.alphas_hat[t].reshape(-1, 1, 1, 1) 121 | 122 | # Sample a noise vector eps from the standard normal distribution with the same shape as X 123 | eps = torch.randn_like(X) 124 | 125 | # Compute x_t, the input to the model at time step t 126 | x_t = x0_to_xt(X, alpha_hat, eps) 127 | 128 | # Forward pass through the model to get predicted noise and v 129 | pred_eps, v = self(x_t, t) 130 | 131 | # Compute the reconstruction loss between the input and the predicted noise 132 | loss = eps_loss = self.mse(eps, pred_eps) 133 | 134 | # If using the variational lower bound (VLB), compute the VLB loss and add it to the reconstruction loss 135 | # self.iteration > 0 is to avoid computing the VLB loss before the first training step because gives NaNs 136 | if self.iteration >= self.init_step_vlb and self.vlb: 137 | loss_vlb = self.lambda_variational * self.variational_loss(x_t, X, pred_eps, v, t).mean(dim=0).sum() 138 | loss = loss + loss_vlb 139 | 140 | # Return the loss as a dictionary 141 | return dict(loss=loss, noise_loss=eps_loss, vlb_loss=loss_vlb if self.vlb else None) 142 | 143 | def variational_loss(self, x_t: torch.Tensor, x_0: torch.Tensor, 144 | model_noise: torch.Tensor, v: torch.Tensor, t: torch.Tensor): 145 | """ 146 | Compute variational loss for time step t 147 | 148 | Parameters: 149 | - x_t (torch.Tensor): the image at step t obtained with closed form formula from x_0 150 | - x_0 (torch.Tensor): the input image 151 | - model_noise (torch.Tensor): the unet predicted noise 152 | - v (torch.Tensor): the unet predicted coefficients for the variance 153 | - t (torch.Tensor): the time step 154 | 155 | Returns: 156 | - vlb (torch.Tensor): the pixel-wise variational loss, with shape [batch_size, channels, width, height] 157 | """ 158 | vlb = 0.0 159 | t_eq_0 = (t == 0).reshape(-1, 1, 1, 1) 160 | 161 | # Compute variational loss for t=0 (i.e., first time step) 162 | if torch.any(t_eq_0): 163 | p = torch.distributions.Normal(mu_x_t(x_t, t, model_noise, self.alphas_hat, self.betas, self.alphas), 164 | sigma_x_t(v, t, self.betas_hat, self.betas)) 165 | # Compute log probability of x_0 under the distribution p 166 | # and add it to the variational lower bound 167 | vlb += - p.log_prob(x_0) * t_eq_0.float() 168 | 169 | t_eq_last = (t == (self.T - 1)).reshape(-1, 1, 1, 1) 170 | 171 | # Compute variational loss for t=T-1 (i.e., last time step) 172 | if torch.any(t_eq_last): 173 | p = torch.distributions.Normal(0, 1) 174 | q = torch.distributions.Normal(sqrt(self.alphas_hat[t]) * x_0, (1 - self.alphas_hat[t])) 175 | # Compute KL divergence between distributions p and q 176 | # and add it to the variational lower bound 177 | vlb += torch.distributions.kl_divergence(q, p) * t_eq_last.float() 178 | 179 | # Compute variational loss for all other time steps 180 | mu_hat = mu_hat_xt_x0(x_t, x_0, t, self.alphas_hat, self.alphas, self.betas) 181 | sigma_hat = sigma_hat_xt_x0(t, self.betas_hat) 182 | q = torch.distributions.Normal(mu_hat, sigma_hat) # q(x_{t-1} | x_t, x_0) 183 | mu = mu_x_t(x_t, t, model_noise, self.alphas_hat, self.betas, self.alphas).detach() 184 | sigma = sigma_x_t(v, t, self.betas_hat, self.betas) 185 | p = torch.distributions.Normal(mu, sigma) # p(x_t | x_{t-1}) 186 | # Compute KL divergence between distributions p and q 187 | # and add it to the variational lower bound 188 | vlb += torch.distributions.kl_divergence(q, p) * (~t_eq_last).float() * (~t_eq_0).float() 189 | 190 | return vlb 191 | 192 | def configure_optimizers(self): 193 | return self.opt_class(params=self.parameters()) 194 | 195 | def generate(self, batch_size: Optional[int] = None, T: Optional[int] = None, 196 | get_intermediate_steps: bool = False) -> Union[torch.Tensor, List[torch.Tensor]]: 197 | """ 198 | Generate a batch of images via denoising diffusion probabilistic model 199 | :param batch_size: batch size of generated images. The default value is 1 200 | :param T: number of diffusion steps to generated images. The default value is the training diffusion steps 201 | :param get_intermediate_steps: return all the denoising steps instead of the final step output 202 | :return: The tensor [bs, c, w, h] of generated images or a list of tensors [bs, c, w, h] if get_intermediate_steps=True 203 | """ 204 | batch_size = batch_size or 1 205 | T = T or self.T 206 | if get_intermediate_steps: 207 | steps = [] 208 | X_noise = torch.randn(batch_size, self.input_channels, # start with random noise sampled from N(0, 1) 209 | self.width, self.height, device=self.device) 210 | beta_sqrt = torch.sqrt(self.betas) 211 | for t in range(T - 1, -1, -1): 212 | if get_intermediate_steps: 213 | steps.append(X_noise) 214 | t = torch.LongTensor([t]).to(self.device) 215 | eps, v = self.denoiser_module(X_noise, t) # predict via nn the noise 216 | # if variational lower bound is present on the loss function hence v (the logit of variance) is trained 217 | # else the variance is taked fixed as in the original DDPM paper 218 | sigma = sigma_x_t(v, t, self.betas_hat, self.betas) if self.vlb else beta_sqrt[t].reshape(-1, 1, 1, 1) 219 | z = torch.randn_like(X_noise) 220 | if t == 0: 221 | z.fill_(0) 222 | alpha_t = self.alphas[t].reshape(-1, 1, 1, 1) 223 | alpha_hat_t = self.alphas_hat[t].reshape(-1, 1, 1, 1) 224 | X_noise = 1 / (torch.sqrt(alpha_t)) * \ 225 | (X_noise - ((1 - alpha_t) / torch.sqrt(1 - alpha_hat_t)) * eps) + sigma * z # denoise step from x_t to x_{t-1} following the DDPM paper. Differently from the 226 | # original paper, the variance is estimated via nn instead of be fixed, as in Improved DDPM paper 227 | X_noise = (X_noise + 1) / 2 # rescale from [-1, 1] to [0, 1] 228 | if get_intermediate_steps: 229 | steps.append(X_noise) 230 | return steps 231 | return X_noise 232 | 233 | def on_fit_start(self) -> None: 234 | self.alphas_hat: torch.FloatTensor = self.alphas_hat.to(self.device) 235 | self.alphas: torch.FloatTensor = self.alphas.to(self.device) 236 | self.betas = self.betas.to(self.device) 237 | self.betas_hat = self.betas_hat.to(self.device) 238 | -------------------------------------------------------------------------------- /model/distributions.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | 5 | 6 | def mu_x_t(x_t: torch.Tensor, t: torch.Tensor, model_noise: torch.Tensor, alphas_hat: torch.Tensor, betas: torch.Tensor, alphas: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: 7 | """ 8 | 9 | :param x_t: the noised image 10 | :param t: the time step of $x_t$ 11 | :param model_noise: the model estimated noise 12 | :param alphas_hat: sequence of $\hat{\alpha}$ used for variance scheduling 13 | :param betas: sequence of $\beta$ used for variance scheduling 14 | :param alphas: sequence of $\alpha$ used for variance scheduling 15 | :return: the mean of $q(x_t | x_0)$ 16 | """ 17 | alpha_t = alphas[t].reshape(-1, 1, 1, 1) 18 | beta_t = betas[t].reshape(-1, 1, 1, 1) 19 | alpha_hat_t = alphas_hat[t].reshape(-1, 1, 1, 1) 20 | x = 1 / (torch.sqrt(alpha_t) + eps) * (x_t - (beta_t / (torch.sqrt(1 - alpha_hat_t) + eps) * model_noise)) 21 | # tg.guard(x, "B, C, W, H") 22 | return x 23 | 24 | return x 25 | 26 | 27 | def sigma_x_t(v: torch.Tensor, t: torch.Tensor, betas_hat: torch.Tensor, betas: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: 28 | """ 29 | Compute the variance at time step t as defined in "Improving Denoising Diffusion probabilistic Models", eqn 15 page 4 30 | :param v: the neural network "logits" used to compute the variance [BS, C, W, H] 31 | :param t: the target time step 32 | :param betas_hat: sequence of $\hat{\beta}$ used for variance scheduling 33 | :param betas: sequence of $\beta$ used for variance scheduling 34 | :return: the estimated variance at time step t 35 | """ 36 | x = torch.exp(v * torch.log(betas[t].reshape(-1, 1, 1, 1) + eps) + (1 - v) * torch.log(betas_hat[t].reshape(-1, 1, 1, 1) + eps)) 37 | # tg.guard(x, "B, C, W, H") 38 | return x 39 | 40 | 41 | def mu_hat_xt_x0(x_t: torch.Tensor, x_0: torch.Tensor, t: torch.Tensor, alphas_hat: torch.Tensor, alphas: torch.Tensor, betas: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: 42 | """ 43 | Compute $\hat{mu}(x_t, x_0)$ of $q(x_{t-1} | x_t, x_0)$ from "Improving Denoising Diffusion probabilistic Models", eqn 11 page 2 44 | :param x_t: The noised image at step t 45 | :param x_0: the original image 46 | :param t: the time step of $x_t$ [batch_size] 47 | :param alphas_hat: sequence of $\hat{\alpha}$ used for variance scheduling [T] 48 | :param alphas: sequence of $\alpha$ used for variance scheduling [T] 49 | :param betas: sequence of $\beta$ used for variance scheduling [T} 50 | :return: the mean of distribution $q(x_{t-1} | x_t, x_0)$ 51 | """ 52 | alpha_hat_t = alphas_hat[t].reshape(-1, 1, 1, 1) 53 | one_min_alpha_hat = (1 - alpha_hat_t) + eps 54 | alpha_t = alphas[t].reshape(-1, 1, 1, 1) 55 | alpha_hat_t_1 = alphas_hat[t - 1].reshape(-1, 1, 1, 1) 56 | beta_t = betas[t].reshape(-1, 1, 1, 1) 57 | x = torch.sqrt(alpha_hat_t_1 + eps) * beta_t / one_min_alpha_hat * x_0 + \ 58 | torch.sqrt(alpha_t + eps) * ( 59 | 1 - alpha_hat_t_1) / one_min_alpha_hat * x_t 60 | # tg.guard(x, "B, C, W, H") 61 | return x 62 | 63 | 64 | def sigma_hat_xt_x0(t: torch.Tensor, betas_hat: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: 65 | """ 66 | Compute the variance of of $q(x_{t-1} | x_t, x_0)$ from "Improving Denoising Diffusion probabilistic Models", eqn 12 page 2 67 | :param t: the time step [batch_size] 68 | :param betas_hat: the array of beta hats [T] 69 | :return: the variance at time step t as scalar [batch_size, 1, 1, 1] 70 | """ 71 | return (betas_hat[t].reshape(-1, 1, 1, 1) + eps) 72 | 73 | 74 | def x0_to_xt(x_0: torch.Tensor, alpha_hat_t: torch.Tensor, eps: Optional[torch.Tensor] = None) -> torch.Tensor: 75 | """ 76 | Compute x_t from x_0 using a closed form using theorem from original DDPM paper (Ho et al.) 77 | :param x_0: the image without noise 78 | :param alpha_hat_t: the cumulated variance schedule at time t 79 | :param eps: pure noise from N(0, 1) 80 | :return: the noised image x_t at step t 81 | """ 82 | if eps is None: 83 | eps = torch.randn_like(x_0) 84 | return torch.sqrt(alpha_hat_t) * x_0 + torch.sqrt(1 - alpha_hat_t) * eps 85 | -------------------------------------------------------------------------------- /model/unet.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Tuple 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | 9 | # import tensorguard as tg 10 | 11 | 12 | def positional_embedding_vector(t: int, dim: int) -> torch.FloatTensor: 13 | """ 14 | 15 | Args: 16 | t (int): time step 17 | dim (int): embedding size 18 | 19 | Returns: the transformer sinusoidal positional embedding vector 20 | 21 | """ 22 | two_i = 2 * torch.arange(0, dim) 23 | return torch.sin(t / torch.pow(10_000, two_i / dim)).unsqueeze(0) 24 | 25 | 26 | def timestep_embedding(timesteps: torch.Tensor, dim: int, max_period=10000): 27 | """ 28 | Create sinusoidal timestep embeddings. 29 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 30 | These may be fractional. 31 | :param dim: the dimension of the output. 32 | :param max_period: controls the minimum frequency of the embeddings. 33 | :return: an [N x dim] Tensor of positional embeddings. 34 | """ 35 | half = dim // 2 36 | freqs = torch.exp( 37 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 38 | ).to(device=timesteps.device) 39 | args = timesteps[:, None].float() * freqs[None] 40 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 41 | if dim % 2: 42 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 43 | return embedding.to(timesteps.device) 44 | 45 | 46 | @torch.no_grad() 47 | def init_zero(module: nn.Module) -> nn.Module: 48 | for p in module.parameters(): 49 | torch.nn.init.zeros_(p) 50 | return module 51 | 52 | 53 | class ResBlockTimeEmbed(nn.Module): 54 | 55 | def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int, padding: int, 56 | time_embed_size: int, p_dropout: float): 57 | """ 58 | Args: 59 | in_channels (int): number of input channels 60 | out_channels (int): number of output channels 61 | kernel_size (int): size of the convolution kernel 62 | stride (int): stride of the convolution 63 | padding (int): padding of the convolution 64 | time_embed_size (int): size of the time embedding 65 | p_dropout (float): dropout probability 66 | """ 67 | super().__init__() 68 | num_groups_in = self.find_max_num_groups(in_channels) 69 | self.conv = nn.Sequential( 70 | nn.GroupNorm(num_groups_in, in_channels), 71 | nn.GELU(), 72 | nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)) 73 | self.relu = nn.ReLU() 74 | self.l_embedding = nn.Sequential( 75 | nn.GELU(), 76 | nn.Linear(time_embed_size, out_channels) 77 | ) 78 | num_groups_out = self.find_max_num_groups(out_channels) 79 | self.out_layer = nn.Sequential( 80 | nn.GroupNorm(num_groups_out, out_channels), 81 | nn.GELU(), 82 | nn.Dropout(p_dropout), 83 | nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding), 84 | ) 85 | self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) 86 | 87 | def find_max_num_groups(self, in_channels: int) -> int: 88 | for i in range(4, 0, -1): 89 | if in_channels % i == 0: 90 | return i 91 | raise Exception() 92 | 93 | def forward(self, x, time_embed): 94 | h = self.conv(x) 95 | time_embed = self.l_embedding(time_embed) 96 | time_embed = time_embed.view(time_embed.shape[0], time_embed.shape[1], 1, 1) 97 | h = h + time_embed 98 | return self.out_layer(h) + self.skip_connection(x) 99 | 100 | 101 | class ImageSelfAttention(nn.Module): 102 | 103 | def __init__(self, num_channels: int, num_heads: int = 1): 104 | """ 105 | Args: 106 | num_channels (int): Number of channels in the input. 107 | num_heads (int, optional): Number of attention heads. Default: 1. 108 | 109 | Shape: 110 | - Input: :math:`(N, C, L)` 111 | - Output: :math:`(N, C, L)` 112 | 113 | Examples: 114 | >>> attention = ImageSelfAttention(3) 115 | >>> input = torch.randn(1, 3, 64) 116 | >>> output = attention(input) 117 | >>> output.shape 118 | torch.Size([1, 3, 64]) """ 119 | 120 | super().__init__() 121 | self.channels = num_channels 122 | self.heads = num_heads 123 | 124 | self.attn_layer = nn.MultiheadAttention(num_channels, num_heads=num_heads) 125 | 126 | def forward(self, x): 127 | """ 128 | 129 | :param x: tensor with shape [batch_size, channels, width, height] 130 | :return: the attention output applied to the image with the shape [batch_size, channels, width, height] 131 | """ 132 | b, c, w, h = x.shape 133 | x = x.reshape(b, w * h, c) 134 | 135 | attn_output, _ = self.attn_layer(x, x, x) 136 | return attn_output.reshape(b, c, w, h) 137 | 138 | 139 | class UNetTimeStep(nn.Module): 140 | 141 | def __init__(self, channels: List[int], kernel_sizes: List[int], strides: List[int], paddings: List[int], 142 | downsample: bool, p_dropouts: List[float], time_embed_size: int): 143 | super().__init__() 144 | assert len(channels) == (len(kernel_sizes) + 1) == (len(strides) + 1) == (len(paddings) + 1) == \ 145 | (len(p_dropouts) + 1), f'{len(channels)} == {(len(kernel_sizes) + 1)} == ' \ 146 | f'{(len(strides) + 1)} == {(len(paddings) + 1)} == \ 147 | {(len(p_dropouts) + 1)}' 148 | self.channels = channels 149 | self.time_embed_size = time_embed_size 150 | self.downsample_blocks = nn.ModuleList([ 151 | ResBlockTimeEmbed(channels[i], channels[i + 1], kernel_sizes[i], strides[i], paddings[i], time_embed_size, p_dropouts[i]) for i in range(len(channels) - 1) 152 | ]) 153 | 154 | self.use_downsample = downsample 155 | self.downsample_op = nn.MaxPool2d(kernel_size=2) 156 | self.middle_block = ResBlockTimeEmbed(channels[-1], channels[-1], kernel_sizes[-1], strides[-1], paddings[-1], time_embed_size, p_dropouts[-1]) 157 | channels[0] *= 2 # because the output is the image plus the estimated variance coefficients 158 | self.upsample_blocks = nn.ModuleList([ 159 | ResBlockTimeEmbed((2 if i != 0 else 1) * channels[-i - 1], channels[-i - 2], kernel_sizes[-i - 1], 160 | strides[-i - 1], 161 | paddings[-i - 1], time_embed_size, p_dropouts[-i - 1]) for i in range(len(channels) - 1) 162 | ]) 163 | self.dropouts = nn.ModuleList([nn.Dropout2d(p) for p in p_dropouts]) 164 | self.p_dropouts = p_dropouts 165 | self.self_attn = ImageSelfAttention(channels[3]) 166 | self.time_embed = nn.Sequential( 167 | nn.Linear(self.time_embed_size, self.time_embed_size), 168 | nn.SiLU(), 169 | nn.Linear(self.time_embed_size, self.time_embed_size), 170 | ) 171 | 172 | def forward(self, x: torch.FloatTensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 173 | """Forward pass of the model. 174 | 175 | Args: 176 | x (torch.FloatTensor): Input tensor of shape (batch_size, num_channels, width, height). 177 | t (torch.Tensor): Time tensor of shape (batch_size, 1). 178 | 179 | Returns: 180 | Tuple[torch.Tensor, torch.Tensor]: Output tensor x_recon of shape (batch_size, num_channels, width, height) 181 | and output tensor v of shape (batch_size, num_channels, width, height). 182 | """ 183 | x_channels = x.shape[1] 184 | # tg.guard(x, "B, C, W, H") 185 | time_embedding = self.time_embed(timestep_embedding(t, self.time_embed_size)) 186 | # tg.guard(time_embedding, "B, TE") 187 | hs = [] 188 | h = x 189 | for i, downsample_block in enumerate(self.downsample_blocks): 190 | h = downsample_block(h, time_embedding) 191 | if i == 2: 192 | h = self.self_attn(h) 193 | h = self.dropouts[i](h) 194 | if i != (len(self.downsample_blocks) - 1): hs.append(h) 195 | if self.use_downsample and i != (len(self.downsample_blocks) - 1): 196 | h = self.downsample_op(h) 197 | h = self.middle_block(h, time_embedding) 198 | for i, upsample_block in enumerate(self.upsample_blocks): 199 | if i != 0: 200 | h = torch.cat([h, hs[-i]], dim=1) 201 | h = upsample_block(h, time_embedding) 202 | if self.use_downsample and (i != (len(self.upsample_blocks) - 1)): 203 | h = F.interpolate(h, size=hs[-i - 1].shape[-2:], mode='nearest') 204 | x_recon, v = h[:, :x_channels], h[:, x_channels:] 205 | # tg.guard(x_recon, "B, C, W, H") 206 | # tg.guard(v, "B, C, W, H") 207 | return x_recon, v 208 | 209 | 210 | -------------------------------------------------------------------------------- /model/unet_class.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from model.unet import ResBlockTimeEmbed, ImageSelfAttention 8 | import tensorguard as tg 9 | 10 | 11 | class ResBlockTimeEmbedClassConditioned(ResBlockTimeEmbed): 12 | 13 | def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int, padding: int, 14 | time_embed_size: int, p_dropout: float, num_classes: int, class_embed_size: int, 15 | assert_shapes: bool = True): 16 | super().__init__(in_channels + class_embed_size, out_channels, kernel_size, stride, padding, 17 | time_embed_size, p_dropout) 18 | self.linear_map_class = nn.Sequential( 19 | nn.Linear(num_classes, class_embed_size), 20 | nn.ReLU(), 21 | nn.Linear(class_embed_size, class_embed_size) 22 | 23 | ) 24 | 25 | self.assert_shapes = assert_shapes 26 | 27 | def forward(self, x, time_embed, c): 28 | emb_c = self.linear_map_class(c) 29 | emb_c = emb_c.view(*emb_c.shape, 1, 1) 30 | emb_c = emb_c.expand(-1, -1, x.shape[-2], x.shape[-1]) 31 | if self.assert_shapes: tg.guard(emb_c, "B, C, W, H") 32 | x = torch.cat([x, emb_c], dim=1) 33 | return super().forward(x, time_embed) 34 | 35 | 36 | class UNetTimeStepClassConditioned(nn.Module): 37 | """ 38 | UNet architecture with class and time embedding injected in every residual block, both in downsample and upsample. 39 | Both information are mapped via an 2-layers MLP to a fixed embedding size. 40 | After the third downsample block a self-attention layer is applied. 41 | """ 42 | 43 | def __init__(self, channels: List[int], kernel_sizes: List[int], strides: List[int], paddings: List[int], 44 | downsample: bool, p_dropouts: List[float], time_embed_size: int, num_classes: int, 45 | class_embed_size: int, 46 | assert_shapes: bool = True): 47 | super().__init__() 48 | assert len(channels) == (len(kernel_sizes) + 1) == (len(strides) + 1) == (len(paddings) + 1) == \ 49 | (len(p_dropouts) + 1), f'{len(channels)} == {(len(kernel_sizes) + 1)} == ' \ 50 | f'{(len(strides) + 1)} == {(len(paddings) + 1)} == \ 51 | {(len(p_dropouts) + 1)}' 52 | self.channels = channels 53 | self.kernel_sizes = kernel_sizes 54 | self.strides = strides 55 | self.paddings = paddings 56 | self.assert_shapes = assert_shapes 57 | self.num_classes = num_classes 58 | self.time_embed_size = time_embed_size 59 | self.class_embed_size = class_embed_size 60 | self.downsample_blocks = nn.ModuleList([ 61 | ResBlockTimeEmbedClassConditioned(channels[i], channels[i + 1], kernel_sizes[i], strides[i], 62 | paddings[i], time_embed_size, p_dropouts[i], num_classes, 63 | class_embed_size, assert_shapes) for i in range(len(channels) - 1) 64 | ]) 65 | 66 | self.use_downsample = downsample 67 | self.downsample_op = nn.MaxPool2d(kernel_size=2) 68 | self.middle_block = ResBlockTimeEmbedClassConditioned(channels[-1], channels[-1], kernel_sizes[-1], strides[-1], 69 | paddings[-1], time_embed_size, p_dropouts[-1], 70 | num_classes, class_embed_size, assert_shapes) 71 | self.upsample_blocks = nn.ModuleList([ 72 | ResBlockTimeEmbedClassConditioned((2 if i != 0 else 1) * channels[-i - 1], channels[-i - 2], 73 | kernel_sizes[-i - 1], 74 | strides[-i - 1], paddings[-i - 1], time_embed_size, p_dropouts[-i - 1], 75 | num_classes, 76 | class_embed_size, assert_shapes) for i in range(len(channels) - 1) 77 | ]) 78 | self.dropouts = nn.ModuleList([nn.Dropout(p) for p in p_dropouts]) 79 | self.p_dropouts = p_dropouts 80 | self.self_attn = ImageSelfAttention(channels[2]) 81 | self.time_embed = nn.Sequential( 82 | nn.Linear(1, self.time_embed_size), 83 | nn.GELU(), 84 | nn.Linear(self.time_embed_size, self.time_embed_size), 85 | ) 86 | 87 | def forward(self, x: torch.FloatTensor, t: torch.Tensor, c: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 88 | x_channels = x.shape[1] 89 | if self.assert_shapes: tg.guard(x, "B, C, W, H") 90 | if self.assert_shapes: tg.guard(c, "B, NUMCLASSES") 91 | time_embedding = self.time_embed(t) 92 | if self.assert_shapes: tg.guard(time_embedding, "B, TE") 93 | h = self.forward_unet(x, time_embedding, c) 94 | x_recon = h 95 | if self.assert_shapes: tg.guard(x_recon, "B, C, W, H") 96 | return x_recon 97 | 98 | def forward_unet(self, x, time_embedding, c): 99 | hs = [] 100 | h = x 101 | for i, downsample_block in enumerate(self.downsample_blocks): 102 | h = downsample_block(h, time_embedding, c) 103 | if i == 2: 104 | h = self.self_attn(h) 105 | h = self.dropouts[i](h) 106 | if i != (len(self.downsample_blocks) - 1): hs.append(h) 107 | if self.use_downsample and i != (len(self.downsample_blocks) - 1): 108 | h = self.downsample_op(h) 109 | h = self.middle_block(h, time_embedding, c) 110 | for i, upsample_block in enumerate(self.upsample_blocks): 111 | if i != 0: 112 | h = torch.cat([h, hs[-i]], dim=1) 113 | h = upsample_block(h, time_embedding, c) 114 | if self.use_downsample and (i != (len(self.upsample_blocks) - 1)): 115 | h = F.interpolate(h, size=hs[-i - 1].shape[-1], mode='nearest') 116 | return h 117 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # Note: This pyproject.toml file is intended to publish model/ folder as pip package 2 | # and is not intended to be used for the whole project. For the project, use anaconda-project.yml 3 | [build-system] 4 | requires = ["hatchling"] 5 | build-backend = "hatchling.build" 6 | 7 | [project] 8 | name = "ddpm" 9 | description = "Pytorch implementation of 'Improved Denoising Diffusion Probabilistic Models', 'Denoising Diffusion Probabilistic Models' and 'Classifier-free Diffusion Guidance'" 10 | requires-python = ">=3.7,<3.12" 11 | keywords = ["pytorch", "ddpm", "denoising diffusion probabilistic model", "generative", ] 12 | license = {text = "MIT"} 13 | readme = "readme_pip.md" 14 | classifiers = [ 15 | "Development Status :: 5 - Production/Stable", 16 | "Intended Audience :: Developers", 17 | "Intended Audience :: Science/Research", 18 | "License :: OSI Approved :: MIT License", 19 | "Operating System :: OS Independent", 20 | "Programming Language :: Python", 21 | "Programming Language :: Python :: 3", 22 | "Programming Language :: Python :: 3.7", 23 | "Programming Language :: Python :: 3.8", 24 | "Programming Language :: Python :: 3.9", 25 | "Programming Language :: Python :: 3.10", 26 | "Programming Language :: Python :: 3.11", 27 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 28 | "Topic :: Software Development :: Libraries :: Python Modules", 29 | ] 30 | 31 | dependencies = [ 32 | "torch>=1.8", 33 | "pytorch-lightning >= 1.8", 34 | "torchvision", 35 | "tensorguard==1.0.0", 36 | "path>=16.0" 37 | ] 38 | 39 | version = "1.0.0" 40 | 41 | 42 | [project.optional-dependencies] 43 | dev = [ 44 | "pytest", 45 | ] 46 | 47 | [tool.hatch.build] 48 | include = [ 49 | "ddpm", 50 | ] 51 | 52 | [tool.hatch.envs.default] 53 | python = "3.10" 54 | dependencies = [ 55 | "torch==2.0.1", 56 | "torchvision", 57 | "pytorch-lightning", 58 | "torchmetrics==1.6.3", 59 | "torchsummary", 60 | "path==17.1.0", 61 | "numpy<2.0.0", 62 | "hydra-core", 63 | "tensorboard", 64 | "seaborn", 65 | "matplotlib", 66 | "einops", 67 | "pytest", 68 | ] 69 | 70 | [tool.hatch.envs.default.scripts] 71 | train = "python train.py accelerator=gpu devices=1" 72 | test = "pytest {args:tests}" 73 | compress-runs = "tar cfz saved_models.tar.gz saved_models/" 74 | run-tensorboard = "tensorboard --logdir=saved_models/" 75 | clean-empty-runs = """python -c '\nfrom path import Path\nfor run in Path(\"saved_models\").dirs():\n\ 76 | if not run.joinpath(\"best.ckpt\").exists():\n print(f\"Removing\ 77 | {run}\")\n run.rmtree()'\n""" 78 | generate = "python generate.py {args:generate}" 79 | plot-cosine-scheduler = "python plot_cosine_scheduler.py {args:plot-cosine-scheduler}" 80 | 81 | 82 | [tool.hatch.envs.cpu] 83 | python = "3.10" 84 | dependencies = [ 85 | "torch==2.0.1", 86 | "torchvision", 87 | "pytorch-lightning", 88 | "torchmetrics==1.6.3", 89 | "path==17.1.0", 90 | "numpy<2.0.0", 91 | "hydra-core", 92 | "tensorboard", 93 | "seaborn", 94 | "matplotlib", 95 | "einops", 96 | "pytest", 97 | ] 98 | 99 | [tool.hatch.envs.cpu.env-vars] 100 | PIP_EXTRA_INDEX_URL = "https://download.pytorch.org/whl/cpu" 101 | PIP_VERBOSE = "1" 102 | 103 | 104 | [tool.hatch.envs.cpu.scripts] 105 | train = "python train.py accelerator=cpu" 106 | 107 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | [![PyPI - Downloads](https://img.shields.io/pypi/dm/ddpm)](https://pypi.org/project/ddpm/) 2 | [![PyPI](https://img.shields.io/pypi/v/ddpm)](https://pypi.org/project/ddpm/) 3 | [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/ddpm)](https://pypi.org/project/ddpm/) 4 | 5 | # DDPM Pytorch 6 | 7 | Pytorch implementation of "_Improved Denoising Diffusion Probabilistic Models_", 8 | "_Denoising Diffusion Probabilistic Models_" and "_Classifier-free Diffusion Guidance_" 9 | 10 | ![](https://hojonathanho.github.io/diffusion/assets/img/pgm_diagram_xarrow.png) 11 | 12 | 13 | # How to use 14 | 15 | There are two ways to use this repository: 16 | 17 | 1. Install pip package containing the pytorch lightning model, which includes also the training step 18 | 19 | pip install ddpm 20 | 21 | 2. Clone the repository to have the full control of the training 22 | 23 | git clone https://github.com/Michedev/DDPMs-Pytorch 24 | 25 | # How to train 26 | 27 | 1. Install the project environment via hatch (`pip install hatch`). There are two environments: _default_ has torch with cuda support, _cpu_ without it. 28 | 29 | hatch env create 30 | or 31 | hatch env create cpu 32 | 33 | 34 | 2. Train the model 35 | 36 | hatch run train 37 | 38 | or for the cpu environment 39 | 40 | hatch run cpu:train 41 | 42 | Note that this is valid for any `hatch run [env:]{command}` command 43 | 44 | By default, the version of trained DDPM is from "Improved Denoising Diffusion Probabilistic Models" paper on MNIST dataset. 45 | You can switch to the original DDPM by disabling the variational lower bound with the following command: 46 | 47 | hatch run train model.vlb=False 48 | 49 | You can also train the DDPM with the Classifier-free Diffusion Guidance by changing the model: 50 | 51 | hatch run train model=unet_class_conditioned 52 | 53 | or via the shortcut 54 | 55 | hatch run train-class-conditioned 56 | 57 | Finally, under saved_models/{train-datetime} you can find the trained model, the tensorboard logs, the training config 58 | 59 | # How to generate 60 | 61 | 1. Train a model (See previous section) 62 | 63 | 2. Generate a new batch of images 64 | 65 | hatch run generate -r RUN 66 | 67 | The other options are: `[--seed SEED] [--device DEVICE] [--batch-size BATCH_SIZE] [-w W] [--scheduler {linear,cosine,tan}] [-T T]` 68 | 69 | # Configure the training 70 | 71 | Under _config_ there are several yaml files containing the training parameters 72 | such as model class and paramters, noise steps, scheduler and so on. 73 | Note that the hyperparameters in such files are taken from 74 | the papers "_Improved Denoising Diffusion Probabilistic Models_" 75 | and "_Denoising Diffusion Probabilistic Models_". Down below the explaination of the config file for train the model: 76 | 77 | defaults: 78 | - model: unet_paper # take the model config from model/unet_paper.yaml 79 | - scheduler: cosine # use the cosine scheduler from scheduler/cosine.yaml 80 | - dataset: mnist 81 | - optional model_dataset: ${model}-${dataset} # set particular hyper parameters for specific couples (model, dataset) 82 | - optional model_scheduler: ${model}-${scheduler} # set particular hyper parameters for specific couples (model, scheduler) 83 | 84 | batch_size: 128 # train batch size 85 | noise_steps: 4_000 # noising steps; the T in "Improved Denoising Diffusion Probabilistic Models" and "Denoising Diffusion Probabilistic Models" 86 | accelerator: null # training hardware; for more details see pytorch lightning 87 | devices: null # training devices to use; for more details see pytorch lightning 88 | gradient_clip_val: 0.0 # 0.0 means gradient clip disabled 89 | gradient_clip_algorithm: norm # gradient clip has two values: 'norm' or 'value 90 | ema: true # use Exponential Moving Average implemented in ema.py 91 | ema_decay: 0.99 # decay factor of EMA 92 | 93 | hydra: 94 | run: 95 | dir: saved_models/${now:%Y_%m_%d_%H_%M_%S} 96 | 97 | # Project structure 98 | 99 | . 100 | ├── callbacks # Pytorch Lightning callbacks for training 101 | │   ├── ema.py # exponential moving average callback 102 | ├── config # config files for training for hydra 103 | │   ├── dataset # dataset config files 104 | │   ├── model # model config files 105 | │   ├── model_dataset # specific (model, dataset) config 106 | │   ├── model_scheduler # specific (model, scheduler) config 107 | │   ├── scheduler # scheduler config files 108 | │   └── train.yaml # training config file 109 | ├── generate.py # script for generating images 110 | ├── model # model files 111 | │   ├── classifier_free_ddpm.py # Classifier-free Diffusion Guidance 112 | │   ├── ddpm.py # Denoising Diffusion Probabilistic Models 113 | │   ├── distributions.py # distributions functions for diffusion 114 | │   ├── unet_class.py # UNet model for Classifier-free Diffusion Guidance 115 | │   └── unet.py # UNet model for Denoising Diffusion Probabilistic Models 116 | ├── pyproject.toml # setuptool file to publish model/ to pypi and to manage the envs 117 | ├── readme.md # this file 118 | ├── readme_pip.md # readme for pypi 119 | ├── train.py # script for training 120 | ├── utils # utility functions 121 | └── variance_scheduler # variance scheduler files 122 | ├── cosine.py # cosine variance scheduler 123 | └── linear.py # linear variance scheduler 124 | 125 | ### Add custom dataset 126 | 127 | To add a custom dataset, you need to create a new class that inherits from torch.utils.data.Dataset 128 | and implement the __len__ and __getitem__ methods. 129 | Then, you need to add the config file to the _config/dataset_ folder with a similar 130 | structure of mnist.yaml 131 | 132 | width: 28 # meta info about the dataset 133 | height: 28 134 | channels: 1 # number of image channels 135 | num_classes: 10 # number of classes 136 | files_location: ~/.cache/torchvision_dataset # location where to store the dataset, in case to be downloaded 137 | train: #dataset.train is instantiated with this config 138 | _target_: torchvision.datasets.MNIST # Dataset class. Following arguments are passed to the dataset class constructor 139 | root: ${dataset.files_location} 140 | train: true 141 | download: true 142 | transform: 143 | _target_: torchvision.transforms.ToTensor 144 | val: #dataset.val is instantiated with this config 145 | _target_: torchvision.datasets.MNIST # Same dataset of train, but the validation split 146 | root: ${dataset.files_location} 147 | train: false 148 | download: true 149 | transform: 150 | _target_: torchvision.transforms.ToTensor 151 | 152 | # Examples of custom training 153 | 154 | ### Disable the variational lower bound, use Linear scheduler, use 1000 noise steps, train in GPU 155 | 156 | hatch run train scheduler=linear accelerator='gpu' model.vlb=False noise_steps=1000 157 | 158 | 159 | ## Classifier-free Guidance 160 | 161 | Use the labels for __Diffusion Guidance__, as in "_Classifier-free Diffusion Guidance_" with the following command 162 | 163 | hatch run train model=unet_class_conditioned noise_steps=1000 164 | 165 | ## Add your scheduler 166 | 167 | 1. Add a new class (preferabily under `variance_scheduler/`) which subclasses `Scheduler` class or just copy the same methods syntax of `Scheduler` 168 | 2. Define a new config under `config/scheduler` with the name _my-scheduler.yaml_ containing the following fields 169 | 170 | ``` 171 | _target_: {your scheduler import path} (e.g. variance_scheduler.Linear) 172 | ... // your scheduler additional parameters 173 | ``` 174 | 175 | Finally train with the following command 176 | 177 | hatch run train scheduler=my-scheduler 178 | 179 | ## Add your dataset 180 | 181 | 1. Add a new class which subclasses `torch.utils.data.Dataset` 182 | 183 | 2. Define a new config under `config/dataset` with the name _my-dataset.yaml_ containing the following fields 184 | 185 | ``` 186 | width: ??? 187 | height: ??? 188 | channels: ??? 189 | train: 190 | _target_: {your dataset import path} (e.g. torchvision.datasets.MNIST) 191 | // your dataset additional parameters 192 | val: 193 | _target_: {your dataset import path} (e.g. torchvision.datasets.MNIST) 194 | // your dataset additional parameters 195 | ``` 196 | 197 | Finally train with the following command 198 | 199 | hatch run train dataset=my-dataset 200 | -------------------------------------------------------------------------------- /readme_pip.md: -------------------------------------------------------------------------------- 1 | # DDPMs Pytorch Implementation 2 | 3 | Pytorch implementation of "_Improved Denoising Diffusion Probabilistic Models_", 4 | "_Denoising Diffusion Probabilistic Models_" and "_Classifier-free Diffusion Guidance_" 5 | 6 | ## Install 7 | 8 | ```bash 9 | pip install ddpm 10 | ``` 11 | 12 | # Usage 13 | 14 | ## Gaussian plain DDPM 15 | ```python 16 | 17 | from ddpm import GaussianDDPM, UNetTimeStep 18 | from ddpm.variance_scheduler import LinearScheduler 19 | 20 | T = 1_000 21 | width = 32 22 | height = 32 23 | channels = 3 24 | 25 | # Create a Gaussian DDPM with 1000 noise steps 26 | scheduler = LinearScheduler(T=T, beta_start=1e-5, beta_end=1e-2) 27 | denoiser = UNetTimeStep(channels=[3, 128, 256, 256, 384], 28 | kernel_sizes=[3, 3, 3, 3], 29 | strides=[1, 1, 1, 1], 30 | paddings=[1, 1, 1, 1], 31 | p_dropouts=[0.1, 0.1, 0.1, 0.1], 32 | time_embed_size=100, 33 | downsample=True) 34 | model = GaussianDDPM(denoiser, T, scheduler, vlb=False, lambda_variational=1.0, width=width, 35 | height=height, input_channels=channels, logging_freq=1_000) # pytorch lightning module 36 | 37 | ``` 38 | 39 | ## Gaussian "Improved" DDPM 40 | 41 | ```python 42 | 43 | from ddpm import GaussianDDPM, UNetTimeStep 44 | from ddpm.variance_scheduler import CosineScheduler 45 | 46 | T = 1_000 47 | width = 32 48 | height = 32 49 | channels = 3 50 | 51 | # Create a Gaussian DDPM with 1000 noise steps 52 | scheduler = CosineScheduler(T=T) 53 | denoiser = UNetTimeStep(channels=[3, 128, 256, 256, 384], 54 | kernel_sizes=[3, 3, 3, 3], 55 | strides=[1, 1, 1, 1], 56 | paddings=[1, 1, 1, 1], 57 | p_dropouts=[0.1, 0.1, 0.1, 0.1], 58 | time_embed_size=100, 59 | downsample=True) 60 | model = GaussianDDPM(denoiser, T, scheduler, vlb=True, lambda_variational=0.0001, width=width, 61 | height=height, input_channels=channels, logging_freq=1_000) # pytorch lightning module 62 | 63 | ``` 64 | 65 | ## Classifier-free Diffusion Guidance 66 | 67 | ```python 68 | 69 | from ddpm import GaussianDDPMClassifierFreeGuidance, UNetTimeStep 70 | from ddpm.variance_scheduler import CosineScheduler 71 | 72 | T = 1_000 73 | width = 32 74 | height = 32 75 | channels = 3 76 | num_classes = 10 77 | 78 | # Create a Gaussian DDPM with 1000 noise steps 79 | scheduler = CosineScheduler(T=T) 80 | denoiser = UNetTimeStep(channels=[3, 128, 256, 256, 384], 81 | kernel_sizes=[3, 3, 3, 3], 82 | strides=[1, 1, 1, 1], 83 | paddings=[1, 1, 1, 1], 84 | p_dropouts=[0.1, 0.1, 0.1, 0.1], 85 | time_embed_size=100, 86 | downsample=True) 87 | model = GaussianDDPMClassifierFreeGuidance(denoiser, T, w=0.3, v=0.2, variance_scheduler=scheduler, width=width, 88 | height=height, input_channels=channels, logging_freq=1_000, p_uncond=0.2, 89 | num_classes=num_classes) # pytorch lightning module 90 | 91 | ``` 92 | 93 | -------------------------------------------------------------------------------- /tests/test_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from model import UNetTimeStep 3 | 4 | 5 | def test_unet_same_width_height(): 6 | x = torch.rand(2, 1, 28, 28) 7 | unet = UNetTimeStep([1, 32, 32, 64, 128], [3, 3, 3, 3], [1, 1, 1, 1], [1, 1, 1, 1], True, [0.0, 0.0, 0.0, 0.0], 100) 8 | x_recon, v = unet(x, torch.tensor([0.25] * 2).view(-1)) 9 | assert x.shape == x_recon.shape 10 | 11 | 12 | def test_unet_different_width_height(): 13 | x = torch.rand(2, 3, 32, 35) 14 | unet = UNetTimeStep([3, 32, 64, 128], [3, 3, 3], [1, 1, 1], [1, 1, 1], True, [0.0, 0.0, 0.0], 10) 15 | x_recon, v = unet(x, torch.tensor([0.25, 0.5]).view(-1)) 16 | assert x.shape == x_recon.shape 17 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from omegaconf import DictConfig, OmegaConf 3 | from path import Path 4 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping 5 | 6 | from torch.utils.data import Dataset, DataLoader 7 | import pytorch_lightning as pl 8 | import omegaconf 9 | import os 10 | 11 | from callbacks.ema import EMA 12 | from callbacks.logger import LoggerCallback 13 | from utils.paths import MODEL 14 | 15 | 16 | # This function is the entry point for the training script. It takes a DictConfig object as an argument, which contains 17 | # the configuration for the training run. The configuration is loaded from a YAML file using Hydra. 18 | @hydra.main('config', 'train.yaml') 19 | def train(config: DictConfig): 20 | # Initialize checkpoint to None 21 | ckpt = None 22 | 23 | # Set random seeds for reproducibility 24 | pl.seed_everything(config.seed) 25 | 26 | # If a checkpoint is specified in the config, load it and update the config accordingly 27 | if config.ckpt is not None: 28 | # Change the current working directory to the parent directory of the checkpoint file 29 | os.chdir(os.path.dirname(config.ckpt)) 30 | 31 | # Assert that the checkpoint file exists 32 | assert os.path.exists(config.ckpt) 33 | 34 | # Set ckpt to the path of the checkpoint file 35 | ckpt = config.ckpt 36 | 37 | # Load the configuration file associated with the checkpoint file 38 | config = OmegaConf.load(os.path.join(os.path.dirname(ckpt), 'config.yaml')) 39 | 40 | # Save the updated configuration to a file called 'config.yaml' 41 | with open('config.yaml', 'w') as f: 42 | omegaconf.OmegaConf.save(config, f) 43 | 44 | Path.cwd().joinpath('gen_images').makedirs_p() 45 | # copy paste model/ folder into the checkpoint folder 46 | MODEL.copytree(Path.cwd().joinpath('model')) 47 | 48 | # Create the variance scheduler and a deep generative model using Hydra 49 | scheduler = hydra.utils.instantiate(config.scheduler) 50 | opt = hydra.utils.instantiate(config.optimizer) 51 | model: pl.LightningModule = hydra.utils.instantiate(config.model, variance_scheduler=scheduler, opt=opt) 52 | 53 | # Create training and validation datasets using Hydra 54 | train_dataset: Dataset = hydra.utils.instantiate(config.dataset.train) 55 | val_dataset: Dataset = hydra.utils.instantiate(config.dataset.val) 56 | 57 | # If a checkpoint is specified, load the model weights from the checkpoint 58 | if ckpt is not None: 59 | model.load_from_checkpoint(ckpt, variance_scheduler=scheduler) 60 | 61 | # Save the hyperparameters of the model to a file called 'hparams.yaml' 62 | model.save_hyperparameters(OmegaConf.to_object(config)['model']) 63 | 64 | # Create PyTorch dataloaders for the training and validation datasets 65 | pin_memory = 'gpu' in config.accelerator 66 | train_dl = DataLoader(train_dataset, batch_size=config.batch_size, pin_memory=pin_memory) 67 | val_dl = DataLoader(val_dataset, batch_size=config.batch_size, pin_memory=pin_memory) 68 | 69 | # Create a ModelCheckpoint callback that saves the model weights to disk during training 70 | ckpt_callback = ModelCheckpoint('./', 'epoch={epoch}-valid_loss={val/loss_epoch}', 71 | monitor='val/loss_epoch', auto_insert_metric_name=False, save_last=True) 72 | ddpm_logger = LoggerCallback(config.freq_logging, config.freq_logging_norm_grad, config.batch_size_gen_images) 73 | callbacks = [ckpt_callback, ddpm_logger] 74 | 75 | # Add additional callbacks if specified in the configuration file 76 | if config.ema: 77 | # Create an Expontential Moving Average callback 78 | callbacks.append(EMA(config.ema_decay)) 79 | if config.early_stop: 80 | callbacks.append(EarlyStopping('val/loss_epoch', min_delta=config.min_delta, patience=config.patience)) 81 | 82 | trainer = pl.Trainer(callbacks=callbacks, accelerator=config.accelerator, devices=config.devices, gradient_clip_val=config.gradient_clip_val, gradient_clip_algorithm=config.gradient_clip_algorithm) 83 | trainer.fit(model, train_dl, val_dl) 84 | 85 | if __name__ == '__main__': 86 | import sys 87 | sys.path.append(Path(__file__).parent.absolute()) 88 | train() 89 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Michedev/DDPMs-Pytorch/4cdb52d5cee7070ddc243ecb781c9fe5e68d0ba8/utils/__init__.py -------------------------------------------------------------------------------- /utils/paths.py: -------------------------------------------------------------------------------- 1 | from path import Path 2 | 3 | 4 | ROOT = Path(__file__).parent.parent 5 | CONFIG = ROOT / 'config' 6 | SCHEDULER = CONFIG / 'scheduler' 7 | MODEL = ROOT / 'model' -------------------------------------------------------------------------------- /variance_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | # Auto-generated by initgen - Made by Mikedev 2 | from .abs_var_scheduler import Scheduler 3 | from .cosine import CosineScheduler 4 | from .hyperbolic_secant import HyperbolicSecant 5 | from .linear import LinearScheduler 6 | 7 | __all__ = [Scheduler, CosineScheduler, HyperbolicSecant, LinearScheduler] -------------------------------------------------------------------------------- /variance_scheduler/abs_var_scheduler.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod, ABC 2 | 3 | 4 | class Scheduler(ABC): 5 | 6 | @abstractmethod 7 | def get_alpha_hat(self): 8 | pass 9 | 10 | @abstractmethod 11 | def get_alphas(self): 12 | pass 13 | 14 | @abstractmethod 15 | def get_betas(self): 16 | pass 17 | 18 | @abstractmethod 19 | def get_betas_hat(self): 20 | pass 21 | -------------------------------------------------------------------------------- /variance_scheduler/cosine.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | import torch 3 | 4 | from variance_scheduler.abs_var_scheduler import Scheduler 5 | 6 | class CosineScheduler(Scheduler): 7 | 8 | 9 | clip_max_value = torch.Tensor([0.999]) 10 | 11 | def __init__(self, T: int, s: float = 0.008): 12 | """ 13 | Cosine variance scheduler. 14 | The equation for the variance is: 15 | alpha_hat = min(cos((t / T + s) / (1 + s) * pi / 2)^2, 0.999) 16 | The equation for the beta is: 17 | beta = 1 - (alpha_hat(t) / alpha_hat(t - 1)) 18 | The equation for the beta_hat is: 19 | beta_hat = (1 - alpha_hat(t - 1)) / (1 - alpha_hat(t)) * beta(t) 20 | """ 21 | self.T = T 22 | self._alpha_hats = self._alpha_hat_function(torch.arange(self.T), T, s) 23 | self._alpha_hats_t_minus_1 = torch.roll(self._alpha_hats, shifts=1, dims=0) # shift forward by 1 so that alpha_first[t] = alpha[t-1] 24 | self._alpha_hats_t_minus_1[0] = self._alpha_hats_t_minus_1[1] # to remove first NaN value 25 | self._betas = 1.0 - self._alpha_hats / self._alpha_hats_t_minus_1 26 | self._betas = torch.minimum(self._betas, self.clip_max_value) 27 | self._alphas = 1.0 - self._betas 28 | self._betas_hat = (1 - self._alpha_hats_t_minus_1) / (1 - self._alpha_hats) * self._betas 29 | self._betas_hat[torch.isnan(self._betas_hat)] = 0.0 30 | 31 | def _alpha_hat_function(self, t: torch.Tensor, T: int, s: float): 32 | """ 33 | Compute the alpha_hat value for a given t value. 34 | :param t: the t value 35 | :param T: the total amount of noising steps 36 | :param s: smoothing parameter 37 | """ 38 | cos_value = torch.pow(torch.cos((t / T + s) / (1 + s) * pi / 2.0), 2) 39 | return cos_value 40 | 41 | def get_alpha_hat(self): 42 | return self._alpha_hats 43 | 44 | def get_alphas(self): 45 | return self._alphas 46 | 47 | def get_betas(self): 48 | return self._betas 49 | 50 | def get_betas_hat(self): 51 | return self._betas_hat 52 | 53 | 54 | if __name__ == '__main__': 55 | scheduler = CosineScheduler(1000) 56 | import matplotlib.pyplot as plt 57 | plt.plot(scheduler.get_alpha_hat().numpy()) 58 | plt.ylabel('$\\alpha_t$') 59 | plt.xlabel('t') 60 | plt.show() -------------------------------------------------------------------------------- /variance_scheduler/hyperbolic_secant.py: -------------------------------------------------------------------------------- 1 | from math import exp 2 | 3 | import torch 4 | from numpy import arctan 5 | 6 | from variance_scheduler.abs_var_scheduler import Scheduler 7 | 8 | 9 | class HyperbolicSecant(Scheduler): 10 | 11 | def __init__(self, T: int, lambda_min: float, lambda_max: float): 12 | self.lambda_min = lambda_min 13 | self.lambda_max = lambda_max 14 | # pg 3 section 2 for the details about the following eqns 15 | self.b = arctan(exp(-lambda_max / 2)) 16 | self.a = arctan(exp(-lambda_min/2)) - self.b 17 | self._beta = - 2 * torch.log(torch.tan(self.a * torch.linspace(0, 1, T, dtype=torch.float) + self.b)) 18 | self._alpha = 1.0 - self._beta 19 | self._alpha_hat = torch.cumprod(self._alpha, dim=0) 20 | self._alpha_hat_t_minus_1 = torch.roll(self._alpha_hat, shifts=1, dims=0) 21 | self._alpha_hat_t_minus_1[0] = self._alpha_hat_t_minus_1[1] 22 | self._beta_hat = (1 - self._alpha_hat_t_minus_1) / (1 - self._alpha_hat) * self._beta 23 | 24 | def get_alpha_hat(self): 25 | return self._alpha_hat 26 | 27 | def get_alphas(self): 28 | return self._alpha 29 | 30 | def get_betas(self): 31 | return self._beta 32 | 33 | def get_betas_hat(self): 34 | return self._beta_hat 35 | -------------------------------------------------------------------------------- /variance_scheduler/linear.py: -------------------------------------------------------------------------------- 1 | from variance_scheduler.abs_var_scheduler import Scheduler 2 | import torch 3 | 4 | 5 | class LinearScheduler(Scheduler): 6 | """ 7 | A scheduler that linearly interpolates between two values of beta over T steps. 8 | """ 9 | 10 | def __init__(self, T: int, beta_start: float, beta_end: float): 11 | """ 12 | Initializes the LinearScheduler. 13 | 14 | Args: 15 | T (int): The number of steps to interpolate over. 16 | beta_start (float): The starting value of beta. 17 | beta_end (float): The ending value of beta. 18 | """ 19 | self.T = T 20 | self.beta_start = beta_start 21 | self.beta_end = beta_end 22 | self._beta = torch.linspace(beta_start, beta_end, T) 23 | self._alpha = 1.0 - self._beta 24 | self._alpha_hat = torch.cumprod(self._alpha, dim=0) 25 | self._alpha_hat_t_minus_1 = torch.roll(self._alpha_hat, shifts=1, dims=0) 26 | self._alpha_hat_t_minus_1[0] = self._alpha_hat_t_minus_1[1] 27 | self._beta_hat = (1 - self._alpha_hat_t_minus_1) / (1 - self._alpha_hat) * self._beta 28 | 29 | def get_alpha_hat(self): 30 | """ 31 | Returns the cumulative product of (1 - beta) over time. 32 | """ 33 | return self._alpha_hat 34 | 35 | def get_alphas(self): 36 | """ 37 | Returns the values of alpha over time. 38 | """ 39 | return self._alpha 40 | 41 | def get_betas(self): 42 | """ 43 | Returns the values of beta over time. 44 | """ 45 | return self._beta 46 | 47 | def get_betas_hat(self): 48 | """ 49 | Returns the values of beta_hat over time. 50 | """ 51 | return self._beta_hat --------------------------------------------------------------------------------