├── README.md ├── config └── vdm_cifar10.yaml ├── src ├── data.py ├── module │ ├── attention.py │ ├── conv.py │ ├── embedding.py │ └── monotonic.py ├── schedule.py ├── unet.py ├── utils.py └── vdm.py ├── test ├── test_attention.py ├── test_schedule.py ├── test_unet.py └── test_vdm.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Variational Diffusion Models in Easy PyTorch 2 | 3 | This repo is an _unofficial_ implementation of `Variational Diffusion Models` as introduced originally in [Kingma et al., (2021)](https://arxiv.org/abs/2107.00630) (revised in 2023). The authors provided an [official implementation](https://github.com/google-research/vdm) in JAX. Other PyTorch implementations exist (see [this nice example](https://github.com/addtt/variational-diffusion-models/tree/main)), so I developed this repo mainly for didactic purposes and as a gentle introduction to [`Bayesian Flow Networks`](https://arxiv.org/abs/2308.07037) that share similar principles. 4 | 5 | ## Usage 6 | 7 | ```python 8 | import torch 9 | 10 | from src.unet import UNet 11 | from src.vdm import VariationalDiffusion 12 | from src.schedule import LinearSchedule 13 | 14 | vdm = VariationalDiffusion( 15 | backbone=UNet( 16 | net_dim=4, 17 | ctrl_dim=None, 18 | use_cond=False, 19 | use_attn=True, 20 | num_group=4, 21 | adapter='b c h w -> b (h w) c', 22 | ), 23 | schedule=LinearSchedule(), # Linear schedule with learnable endpoints 24 | img_shape=(32, 32), 25 | vocab_size=256, 26 | ) 27 | 28 | # Get some fake imgs for testing 29 | imgs = torch.randn(16, 3, 32, 32) 30 | 31 | # Compute the VDM loss, which is a combination of 32 | # diffusion + latent + reconstruction loss 33 | loss, stats = vdm.compute_loss(imgs) 34 | 35 | # Once the model is trained, we can sample from the learnt 36 | # inverse diffusion process by simply doing 37 | num_imgs = 4 38 | num_step = 100 39 | 40 | samples = vdm(num_imgs, num_step) 41 | ``` 42 | 43 | We now support the learnable noise schedule (the $\gamma_\eta(t)$ network in the paper) via the `LearnableSchedule` module. This is implemented via a monotonic linear network (which uses the `MonotonicLinear` module) as described in *Constrained Monotonic Neural Networks* [Runje & Shankaranarayana, ICML (2023)](https://arxiv.org/abs/2205.11775). Moreover, we added preliminary support for optimizing the noise schedule to reduce the variance of the diffusion loss (as discussed in `Appendix I.2` of the main paper). This is achieved via the `reduce_variance` call, which re-uses the already-computed gradient needed for the VLB to reduce computational overhead. 44 | 45 | ```python 46 | import torch 47 | 48 | from src.unet import UNet 49 | from src.vdm import VariationalDiffusion 50 | from src.schedule import LearnableSchedule 51 | 52 | vdm = VariationalDiffusion( 53 | backbone=UNet( 54 | net_dim=4, 55 | ctrl_dim=None, 56 | use_cond=False, 57 | use_attn=True, 58 | num_group=4, 59 | adapter='b c h w -> b (h w) c', 60 | ), 61 | schedule=LearnableSchedule( 62 | hid_dim=[50, 50], 63 | gate_func='relu', 64 | ), # Fully learnable schedule with support for reduced variance 65 | img_shape=(32, 32), 66 | vocab_size=256, 67 | ) 68 | 69 | # Get some fake imgs for testing 70 | imgs = torch.randn(16, 3, 32, 32) 71 | 72 | # Initialize the optimizer of choice 73 | optim = torch.optim.AdamW(vdm.paramters(), lr=1e-3) 74 | optim.zero_grad() 75 | 76 | # First we compute the VLB loss 77 | loss, stats = vdm.compute_loss(imgs) 78 | 79 | # Then we call .backward() to populate the gradients 80 | # NOTE: We need to retain the graph to access the 81 | # gradients, otherwise they are freed 82 | loss.backward(retain_graph=True) 83 | 84 | # Finally we update the noise-schedule gradients to 85 | # support lower variance (faster training) 86 | vdm.reduce_variance(*stats['var_args']) 87 | 88 | # Finally we update the model parameters 89 | optim.step() 90 | 91 | # We can manually delete both loss and stat to put 92 | # the grad graph out-of-scope so it gets freed 93 | def loss, stats 94 | ``` 95 | 96 | We now support training the model via [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/latest/). This is implemented in `src/vdm.py` and can be used from the command line by providing a configuration file (see example in `config/vdm_cifar10.yaml`): 97 | 98 | ```bash 99 | python train.py -conf vdm_cifar10.yaml # all the pytorch-trainer arguments are supported 100 | ``` 101 | 102 | ## Roadmap 103 | 104 | - [x] Put all the essential pieces together: UNet, VDM, a noise schedule. 105 | - [x] Add fully learnable schedule (monotonic neural network). Implement gradient trick described in Appendix I.2 106 | - [x] Add functioning training script (Lightning). 107 | - [ ] Show some results. 108 | 109 | ## Citations 110 | 111 | ```bibtex 112 | @article{kingma2021variational, 113 | title={Variational diffusion models}, 114 | author={Kingma, Diederik and Salimans, Tim and Poole, Ben and Ho, Jonathan}, 115 | journal={Advances in neural information processing systems}, 116 | volume={34}, 117 | pages={21696--21707}, 118 | year={2021} 119 | } 120 | ``` 121 | 122 | ```bibtex 123 | @inproceedings{runje2023constrained, 124 | title={Constrained monotonic neural networks}, 125 | author={Runje, Davor and Shankaranarayana, Sharath M}, 126 | booktitle={International Conference on Machine Learning}, 127 | pages={29338--29353}, 128 | year={2023}, 129 | organization={PMLR} 130 | } 131 | ``` -------------------------------------------------------------------------------- /config/vdm_cifar10.yaml: -------------------------------------------------------------------------------- 1 | UNET: 2 | net_dim : 32 3 | out_dim : null 4 | inp_chn : 3 5 | dropout : 0.1 6 | adapter : 'q c h w -> q (h w) c' 7 | attn_dim : 128 8 | ctrl_dim : null 9 | use_cond : False 10 | use_attn : True 11 | chn_mult : [2, 1, 1, 1] 12 | n_fourier : [7, 8, 1] # n_min=7, n_max=8, step=1 13 | num_group : 8 14 | num_heads : 8 15 | 16 | VDM: 17 | data_key : 'imgs' 18 | vocab_size : 256 19 | sampling_ste: 50 20 | img_shape: [32, 32] 21 | 22 | SCHEDULE: 23 | name : 'learnable' 24 | gamma_min : -13.3 25 | gamma_max : 5.0 26 | hid_dim : [100] 27 | gate_func : 'relu' 28 | act_weight : [7, 7, 2] 29 | 30 | OPTIMIZER: 31 | name : AdamW 32 | learning_rate : 0.0001 33 | weight_decay : 0.05 34 | 35 | DATASET: 36 | root : '' 37 | download: True 38 | num_workers : 4 39 | batch_size : 128 40 | val_batch_size : 128 41 | train_shuffle : True 42 | val_shuffle : True 43 | 44 | TRAINER: 45 | max_epochs : 1000 46 | accelerator : gpu 47 | devices : 4 48 | # strategy : ddp_find_unused_parameters_false 49 | # accumulate_grad_batches : 2 50 | val_check_interval: 1 51 | # limit_val_batches : 1 52 | log_every_n_steps : 1 53 | 54 | MISC: 55 | logs_dir : '' 56 | ckpt_dir : '' 57 | run_name : 'VDM-cifar10' 58 | monitor : 'val_loss' 59 | version : null 60 | save_last : True 61 | resume_ckpt : null 62 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.data import Dataset 3 | from torch.utils.data import DataLoader 4 | from torchvision.datasets import CIFAR10 5 | from torchvision import transforms 6 | 7 | from typing import Callable 8 | 9 | from lightning import LightningDataModule 10 | 11 | class CIFAR10DM(LightningDataModule): 12 | 13 | def __init__( 14 | self, 15 | root : str, 16 | download : bool = False, 17 | 18 | batch_size : int = 16, 19 | num_workers : int = 0, 20 | train_shuffle : bool = True, 21 | val_shuffle : bool = False, 22 | val_batch_size : int | None = None, 23 | worker_init_fn : Callable | None = None, 24 | collate_fn : Callable | None = None, 25 | train_sampler : Callable | None = None, 26 | val_sampler : Callable | None = None, 27 | test_sampler : Callable | None = None, 28 | ) -> None: 29 | super().__init__() 30 | 31 | transform = transforms.Compose([ 32 | transforms.RandomHorizontalFlip(), 33 | transforms.ToTensor(), 34 | transforms.Lambda(lambda x : 2 * x - 1.) 35 | ]) 36 | 37 | self.root = root 38 | self.download = download 39 | self.transform = transform 40 | 41 | self.num_workers = num_workers 42 | self.batch_size = batch_size 43 | self.train_shuffle = train_shuffle 44 | self.val_shuffle = val_shuffle 45 | self.train_sampler = train_sampler 46 | self.valid_sampler = val_sampler 47 | self.test__sampler = test_sampler 48 | self.collate_fn = collate_fn 49 | self.worker_init_fn = worker_init_fn 50 | self.val_batch_size = val_batch_size 51 | 52 | def setup(self, stage = None): 53 | cifar_train = CIFAR10(self.root, 54 | train = True, 55 | transform = self.transform, 56 | download=self.download, 57 | ) 58 | 59 | cifar_val = CIFAR10(self.root, 60 | train = False, 61 | transform = self.transform, 62 | download=self.download, 63 | ) 64 | 65 | # Assign train/val datasets for use in dataloader 66 | if stage == "fit" or stage is None: 67 | self.train_dataset = cifar_train 68 | self.valid_dataset = cifar_val 69 | 70 | # Assign test dataset for use in dataloader(s) 71 | if stage == "test" or stage is None: 72 | self.test_dataset = cifar_val 73 | 74 | def train_dataloader(self) -> DataLoader: 75 | return DataLoader( 76 | self.train_dataset, 77 | sampler = self.train_sampler, 78 | batch_size = self.batch_size, 79 | shuffle = self.train_shuffle, 80 | collate_fn = self.collate_fn, 81 | num_workers = self.num_workers, 82 | worker_init_fn = self.worker_init_fn, 83 | ) 84 | 85 | def val_dataloader(self) -> DataLoader: 86 | return DataLoader( 87 | self.valid_dataset, 88 | sampler = self.valid_sampler, 89 | batch_size = self.val_batch_size, 90 | shuffle = self.val_shuffle, 91 | collate_fn = self.collate_fn, 92 | num_workers = self.num_workers, 93 | worker_init_fn = self.worker_init_fn, 94 | ) 95 | 96 | def test_dataloader(self) -> DataLoader: 97 | return DataLoader( 98 | self.test_dataset, 99 | sampler = self.test__sampler, 100 | batch_size = self.val_batch_size, 101 | shuffle = self.val_shuffle, 102 | collate_fn = self.collate_fn, 103 | num_workers = self.num_workers, 104 | worker_init_fn = self.worker_init_fn, 105 | ) -------------------------------------------------------------------------------- /src/module/attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from torch.nn import MultiheadAttention 4 | 5 | from itertools import starmap 6 | 7 | from einops import rearrange 8 | from typing import Tuple 9 | from torch import Tensor 10 | 11 | from ..utils import default 12 | 13 | class Adapter(nn.Module): 14 | def __init__( 15 | self, 16 | pattern : str | Tuple[str, ...], 17 | qry_dim : int, 18 | key_dim : int | None = None, 19 | val_dim : int | None = None, 20 | emb_dim : int | None = None, 21 | ) -> None: 22 | super(Adapter, self).__init__() 23 | 24 | # If no adapter was provided for key and values, we 25 | # assume self-attention is going to be computed, so 26 | # we simply replicate the pattern for both key and values 27 | if isinstance(pattern, str): pattern = [pattern] * 3 28 | if len(pattern) == 2: pattern = (*pattern, pattern[-1]) 29 | 30 | key_dim = default(key_dim, qry_dim) 31 | val_dim = default(val_dim, qry_dim) 32 | emb_dim = default(emb_dim, qry_dim) 33 | 34 | self.pattern = pattern 35 | self.emb_dim = emb_dim 36 | 37 | self.inv_pattern = ' -> '.join(pattern[0].split('->')[::-1]) 38 | 39 | self.to_q = self._get_adapter(pattern[0], chn_inp=qry_dim, chn_out=emb_dim) 40 | self.to_k = self._get_adapter(pattern[1], chn_inp=key_dim, chn_out=emb_dim) 41 | self.to_v = self._get_adapter(pattern[2], chn_inp=val_dim, chn_out=emb_dim) 42 | 43 | self.from_q = self._get_adapter(pattern[0], chn_inp=emb_dim, chn_out=qry_dim) 44 | 45 | def forward( 46 | self, 47 | qry : Tensor, 48 | key : Tensor | None = None, 49 | val : Tensor | None = None, 50 | ) -> Tensor: 51 | ''' 52 | ''' 53 | 54 | q = self.to_q(qry) 55 | k = self.to_k(key) 56 | v = self.to_v(val) 57 | 58 | self.q_shape = q.shape 59 | 60 | q, k, v = starmap(lambda t, adapt : rearrange(t, adapt), zip((q, k, v), self.pattern)) 61 | 62 | return q, k, v 63 | 64 | def restore(self, attn : Tensor) -> Tensor: 65 | ''' 66 | ''' 67 | 68 | if not hasattr(self, 'q_shape'): 69 | raise ValueError('Cannot restore before forward pass has been called') 70 | 71 | # Prepare the appropriate kwargs for rearrange by composing the known 72 | # inverse adapter with the stored qry shape from forward pass 73 | names = [c for c in self.inv_pattern.split('->')[-1] if c.isalpha()] 74 | kwargs = {k : v for k, v in zip(names, self.q_shape)} 75 | 76 | attn = rearrange(attn, self.inv_pattern, **kwargs) 77 | 78 | return self.from_q(attn) 79 | 80 | def _get_adapter( 81 | self, 82 | pattern : str | Tuple[str, ...], 83 | chn_inp : int, 84 | chn_out : int, 85 | ) -> nn.Module: 86 | if chn_inp == chn_out: return nn.Identity() 87 | 88 | dim_out = sum([c.isalpha() for c in pattern.split('->')[-1]]) 89 | 90 | match dim_out: 91 | case 0: return nn.Linear(chn_inp, chn_out, bias=False) 92 | case 3: return nn.Conv1d(chn_inp, chn_out, 1, bias=False) 93 | case 4: return nn.Conv2d(chn_inp, chn_out, 1, bias=False) 94 | case 5: return nn.Conv3d(chn_inp, chn_out, 1, bias=False) 95 | case _: pass 96 | 97 | raise ValueError(f'Input shape not supported. Got {dim_out}') 98 | 99 | class AdaptiveAttention(MultiheadAttention): 100 | def __init__( 101 | self, 102 | emb_dim, 103 | n_heads, 104 | pattern : str, 105 | qry_dim : int | None = None, 106 | key_dim : int | None = None, 107 | val_dim : int | None = None, 108 | batch_first : bool = True, 109 | **kwargs 110 | ) -> None: 111 | super(AdaptiveAttention, self).__init__(emb_dim, n_heads, batch_first=batch_first, **kwargs) 112 | 113 | qry_dim = default(qry_dim, emb_dim) 114 | key_dim = default(key_dim, qry_dim) 115 | val_dim = default(val_dim, key_dim) 116 | 117 | # Build the attention adepter 118 | self.adapter = Adapter( 119 | pattern=pattern, 120 | qry_dim=qry_dim, 121 | key_dim=key_dim, 122 | val_dim=val_dim, 123 | emb_dim=emb_dim, 124 | ) 125 | 126 | def forward( 127 | self, 128 | qry : Tensor, 129 | key : Tensor | None = None, 130 | val : Tensor | None = None, 131 | return_weights : bool = False, 132 | **kwargs 133 | ) -> Tensor | Tuple[Tensor, Tensor]: 134 | ''' 135 | 136 | ''' 137 | 138 | key = default(key, qry) 139 | val = default(val, key) 140 | 141 | # Adapt the inputs to the expected format by the MHA module 142 | qry, key, val = self.adapter(qry, key, val) 143 | 144 | # Compute the attention output 145 | attn, attn_weights = super().forward(qry, key, val, **kwargs) 146 | 147 | # Restore the correct output format 148 | attn = self.adapter.restore(attn) 149 | 150 | return (attn, attn_weights) if return_weights else attn 151 | -------------------------------------------------------------------------------- /src/module/conv.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | from torch import Tensor 4 | 5 | from ..utils import exists 6 | from ..utils import default 7 | from ..utils import enlarge_as 8 | 9 | def Upscale(dim_in, dim_out : int = None, factor : int = 2): 10 | return nn.Sequential( 11 | nn.Upsample(scale_factor = factor, mode = 'nearest'), 12 | nn.Conv2d(dim_in, default(dim_out, dim_in), 3, padding = 1) 13 | ) if factor > 1 else nn.Identity() 14 | 15 | def Downscale(dim_in, dim_out : int = None, factor : int = 2): 16 | return nn.Conv2d(dim_in, default(dim_out, dim_in), 2 * factor, factor, 1)\ 17 | if factor > 1 else nn.Identity() 18 | 19 | class ContextRes(nn.Module): 20 | ''' 21 | Convolutional Residual Block with context embedding 22 | injection support, used by Diffusion Models. It is 23 | composed of two convolutional layers with normalization. 24 | The context embedding signal is injected between the two 25 | convolutions (optionally) and is added to the input to 26 | the second one. 27 | ''' 28 | 29 | def __init__( 30 | self, 31 | inp_dim : int, 32 | out_dim : int | None = None, 33 | hid_dim : int | None = None, 34 | ctx_dim : int | None = None, 35 | num_group : int = 8, 36 | dropout : float = 0., 37 | ) -> None: 38 | super().__init__() 39 | 40 | out_dim = default(out_dim, inp_dim) 41 | hid_dim = default(hid_dim, out_dim) 42 | ctx_dim = default(ctx_dim, out_dim) 43 | 44 | self.time_emb = nn.Sequential( 45 | nn.SiLU(inplace = False), 46 | nn.Linear(ctx_dim, hid_dim), 47 | ) 48 | 49 | self.conv1 = nn.Sequential( 50 | nn.Conv2d(inp_dim, hid_dim, kernel_size = 3, padding = 1), 51 | nn.GroupNorm(num_group, hid_dim), 52 | nn.SiLU(inplace = False), 53 | ) 54 | 55 | self.conv2 = nn.Sequential( 56 | *([nn.Dropout(dropout)] * (dropout > 0.)), 57 | nn.Conv2d(hid_dim, out_dim, kernel_size = 3, padding = 1), 58 | nn.GroupNorm(num_group, out_dim), 59 | nn.SiLU(inplace = False), 60 | ) 61 | 62 | self.skip = nn.Conv2d(inp_dim, out_dim, 1) if inp_dim != out_dim else nn.Identity() 63 | 64 | def forward( 65 | self, 66 | inp : Tensor, 67 | ctx : Tensor | None = None, 68 | ) -> Tensor: 69 | 70 | # Perform first convolution block 71 | h = self.conv1(inp) 72 | 73 | if exists(ctx): 74 | # Add embedded time signal with appropriate 75 | # broadcasting to match image-like tensors 76 | ctx = self.time_emb(ctx) 77 | h += enlarge_as(ctx, h) 78 | 79 | h = self.conv2(h) 80 | 81 | return self.skip(inp) + h -------------------------------------------------------------------------------- /src/module/embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torch import Tensor 5 | 6 | from einops import einsum 7 | from einops import rearrange 8 | 9 | class TimeEmbedding(nn.Module): 10 | ''' 11 | Embedding for time-like data used by diffusion models. 12 | ''' 13 | 14 | def __init__( 15 | self, 16 | emb_dim : int, 17 | base : int = 10000 18 | ) -> None: 19 | super().__init__() 20 | 21 | self.emb_dim = emb_dim 22 | self.base = base 23 | 24 | def forward(self, time : Tensor) -> Tensor: 25 | # NOTE: We multiply by 1000 because time in variational models 26 | # is formalized in [0, 1], so we just upscale to 1000 for 27 | # proper embedding computation 28 | time *= 1000 29 | 30 | # Check for correct time shape 31 | bs, _ = time.shape 32 | 33 | half_dim = self.emb_dim // 2 34 | emb_time = torch.empty((bs, self.emb_dim), device = time.device) 35 | 36 | pos_n = torch.arange(half_dim, device = time.device) 37 | inv_f = 1. / (self.base ** (pos_n / (half_dim - 1))) 38 | 39 | emb_v = einsum(time, inv_f, 'b _, f -> b f') 40 | 41 | emb_time[..., 0::2] = emb_v.sin() 42 | emb_time[..., 1::2] = emb_v.cos() 43 | 44 | return emb_time 45 | 46 | class FourierEmbedding(nn.Module): 47 | ''' 48 | Set of Fourier Features to add to the input latent 49 | code "z" of the noise-predictor UNet model to ease 50 | its handling of the high-frequency components of the 51 | input, which have a significant impact on the likelihood 52 | (despite not that much for the visual appearance). 53 | ''' 54 | 55 | def __init__( 56 | self, 57 | n_min : int = 7, 58 | n_max : int = 8, 59 | n_step : int = 1, 60 | ) -> None: 61 | ''' 62 | 63 | ''' 64 | super().__init__() 65 | 66 | self.n_exp = torch.arange(n_min, n_max, n_step) 67 | self.n_feat = len(self.n_exp) 68 | 69 | def forward(self, z : Tensor) -> Tensor: 70 | ''' 71 | Add the Fourier features to the input latent code. 72 | The features are concatenated to the channel dimension 73 | of the input vector, which is expected to have shape 74 | [batch_size, chn_dim, ...]. 75 | 76 | The Fourier features are defined as: 77 | - f^n_ijk = sin(2^n pi z_ijk) 78 | - g^n_ijk = cos(2^n pi z_ijk) 79 | ''' 80 | 81 | (bs, chn, *_), device = z.shape, z.device 82 | 83 | freq = einsum(2 ** self.n_exp.to(device), z, 'n, b c ... -> b n c ...') 84 | freq = rearrange(freq, 'b n c ... -> b (n c) ...') 85 | 86 | f = freq.sin() 87 | g = freq.cos() 88 | 89 | return torch.cat([z, f, g], dim = 1) 90 | 91 | 92 | -------------------------------------------------------------------------------- /src/module/monotonic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch import Tensor 6 | from typing import Tuple, Callable 7 | 8 | from einops import repeat 9 | from functools import partial 10 | 11 | from ..utils import default 12 | 13 | def saturating_func( 14 | x : Tensor, 15 | conv_f : Callable[[Tensor], Tensor] = None, 16 | conc_f : Callable[[Tensor], Tensor] = None, 17 | slope : float = 1., 18 | const : float = 1., 19 | ) -> Tensor: 20 | conv = conv_f(+torch.ones_like(x) * const) 21 | 22 | return slope * torch.where( 23 | x <= 0, 24 | conv_f(x + const) - conv, 25 | conc_f(x - const) + conv, 26 | ) 27 | 28 | class MonotonicLinear(nn.Linear): 29 | ''' 30 | Monotonic Linear Layer as introduced in: 31 | `Constrained Monotonic Neural Networks` ICML (2023). 32 | 33 | Code is a PyTorch implementation of the official repository: 34 | https://github.com/airtai/mono-dense-keras/ 35 | ''' 36 | 37 | def __init__( 38 | self, 39 | in_features : int, 40 | out_features : int, 41 | bias : bool = True, 42 | gate_func : str = 'elu', 43 | indicator : int | Tensor | None = None, 44 | act_weight : str | Tuple[float, float, float] = (7, 7, 2), 45 | ) -> None: 46 | # Assume positive monotonicity in all input features 47 | indicator = default(indicator, torch.ones(in_features)) 48 | 49 | if isinstance(indicator, int): 50 | indicator = torch.ones(in_features) * indicator 51 | 52 | assert indicator.dim() == 1, 'Indicator tensor must be 1-dimensional.' 53 | assert indicator.size(-1) == in_features, 'Indicator tensor must have the same number of elements as the input features.' 54 | assert len(act_weight) == 3, f'Relative activation weights should have len = 3. Got {len(act_weight)}' 55 | if isinstance(act_weight, str): assert act_weight in ('concave', 'convex') 56 | 57 | self.indicator = indicator 58 | 59 | # Compute the three activation functions: concave|convex|saturating 60 | match gate_func: 61 | case 'elu' : self.act_conv = F.elu 62 | case 'silu': self.act_conv = F.silu 63 | case 'gelu': self.act_conv = F.gelu 64 | case 'relu': self.act_conv = F.relu 65 | case 'selu': self.act_conv = F.selu 66 | case _: raise ValueError(f'Unknown gating function {gate_func}') 67 | 68 | self.act_conc = lambda t : -self.act_conv(-t) 69 | self.act_sat = partial( 70 | saturating_func, 71 | conv_f=self.act_conv, 72 | conc_f=self.act_conc, 73 | ) 74 | 75 | match act_weight: 76 | case 'concave': self.act_weight = torch.tensor((1, 0, 0)) 77 | case 'convex' : self.act_weight = torch.tensor((0, 1, 0)) 78 | case _: self.act_weight = torch.tensor(act_weight) / sum(act_weight) 79 | 80 | # Build the layer weights and bias 81 | super(MonotonicLinear, self).__init__(in_features, out_features, bias) 82 | 83 | def forward(self, x : Tensor) -> Tensor: 84 | ''' 85 | ''' 86 | 87 | # Get the absolute values of the weights 88 | abs_weights = self.weight.data.abs() 89 | 90 | # * Use monotonicity indicator T to adjust the layer weights 91 | # * T_i = +1 -> W_ij <= || W_ij || 92 | # * T_i = -1 -> W_ij <= -|| W_ij || 93 | # * T_i = 0 -> do nothing 94 | mask_pos = self.indicator == +1 95 | mask_neg = self.indicator == -1 96 | 97 | self.weight.data[..., mask_pos] = +abs_weights[..., mask_pos] 98 | self.weight.data[..., mask_neg] = -abs_weights[..., mask_neg] 99 | 100 | # Get the output of linear layer 101 | out = super().forward(x) 102 | 103 | # Compute output by adding non-linear gating according to 104 | # relative importance of activations 105 | s_conv, s_conc, _ = (self.act_weight * self.out_features).round() 106 | s_conv = int(s_conv) 107 | s_conc = int(s_conc) 108 | s_sat = self.out_features - s_conv - s_conc 109 | 110 | i_conv, i_conc, i_sat = torch.split( 111 | out, (s_conv, s_conc, s_sat), dim=-1 112 | ) 113 | 114 | out = torch.cat(( 115 | self.act_conv(i_conv), 116 | self.act_conc(i_conc), 117 | self.act_sat (i_sat), 118 | ), 119 | dim=-1, 120 | ) 121 | 122 | return out -------------------------------------------------------------------------------- /src/schedule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torch import Tensor 5 | from typing import List, Tuple 6 | 7 | from itertools import pairwise 8 | 9 | from .module.monotonic import MonotonicLinear 10 | 11 | class LinearSchedule(nn.Module): 12 | ''' 13 | Simple Linear schedule with learnable endpoints. 14 | ''' 15 | 16 | def __init__( 17 | self, 18 | gamma_min : float = -13.3, 19 | gamma_max : float = 5.0, 20 | ) -> None: 21 | super().__init__() 22 | 23 | self.q : Tensor = nn.Parameter(torch.tensor(gamma_min)) 24 | self.m : Tensor = nn.Parameter(torch.tensor(gamma_max - gamma_min)) 25 | 26 | def forward(self, t : float) -> Tensor: 27 | return self.m.abs() * t + self.q 28 | 29 | class LearnableSchedule(nn.Module): 30 | ''' 31 | Monotonic schedule represented by a MLP that 32 | learns the optimal schedule to follow to minimize 33 | the VLB variance (granting faster more stable training). 34 | 35 | Monotonicity is ensured by using MonotonicLinear layers 36 | as introduced in: 37 | `Constrained Monotonic Neural Networks` ICML (2023). 38 | ''' 39 | 40 | def __init__( 41 | self, 42 | gamma_min : float = -13.3, 43 | gamma_max : float = 5.0, 44 | hid_dim : int | List[int] = 3, 45 | gate_func : str = 'relu', 46 | act_weight : Tuple[float, float, float] = (7, 7, 2), 47 | ) -> None: 48 | super().__init__() 49 | 50 | if isinstance(hid_dim, int): hid_dim = [hid_dim] 51 | dims = [1, *hid_dim, 1] 52 | 53 | # Create the MLP 54 | self.layers = nn.Sequential( 55 | *(MonotonicLinear( 56 | inp_dim, out_dim, bias=True, 57 | gate_func=gate_func, 58 | indicator=+1 if layer > 0 else -1, 59 | act_weight=act_weight, 60 | ) for layer, (inp_dim, out_dim) in enumerate(pairwise(dims))) 61 | ) 62 | 63 | self.gamma_min = nn.Parameter(torch.tensor(gamma_min)) 64 | self.gamma_max = nn.Parameter(torch.tensor(gamma_max)) 65 | 66 | def forward(self, t : Tensor) -> Tensor: 67 | # Compute output for intermediate times in [0, 1] 68 | gamma_t = self.layers(t) 69 | 70 | # * Rescale the output to lie between SNR_min, SNR_max 71 | # * where gamma_0 = -log(SNR_max), gamma_1 = -log(SNR_min) 72 | gamma_0 = self.layers(torch.zeros_like(t)) 73 | gamma_1 = self.layers(torch.ones_like (t)) 74 | 75 | out = self.gamma_min + (self.gamma_max - self.gamma_min) * ( 76 | (gamma_t - gamma_0) / (gamma_1 - gamma_0) 77 | ) 78 | 79 | return out 80 | -------------------------------------------------------------------------------- /src/unet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from torch import Tensor 5 | from typing import Tuple, List 6 | 7 | from .utils import exists 8 | from .utils import default 9 | 10 | from .module.conv import Upscale 11 | from .module.conv import Downscale 12 | from .module.conv import ContextRes 13 | from .module.embedding import TimeEmbedding 14 | from .module.embedding import FourierEmbedding 15 | from .module.attention import AdaptiveAttention 16 | 17 | class UNet(nn.Module): 18 | ''' 19 | U-Net model as introduced in: 20 | "U-Net: Convolutional Networks for Biomedical Image Segmentation". 21 | It is a common choice as network backbone for diffusion models. 22 | ''' 23 | 24 | def __init__( 25 | self, 26 | net_dim : int = 4, 27 | out_dim : int | None = None, 28 | inp_chn : int = 3, 29 | dropout : float = 0., 30 | adapter : str | Tuple[str, ...] = 'b c h w -> b (h w) c', 31 | attn_dim : int = 128, 32 | ctrl_dim : int | None = None, 33 | use_cond : bool = False, 34 | use_attn : bool = False, 35 | chn_mult : List[int] = (1, 2, 4, 8), 36 | n_fourier : Tuple[int, ...] | None = None, 37 | num_group : int = 8, 38 | num_heads : int = 4, 39 | ) -> None: 40 | super().__init__() 41 | 42 | out_dim = default(out_dim, inp_chn) 43 | 44 | self.inp_chn = inp_chn 45 | self.use_cond = use_cond 46 | self.use_attn = use_attn 47 | 48 | # * Build the input embeddings 49 | # Optional Fourier Feature Embeddings 50 | self.fourier_emb = FourierEmbedding(*n_fourier) if exists(n_fourier) else nn.Identity() 51 | 52 | # Time Embeddings 53 | ctx_dim = net_dim * 4 54 | self.time_emb = nn.Sequential( 55 | TimeEmbedding(net_dim), 56 | nn.Linear(net_dim, ctx_dim), 57 | nn.GELU(), 58 | nn.Linear(ctx_dim, ctx_dim) 59 | ) 60 | 61 | # NOTE: We need channels * 2 to accommodate for the self-conditioning 62 | tot_chn = inp_chn * (1 + use_cond + (2 * self.fourier_emb.n_feat if exists(n_fourier) else 0)) 63 | 64 | self.proj_inp = nn.Conv2d(tot_chn, net_dim, 7, padding = 3) 65 | 66 | dims = [net_dim, *map(lambda m: net_dim * m, chn_mult)] 67 | mid_dim = dims[-1] 68 | 69 | dims = list(zip(dims, dims[1:])) 70 | 71 | # * Building the model. It has three main components: 72 | # * 1) The downscale modules 73 | # * 2) The bottleneck modules 74 | # * 3) The upscale modules 75 | self.downs = nn.ModuleList([]) 76 | self.ups = nn.ModuleList([]) 77 | num_resolutions = len(dims) 78 | 79 | # Build up the downscale module part 80 | for idx, (dim_in, dim_out) in enumerate(dims): 81 | is_last = idx >= (num_resolutions - 1) 82 | 83 | self.downs.append(nn.ModuleList([ 84 | ContextRes(dim_in, dim_in, ctx_dim=ctx_dim, num_group=num_group, dropout=dropout), 85 | ContextRes(dim_in, dim_in, ctx_dim=ctx_dim, num_group=num_group, dropout=dropout), 86 | AdaptiveAttention(attn_dim, num_heads, adapter, qry_dim=dim_in, key_dim=ctrl_dim) if use_attn else nn.Identity(), 87 | nn.Conv2d(dim_in, dim_out, 3, padding = 1) if is_last else Downscale(dim_in, dim_out) 88 | ])) 89 | 90 | # Buildup the bottleneck module 91 | self.mid_block1 = ContextRes(mid_dim, mid_dim, ctx_dim=ctx_dim, num_group=num_group) 92 | self.mid_attn = AdaptiveAttention(attn_dim, num_heads, adapter, qry_dim=mid_dim, key_dim=ctrl_dim) 93 | self.mid_block2 = ContextRes(mid_dim, mid_dim, ctx_dim=ctx_dim, num_group=num_group) 94 | 95 | # Build the upscale module part 96 | # NOTE: We need to make rooms for incoming residual connections from the downscale layers 97 | for idx, (dim_in, dim_out) in enumerate(reversed(dims)): 98 | is_last = idx >= (num_resolutions - 1) 99 | 100 | self.ups.append(nn.ModuleList([ 101 | ContextRes(dim_in + dim_out, dim_out, ctx_dim=ctx_dim, num_group=num_group, dropout=dropout), 102 | ContextRes(dim_in + dim_out, dim_out, ctx_dim=ctx_dim, num_group=num_group, dropout=dropout), 103 | AdaptiveAttention(attn_dim, num_heads, adapter, qry_dim=dim_out, key_dim=ctrl_dim) if use_attn else nn.Identity(), 104 | nn.Conv2d(dim_out, dim_in, 3, padding = 1) if is_last else Upscale(dim_out, dim_in) 105 | ])) 106 | 107 | self.final = ContextRes(net_dim * 2, net_dim, ctx_dim = ctx_dim, num_group = num_group) 108 | self.proj_out = nn.Conv2d(net_dim, out_dim, 1) 109 | 110 | def forward( 111 | self, 112 | imgs : Tensor, 113 | time : Tensor, 114 | cond : Tensor | None = None, 115 | ctrl : Tensor | None = None, 116 | ) -> Tensor: 117 | ''' 118 | Compute forward pass of the U-Net module. Expect input 119 | to be image-like and expects an auxiliary time signal 120 | (1D-like) to be provided as well. An optional contextual 121 | signal can be provided and will be used by the attention 122 | gates that will function as cross-attention as opposed 123 | to self-attentions. 124 | 125 | Params: 126 | - imgs: Tensor of shape [batch, channel, H, W] 127 | - time: Tensor of shape [batch, 1] 128 | - context[optional]: Tensor of shape [batch, seq_len, emb_dim] 129 | 130 | Returns: 131 | - imgs: Processed images, tensor of shape [batch, channel, H, W] 132 | ''' 133 | 134 | # Optional self-conditioning to the model (we default to original 135 | # input size before fourier embeddings are added) 136 | cond = default(cond, torch.zeros_like(imgs)) 137 | 138 | # Add (optional) Fourier Embeddings 139 | imgs = self.fourier_emb(imgs) 140 | 141 | if self.use_cond: imgs = torch.cat((imgs, cond), dim = 1) 142 | 143 | x : Tensor = self.proj_inp(imgs) 144 | t : Tensor = self.time_emb(time) 145 | 146 | h = [x.clone()] 147 | 148 | for conv1, conv2, attn, down in self.downs: 149 | x = conv1(x, t) 150 | h += [x] 151 | 152 | x = conv2(x, t) 153 | x = attn(x, ctrl) if self.use_attn else x 154 | h += [x] 155 | 156 | x = down(x) 157 | 158 | x = self.mid_block1(x, t) 159 | x = self.mid_attn(x, ctrl) 160 | x = self.mid_block2(x, t) 161 | 162 | for conv1, conv2, attn, up in self.ups: 163 | x = torch.cat((x, h.pop()), dim = 1) 164 | x = conv1(x, t) 165 | 166 | x = torch.cat((x, h.pop()), dim = 1) 167 | x = conv2(x, t) 168 | x = attn(x, ctrl) if self.use_attn else x 169 | 170 | x = up(x) 171 | 172 | x = torch.cat((x, h.pop()), dim = 1) 173 | 174 | x = self.final(x, t) 175 | 176 | return self.proj_out(x) -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from typing import Any 5 | from torch import Tensor 6 | 7 | from einops import rearrange 8 | 9 | def exists(var : Any | None) -> bool: 10 | return var is not None 11 | 12 | def default(var : Any | None, val : Any) -> Any: 13 | return var if exists(var) else val 14 | 15 | def enlarge_as(a : Tensor, b : Tensor) -> Tensor: 16 | ''' 17 | Add sufficient number of singleton dimensions 18 | to tensor a **to the right** so to match the 19 | shape of tensor b. NOTE that simple broadcasting 20 | works in the opposite direction. 21 | ''' 22 | return rearrange(a, f'... -> ...{" 1" * (b.dim() - a.dim())}').contiguous() -------------------------------------------------------------------------------- /src/vdm.py: -------------------------------------------------------------------------------- 1 | from lightning.pytorch.utilities.types import STEP_OUTPUT 2 | import torch 3 | import warnings 4 | 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | 8 | from torch import Tensor 9 | from torch import autograd 10 | from torch import sqrt, sigmoid, prod 11 | from torch import exp, expm1, log 12 | from torch import log_softmax 13 | from typing import Any, Tuple, Dict, List 14 | from itertools import pairwise 15 | 16 | from torchvision.utils import make_grid 17 | 18 | from lightning import LightningModule 19 | 20 | from tqdm.auto import tqdm 21 | 22 | from einops import reduce 23 | from einops import rearrange 24 | 25 | from .utils import exists 26 | from .utils import default 27 | from .utils import enlarge_as 28 | 29 | from .unet import UNet 30 | from .schedule import LinearSchedule 31 | from .schedule import LearnableSchedule 32 | 33 | loge2 = torch.log(torch.tensor(2)) 34 | 35 | class VariationalDiffusion(LightningModule): 36 | ''' 37 | 38 | ''' 39 | 40 | @classmethod 41 | def from_conf(cls, conf_file : Dict[str, Any]) -> 'VariationalDiffusion': 42 | 43 | vdm_conf = conf_file['VDM'] 44 | unet_conf = conf_file['UNET'] 45 | optim_conf= conf_file['OPTIMIZER'] 46 | schedule_conf = conf_file['SCHEDULE'] 47 | 48 | schedule_name = schedule_conf.pop('name') 49 | 50 | match schedule_name: 51 | case 'linear': Schedule = LinearSchedule 52 | case 'learnable': Schedule = LearnableSchedule 53 | case _: raise ValueError(f'Unknown schedule: {schedule_name}') 54 | 55 | # Build the VDM model 56 | return cls( 57 | backbone=UNet(**unet_conf), 58 | schedule=Schedule(**schedule_conf), 59 | optim_conf=optim_conf, 60 | **vdm_conf, 61 | ) 62 | 63 | def __init__( 64 | self, 65 | backbone : nn.Module, 66 | schedule : nn.Module | None = None, 67 | img_shape : Tuple[int, int] = (64, 64), 68 | vocab_size : int = 256, 69 | data_key : str = 'imgs', 70 | ctrl_key : str | None = None, 71 | sampling_step : int = 50, 72 | optimizer_conf : Dict[str, Any] | None = None, 73 | ) -> None: 74 | super().__init__() 75 | 76 | self.backbone : nn.Module = backbone 77 | self.schedule : nn.Module = default(schedule, LinearSchedule()) 78 | 79 | img_chn = self.backbone.inp_chn 80 | 81 | self.img_shape = (img_chn, *img_shape) 82 | self.vocab_size = vocab_size 83 | 84 | self.data_key = data_key 85 | self.ctrl_key = ctrl_key 86 | self.opt_conf : Dict[str, Any] = default(optimizer_conf, {'lr' : 1e-3}) 87 | 88 | self.num_step = sampling_step 89 | 90 | @property 91 | def device(self): 92 | return next(self.backbone.parameters()).device 93 | 94 | def training_step(self, batch : Dict[str, Tensor], batch_idx : int) -> Tensor: 95 | # Extract the starting images from data batch 96 | x_0 = batch[self.data_key] 97 | ctrl = batch[self.ctrl_key] if exists(self.ctrl_key) else None 98 | 99 | # Compute the VLB loss 100 | loss, stat = self.compute_loss(x_0) 101 | 102 | self.log_dict({'train_loss' : loss}, logger = True, on_step = True, sync_dist = True) 103 | self.log_dict({f'train_{k}' : v for k, v in stat.item()}, logger = True, on_step = True, sync_dist = True) 104 | 105 | return loss 106 | 107 | def validation_step(self, batch : Dict[str, Tensor], batch_idx : int) -> Tensor: 108 | # Extract the starting images from data batch 109 | x_0 = batch[self.data_key] 110 | ctrl = batch[self.ctrl_key] if exists(self.ctrl_key) else None 111 | 112 | # Compute the VLB loss 113 | loss, stat = self.compute_loss(x_0) 114 | 115 | self.log_dict({'val_loss' : loss}, logger=True, on_step=True, sync_dist=True) 116 | self.log_dict({f'val_{k}' : v for k, v in stat.item()}, logger=True, on_step=True, sync_dist=True) 117 | 118 | return x_0, ctrl 119 | 120 | @torch.no_grad() 121 | def validation_epoch_end(self, val_outs : List[Tuple[Tensor, Tensor | None]]) -> None: 122 | ''' 123 | At the end of the validation cycle, we inspect how the training 124 | procedure is doing by sampling novel images from the learn distribution. 125 | ''' 126 | 127 | # Collect the input shapes 128 | (x_0, ctrl), *_ = val_outs 129 | 130 | # Produce 8 samples and log them 131 | imgs = self( 132 | num_imgs=8, 133 | num_step=self.num_step, 134 | # ctrl = ctrl, 135 | verbose = False, 136 | ) 137 | 138 | assert not torch.isnan(imgs).any(), 'NaNs detected in imgs!' 139 | 140 | imgs = make_grid(imgs, nrow = 4) 141 | 142 | # Log images using the default TensorBoard logger 143 | self.logger.experiment.add_image('VDM', imgs, global_step=self.global_step) 144 | 145 | def configure_optimizers(self): 146 | opt_name = self.opt_conf.pop('name') 147 | match opt_name: 148 | case 'AdamW': Optim = optim.AdamW 149 | case 'SGD' : Optim = optim.SGD 150 | case _: raise ValueError(f'Unknown optimizer: {opt_name}') 151 | 152 | params = list(self.backbone.parameters()) +\ 153 | list(self.schedule.parameters()) 154 | 155 | opt_kw = self.opt_conf 156 | 157 | opt = Optim(params, **opt_kw) 158 | 159 | return opt 160 | 161 | @torch.no_grad() 162 | def forward( 163 | self, 164 | num_imgs : int, 165 | num_steps : int, 166 | seed_noise : Tensor | None = None, 167 | verbose : bool = False, 168 | ) -> Tensor: 169 | ''' 170 | We reserve the forward call of the model to the posterior sampling, 171 | that is used with a fully trained model to generate samples, hence 172 | the torch.no_grad() decorator. 173 | ''' 174 | device = self.device 175 | 176 | z_s = default(seed_noise, torch.randn((num_imgs, *self.img_shape), device=device)) 177 | 178 | # Sample the reverse time-steps and compute the corresponding 179 | # noise schedule values (gammas) 180 | time = torch.linspace(1., 0., num_steps + 1, device=device) 181 | gamma = self.schedule(time) 182 | 183 | iterator = pairwise(gamma) 184 | iterator = tqdm(iterator, total=num_steps) if verbose else iterator 185 | for gamma_t, gamma_s in iterator: 186 | # Sample from the backward diffusion process 187 | z_s = self._coalesce(z_s, gamma_t, gamma_s) 188 | 189 | # After the backward process we are left with the z_0 latent from 190 | # which we should estimate the probabilities of the data x via 191 | # p (x | z_0) =~ q (z_0 | x), which is a good approximation whenever 192 | # SNR(t = 0) is high-enough (as we basically don't corrupt x). 193 | # NOTE: We pass a rescaled z_0 by the mean which is alpha_0, 194 | # we re-use gamma_s as the last step corresponds to t=0 195 | alpha_0 = sqrt(sigmoid(gamma_s)) 196 | 197 | # Decode the probability for each data bin, expected prob shape is: 198 | # [batch_size, C, H, W, vocal_size] 199 | prob = self._data_prob(z_s / alpha_0, gamma_s) 200 | 201 | # Our sample is obtained by taking the highest probability bin among 202 | # all the possible data values 203 | img = torch.argmax(prob, dim=-1) 204 | 205 | # Normalize image to be in [0, 1] 206 | return img.float() / (self.vocab_size - 1) 207 | 208 | def compute_loss(self, imgs : Tensor) -> Tensor: 209 | ''' 210 | L_∞ = L_diffusion + L_latent + L_reconstruction. 211 | 212 | This loss comes from minimizing the Variational Lower Bound (VLB), 213 | which is: -log p(x) < -VLB(x). 214 | ''' 215 | 216 | bs, *img_shape = imgs.shape 217 | bpd = 1. / prod(torch.tensor(img_shape)) * loge2 218 | 219 | # Rescale image tensor (expected in range [0, 1]) to [-1 + 1/vs, +1 - 1/vs] 220 | # (vs = vocab-size) 221 | idxs = torch.round(imgs * (self.vocab_size - 1)).long() 222 | imgs = 2 * ((idxs + .5) / self.vocab_size) - 1 223 | 224 | # Compute the gamma at the time endpoints: gamma_0 | gamma_1 225 | gamma_0 : Tensor = self.schedule(torch.tensor([0.], device=self.device)) 226 | gamma_1 : Tensor = self.schedule(torch.tensor([1.], device=self.device)) 227 | 228 | diffusion_loss, SNR_t = self._diffusion_loss(imgs) 229 | latent_loss = self._latent_loss(imgs, gamma_1) 230 | recon_loss = self._recon_loss(imgs, gamma_0, idxs) 231 | 232 | # Compute the total loss as the sum of the three losses 233 | loss = (diffusion_loss + latent_loss + recon_loss).mean() * bpd 234 | 235 | stat = { 236 | 'tot_loss' : loss.item(), 237 | 'var_args' : (SNR_t, diffusion_loss), 238 | 'gamma_0' : gamma_0.item(), 239 | 'gamma_1' : gamma_1.item(), 240 | 'recon_loss' : bpd * recon_loss.mean(), 241 | 'latent_loss' : bpd * latent_loss.mean(), 242 | 'diffusion_loss' : bpd * diffusion_loss.mean(), 243 | } 244 | 245 | return loss, stat 246 | 247 | def reduce_variance( 248 | self, 249 | SNR_t : Tensor, 250 | diff_loss : Tensor, 251 | ): 252 | ''' 253 | This function computes the gradients of the variance of the 254 | M.C. estimate of the diffusion loss (L_∞) w.r.t. the noise 255 | schedule so to optimize its overall shape. 256 | NOTE 1: Only the star|end-points contribute to the VLB, which 257 | is what we are optimizing when computing the loss, so 258 | we need an additional objective that can explicitly 259 | train the noise schedule shape. 260 | NOTE 2: Following Appendix I.2 of the main paper, note that the 261 | gradient w.r.t. the SNR is already computed when doing 262 | back-prop of the VLB. 263 | ''' 264 | # NOTE: This function should be called after backward on the loss 265 | # has already been called. We check that schedule parameters 266 | # have non-zero gradients 267 | msg = '''Noise schedule parameters have zero gradient. This is probably due 268 | to the function `reduce_variance` been called before `backward` has 269 | been called to the VLB loss. Reduce variance need the gradients and 270 | is thus now ineffective. Please only call `reduce_variance` after 271 | loss.backward() has been called. 272 | ''' 273 | 274 | for par in self.schedule.parameters(): 275 | if torch.all(par.grad == 0): warnings.warn(msg) 276 | 277 | # Grad already contains derivative of L_∞^MC w.r.t SNR 278 | par.grad *= autograd.grad( 279 | outputs=SNR_t, 280 | inputs=par, 281 | grad_outputs=2 * diff_loss, 282 | create_graph=True, 283 | retain_graph=True, 284 | )[0] 285 | 286 | def _diffusion_loss(self, x_0 : Tensor) -> Tensor: 287 | ''' 288 | Compute the (continuous) L_∞ loss (T -> ∞), which is defined as: 289 | 290 | L_∞ = 1/2 gamma'(t) E_{t~U(0, 1)} || eps_theta(z_t ; t) - eps ||^2 291 | 292 | NOTE: We use autograd to estimate gamma'(t) = d gamma(t) / dt 293 | ''' 294 | bs, *img_shape = x_0.shape 295 | 296 | # Sample a set of times for forward diffusion q(z_t | x_0) 297 | # and convert them to gammas using the noise schedule 298 | times = self._get_times(bs).requires_grad_(True) 299 | gamma = self.schedule(times) 300 | 301 | SNR_t = exp(-gamma) 302 | 303 | # Sample from the forward diffusion process (with known noise as we need it 304 | # to compute the diffusion loss) 305 | eps = torch.randn_like(x_0) 306 | z_t = self._diffuse(x_0, gamma, noise=eps) 307 | 308 | # Compute the latent noise eps_theta using the backbone model 309 | eps_theta = self.backbone(z_t, time=gamma) # NOTE: We should add here conditioning if needed 310 | 311 | # Compute the continuous loss by estimating the expectation values via 312 | # Monte Carlo estimates (we sample the times and simply compute the expected values) 313 | dgamma_dt, *_ = autograd.grad( 314 | outputs=gamma, 315 | inputs=times, 316 | grad_outputs=torch.ones_like(gamma), 317 | create_graph=True, 318 | retain_graph=True, 319 | ) 320 | 321 | loss = .5 * dgamma_dt * reduce(((eps - eps_theta) ** 2), 'b ... -> b 1', 'sum') 322 | 323 | # Return loss with dimension [batch_size] 324 | return loss, SNR_t 325 | 326 | def _recon_loss(self, x_0 : Tensor, gamma_0 : Tensor, idxs : Tensor) -> Tensor: 327 | ''' 328 | Compute the reconstruction loss, which is defined as: 329 | 330 | L_rec = - E_{q(z_0 | x_0)} [log p(x | z_0)] 331 | ''' 332 | 333 | # Compute z_0 / alpha_0 from x_0 334 | z_0 = x_0 + exp(.5 * gamma_0) * torch.randn_like(x_0) 335 | 336 | # Get the probabilities for each data value, we get 337 | # prob shape: [batch_size, *img_shape, vocab_size] 338 | prob = self._data_prob(z_0, gamma_0) 339 | 340 | # Grab the probability of the data values 341 | idxs = rearrange(idxs, '... -> ... 1') 342 | prob = torch.gather(prob, dim=-1, index=idxs) 343 | 344 | # Compute the reconstruction loss 345 | loss = -reduce(prob, 'b ... -> b', 'sum') 346 | 347 | return loss 348 | 349 | def _latent_loss(self, x_0 : Tensor, gamma_1 : Tensor) -> Tensor: 350 | ''' 351 | Compute the latent loss, which is defined as: 352 | 353 | L_latent = D_KL(q(z_1 | x_0) || p(z_1)), 354 | 355 | which is the D_KL from a standard normal N(0, 1) (the 356 | desired p(z_1), from which we know how to sample), and 357 | the measured q(z_1 | x_0). 358 | ''' 359 | 360 | # Compute the mean (alpha_1) and std (sigma_1) of q(z_1 | x_0) 361 | # NOTE: For variance-preserving diffusion process we have: 362 | # alpha_t = sqrt(1 - sigma_t ** 2) 363 | sigma_1_sq = sigmoid(+gamma_1) 364 | alpha_1_sq = 1 - sigma_1_sq 365 | 366 | mu_sq = alpha_1_sq * x_0 ** 2 367 | 368 | # Compute the D_KL between a reference N(0, 1) and N(mu, sig) 369 | loss = .5 * (sigma_1_sq + mu_sq - log(sigma_1_sq.clamp(min=1e-15)) - 1.) 370 | 371 | return reduce(loss, 'b ... -> b', 'sum') 372 | 373 | def _diffuse(self, x_0 : Tensor, gamma_t : Tensor, noise : Tensor | None = None) -> Tensor: 374 | ''' 375 | Forward diffusion: we sample from q(z_t | x_0). This is 376 | easy sampling as we only need to sample from a standard 377 | normal with known SNR(t). We have: 378 | q(z_t | x_0) = N(alpha_t x_0 | sigma_t**2 * I), with 379 | 380 | SNR(t) = alpha_t ** 2 / sigma_t ** 2 381 | 382 | NOTE: Time is effectively parametrized via the SNR which 383 | in turn is computed via the noise schedule that can 384 | either be linear of a monotonic network. 385 | ''' 386 | 387 | noise = default(noise, torch.randn_like(x_0)) 388 | 389 | # Compute the alpha_t and sigma_t using the noise schedule 390 | alpha_t = enlarge_as(sqrt(sigmoid(-gamma_t)), x_0) 391 | sigma_t = enlarge_as(sqrt(sigmoid(+gamma_t)), x_0) 392 | 393 | return alpha_t * x_0 + sigma_t * noise 394 | 395 | def _coalesce(self, z_t : Tensor, gamma_t : Tensor, gamma_s : Tensor) -> Tensor: 396 | ''' 397 | Backward diffusion: we sample from p(z_s | z_t, x = x_theta(z_t ; t)). 398 | This is a bit more involved as we need to sample from a 399 | distribution that depends on the previous sample. We have: 400 | p(z_s | z_t, x = x_theta(z_t ; t)) = N(mu_theta, sigma_Q ** 2 * I), 401 | 402 | where we eventually have (see Eq.(32-33) in Appendix A of paper): 403 | 404 | mu_theta = alpha_s / alpha_t * (z_t + sigma_t * expm1(gamma_s - gamma_t))\ 405 | * eps_theta(z_t ; t) 406 | 407 | sigma_Q = sigma_s ** 2 * (-expm1(gamma_s - gamma_t)) 408 | ''' 409 | 410 | alpha_s_sq = sigmoid(-gamma_s) 411 | alpha_t_sq = sigmoid(-gamma_t) 412 | 413 | sigma_t = sqrt(sigmoid(+gamma_t)) 414 | c = -expm1(gamma_s - gamma_t) 415 | 416 | # Predict latent noise eps_theta using backbone model 417 | eps = self.backbone(z_t, gamma_t) # NOTE: We should add here conditioning if needed 418 | 419 | # Compute new latent z_s mean and std 420 | scale = sqrt((1 - alpha_s_sq) * c) 421 | mean = sqrt(alpha_s_sq / alpha_t_sq) * (z_t - c * sigma_t * eps) 422 | 423 | return mean + scale * torch.randn_like(z_t) 424 | 425 | def _data_prob(self, z_0 : Tensor, gamma_0 : Tensor) -> Tensor: 426 | ''' 427 | Compute the probability distribution for p(x | z_0). This distribution is 428 | approximated by p(x | z_0) ~ Prod_i q(z_0(i) | x(i)), which is sensible 429 | whenever SNR(t=0) is high enough. Here q(z_0(i) | x(i)) represent the 430 | latent code for pixel i-th given the original pixel value. The logits 431 | are estimated as: 432 | 433 | -1/2 SNR(t=0) * (z_0 / alpha_0 - k) ** 2, 434 | 435 | where k takes all possible data values (hence k take vocab_size values). 436 | 437 | NOTE: We assume input z_0 has already been normalized, so we actually 438 | expect z_0 / alpha_0. Moreover, we have: SNR(t=0) = exp(-gamma_0) 439 | ''' 440 | 441 | # Add vocab_size dimension 442 | z_0 = rearrange(z_0, '... -> ... 1') 443 | 444 | x_lim = 1 - 1 / self.vocab_size 445 | x_val = torch.linspace(-x_lim, +x_lim, self.vocab_size, device = self.device) 446 | 447 | logits = -.5 * exp(-gamma_0) * (z_0 - x_val) ** 2 448 | 449 | # Normalize along the vocab_size dimension 450 | return log_softmax(logits, dim=-1) 451 | 452 | def _get_times(self, batch_size : int, sampler : str = 'low-var') -> Tensor: 453 | ''' 454 | Sample the diffusion time steps. We can choose the sampler to 455 | be either are low-variance or naive. 456 | ''' 457 | 458 | samplers = ('low-var', 'naive') 459 | 460 | match sampler: 461 | case 'low-var': 462 | t_0 = torch.rand(1).item() / batch_size 463 | ts = torch.arange(t_0, 1., 1 / batch_size, device=self.device) 464 | 465 | # Add single channel dimension 466 | return rearrange(ts, 'b -> b 1') 467 | 468 | case 'naive': 469 | return torch.rand((batch_size, 1), device=self.device) 470 | 471 | raise ValueError(f'Unknown sampler: {sampler}. Available samplers are: {samplers}') 472 | 473 | 474 | 475 | 476 | 477 | -------------------------------------------------------------------------------- /test/test_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | 4 | from torch import Tensor 5 | 6 | from src.module.attention import AdaptiveAttention 7 | 8 | class AdaptiveAttentionTest(unittest.TestCase): 9 | def setUp(self) -> None: 10 | 11 | batch_size = 2 12 | chn_dim = 8 13 | h = w = 16 14 | l = h * w 15 | 16 | self.n_heads = 4 17 | self.emb_dim = 32 18 | 19 | self.img_h = h 20 | self.img_w = w 21 | self.chn_dim = chn_dim 22 | 23 | self.input_shape_txt = (batch_size, l, chn_dim) 24 | self.input_shape_mp3 = (batch_size, chn_dim, l) 25 | self.input_shape_img = (batch_size, chn_dim, h, w) 26 | self.input_shape_mp4 = (batch_size, chn_dim, h, w, l) 27 | 28 | 29 | def test_forward_self_attn_image(self): 30 | 31 | attn = AdaptiveAttention( 32 | emb_dim=self.emb_dim, 33 | n_heads=self.n_heads, 34 | 35 | # Test with adapter just for the query, it then expects 36 | # key|val to be already sequence-like 37 | pattern='b c h w -> b (h w) c', 38 | qry_dim=self.chn_dim, 39 | batch_first=True, 40 | ) 41 | 42 | qry : Tensor = torch.randn(self.input_shape_img) 43 | key : Tensor = torch.randn(self.input_shape_img) 44 | val : Tensor = torch.randn(self.input_shape_img) 45 | 46 | out_tensor, _ = attn(qry, key, val) 47 | 48 | self.assertEqual(out_tensor.shape, self.input_shape_img) 49 | 50 | def test_forward_attn_img_txt(self): 51 | 52 | attn = AdaptiveAttention( 53 | emb_dim=self.emb_dim, 54 | n_heads=self.n_heads, 55 | 56 | # Test with adapter just for the query, it then expects 57 | # key|val to be already sequence-like 58 | pattern=( 59 | 'b c h w -> b (h w) c', 60 | '... -> ...', 61 | '... -> ...', 62 | ), 63 | qry_dim=self.chn_dim, 64 | batch_first=True, 65 | ) 66 | 67 | qry : Tensor = torch.randn(self.input_shape_img) 68 | key : Tensor = torch.randn(self.input_shape_txt) 69 | val : Tensor = torch.randn(self.input_shape_txt) 70 | 71 | out_tensor, _ = attn(qry, key, val) 72 | 73 | self.assertEqual(out_tensor.shape, self.input_shape_img) -------------------------------------------------------------------------------- /test/test_schedule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | 4 | from torch import Tensor 5 | 6 | from src.module.monotonic import MonotonicLinear 7 | from src.schedule import LearnableSchedule 8 | 9 | class ScheduleTest(unittest.TestCase): 10 | def setUp(self) -> None: 11 | pass 12 | 13 | def test_monotonic_linear(self): 14 | 15 | mono = MonotonicLinear( 16 | 3, 3, bias=True, 17 | gate_func='relu', 18 | indicator=torch.tensor((+1, +1, -1)), 19 | act_weight=(1, 1, 1), 20 | ) 21 | 22 | time_zero = torch.rand((10, 3)) 23 | 24 | time_add1 = time_zero + torch.tensor((1, 0, 0)) 25 | time_add2 = time_zero + torch.tensor((0, 1, 0)) 26 | time_add3 = time_zero + torch.tensor((0, 0, 1)) 27 | 28 | time_sub1 = time_zero - torch.tensor((1, 0, 0)) 29 | time_sub2 = time_zero - torch.tensor((0, 1, 0)) 30 | time_sub3 = time_zero - torch.tensor((0, 0, 1)) 31 | 32 | out_zero = mono(time_zero) 33 | 34 | out_add1 = mono(time_add1) 35 | out_add2 = mono(time_add2) 36 | out_add3 = mono(time_add3) 37 | 38 | out_sub1 = mono(time_sub1) 39 | out_sub2 = mono(time_sub2) 40 | out_sub3 = mono(time_sub3) 41 | 42 | # Layer should be positive monotonic along dim=[0, 1], 43 | # while being negative monotonic along dim=2 44 | mono_pos_dim0 = torch.all(out_add1 >= out_zero) 45 | mono_pos_dim1 = torch.all(out_add2 >= out_zero) 46 | mono_neg_dim2 = torch.all(out_add3 <= out_zero) 47 | 48 | mono_neg_dim0 = torch.all(out_sub1 <= out_zero) 49 | mono_neg_dim1 = torch.all(out_sub2 <= out_zero) 50 | mono_pos_dim2 = torch.all(out_sub3 >= out_zero) 51 | 52 | # Check monotonicity 53 | self.assertTrue(torch.all(mono_pos_dim0)) 54 | self.assertTrue(torch.all(mono_pos_dim1)) 55 | self.assertTrue(torch.all(mono_neg_dim2)) 56 | 57 | self.assertTrue(torch.all(mono_neg_dim0)) 58 | self.assertTrue(torch.all(mono_neg_dim1)) 59 | self.assertTrue(torch.all(mono_pos_dim2)) 60 | 61 | def test_schedule(self): 62 | 63 | schedule = LearnableSchedule( 64 | hid_dim=[50, 50], 65 | gate_func='relu', 66 | ) 67 | 68 | time = torch.rand((10, 1)) 69 | time_p1 = time + 1 70 | 71 | out_t0 : Tensor = schedule(time) 72 | out_t1 : Tensor = schedule(time_p1) 73 | 74 | # Check monotonicity 75 | self.assertTrue(torch.all(out_t0 >= out_t1)) -------------------------------------------------------------------------------- /test/test_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | 4 | from torch import Tensor 5 | 6 | from src.unet import UNet 7 | 8 | class UNetTest(unittest.TestCase): 9 | def setUp(self) -> None: 10 | 11 | chn_dim = 3 12 | h = w = 16 13 | 14 | self.n_heads = 4 15 | self.emb_dim = 32 16 | 17 | self.img_h = h 18 | self.img_w = w 19 | self.chn_dim = chn_dim 20 | self.batch_size = 2 21 | 22 | self.input_shape_img = (self.batch_size, chn_dim, h, w) 23 | 24 | self.time = torch.rand(self.batch_size) 25 | 26 | def test_forward_with_img_attn(self): 27 | 28 | out_dim = 6 29 | out_shape = (self.batch_size, out_dim, self.img_h, self.img_w) 30 | 31 | unet = UNet( 32 | net_dim=4, 33 | out_dim=out_dim, 34 | ctrl_dim=3, 35 | use_cond=True, 36 | use_attn=True, 37 | num_group=4, 38 | adapter='b c h w -> b (h w) c', 39 | ) 40 | 41 | 42 | inp_img = torch.randn(self.input_shape_img) 43 | inp_ctx = torch.randn(self.input_shape_img) 44 | 45 | out_tensor = unet( 46 | inp_img, 47 | self.time, 48 | ctrl=inp_ctx, 49 | ) 50 | 51 | self.assertEqual(out_tensor.shape, out_shape) 52 | 53 | def test_forward_with_fourier(self): 54 | 55 | out_dim = 6 56 | out_shape = (self.batch_size, out_dim, self.img_h, self.img_w) 57 | 58 | unet = UNet( 59 | net_dim=4, 60 | out_dim=out_dim, 61 | adapter='b c h w -> b (h w) c', 62 | ctrl_dim=3, 63 | use_cond=True, 64 | use_attn=True, 65 | n_fourier=(7, 9, 1), 66 | num_group=4, 67 | ) 68 | 69 | 70 | inp_img = torch.randn(self.input_shape_img) 71 | inp_ctx = torch.randn(self.input_shape_img) 72 | 73 | out_tensor = unet( 74 | inp_img, 75 | self.time, 76 | ctrl=inp_ctx, 77 | ) 78 | 79 | self.assertEqual(out_tensor.shape, out_shape) 80 | 81 | -------------------------------------------------------------------------------- /test/test_vdm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | 4 | from torch import Tensor 5 | 6 | from src.unet import UNet 7 | from src.vdm import VariationalDiffusion 8 | from src.schedule import LinearSchedule 9 | from src.schedule import LearnableSchedule 10 | 11 | class VariationalDiffusionTest(unittest.TestCase): 12 | def setUp(self) -> None: 13 | 14 | chn_dim = 3 15 | h = w = 16 16 | 17 | self.n_heads = 4 18 | self.emb_dim = 32 19 | 20 | self.img_h = h 21 | self.img_w = w 22 | self.chn_dim = chn_dim 23 | self.batch_size = 2 24 | self.vocab_size = 256 25 | 26 | self.img_shape = (chn_dim, h, w) 27 | 28 | def test_forward(self): 29 | 30 | num_imgs = 4 31 | num_steps = 5 32 | 33 | vdm = VariationalDiffusion( 34 | backbone=UNet( 35 | net_dim=4, 36 | ctrl_dim=None, 37 | use_cond=False, 38 | use_attn=True, 39 | num_group=4, 40 | adapter='b c h w -> b (h w) c', 41 | ), 42 | schedule=LinearSchedule(), 43 | img_shape=(self.img_h, self.img_w), 44 | vocab_size=256, 45 | ) 46 | 47 | imgs : Tensor = vdm( 48 | num_imgs, 49 | num_steps, 50 | ) 51 | 52 | self.assertEqual(imgs.shape, (num_imgs, *self.img_shape)) 53 | 54 | def test_compute_loss(self): 55 | 56 | imgs = torch.rand((self.batch_size, *self.img_shape)) 57 | 58 | vdm = VariationalDiffusion( 59 | backbone=UNet( 60 | net_dim=4, 61 | ctrl_dim=None, 62 | use_cond=False, 63 | use_attn=True, 64 | num_group=4, 65 | adapter='b c h w -> b (h w) c', 66 | ), 67 | schedule=LinearSchedule(), 68 | img_shape=(self.img_h, self.img_w), 69 | vocab_size=256, 70 | ) 71 | 72 | loss, stat = vdm.compute_loss( 73 | imgs 74 | ) 75 | 76 | print(stat) 77 | 78 | self.assertTrue(loss > 0) 79 | 80 | def test_reduce_variance(self): 81 | 82 | imgs = torch.rand((self.batch_size, *self.img_shape)) 83 | 84 | vdm = VariationalDiffusion( 85 | backbone=UNet( 86 | net_dim=4, 87 | ctrl_dim=None, 88 | use_cond=False, 89 | use_attn=True, 90 | num_group=4, 91 | adapter='b c h w -> b (h w) c', 92 | ), 93 | schedule=LearnableSchedule( 94 | hid_dim=[50, 50], 95 | gate_func='relu', 96 | ), 97 | img_shape=(self.img_h, self.img_w), 98 | vocab_size=256, 99 | ) 100 | 101 | loss, stat = vdm.compute_loss(imgs) 102 | 103 | # Call loss.backward() to populate the schedule gradients 104 | loss.backward(retain_graph=True) 105 | 106 | # Now we can call reduce_variance to update the gradients 107 | vdm.reduce_variance(*stat['var_args']) 108 | 109 | # If we get here we consider it success 110 | self.assertTrue(loss > 0) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os.path as path 3 | 4 | from lightning import Trainer 5 | from argparse import ArgumentParser 6 | 7 | from lightning.pytorch.callbacks import ModelCheckpoint 8 | from lightning.pytorch.loggers import TensorBoardLogger 9 | 10 | from src.data import CIFAR10DM 11 | from src.vdm import VariationalDiffusion 12 | 13 | def main(args): 14 | 15 | config = path.join('conf', args.config) 16 | 17 | with open(config, 'r') as f: 18 | conf_file = yaml.safe_load(f) 19 | 20 | # Create the variational diffusion model 21 | vdm = VariationalDiffusion.from_conf(conf_file) 22 | cifar = CIFAR10DM(**conf_file['DATASET']) 23 | 24 | ckpt_dir = conf_file['MISC']['ckpt_dir'] 25 | logs_dir = conf_file['MISC']['logs_dir'] 26 | run_name = conf_file['MISC']['run_name'] 27 | monitor = conf_file['MISC']['monitor'] 28 | resume = conf_file['MISC']['resume_ckpt'] 29 | 30 | ckpter = ModelCheckpoint(dirpath=ckpt_dir, monitor=monitor) 31 | logger = TensorBoardLogger(logs_dir, name =run_name) 32 | 33 | args = {**vars(args), **conf_file['TRAINER'], 'logger' : logger, 'callbacks' : ckpter} 34 | args.pop('config') 35 | 36 | trainer = Trainer(**args) 37 | 38 | trainer.fit(vdm, datamodule=cifar, ckpt_path=resume) 39 | 40 | 41 | if __name__ == '__main__': 42 | parser = ArgumentParser( 43 | prog = 'Variational-Diffusion Model Training Script', 44 | description = 'Training of the Variational-Diffusion Model on the CIFAR-10 Dataset', 45 | ) 46 | 47 | parser = Trainer.add_argparse_args(parser) 48 | parser.add_argument('-config', type = str, default = 'vdm_cifar10.yaml', help='Configuration file name') 49 | 50 | args = parser.parse_args() 51 | 52 | main(args) --------------------------------------------------------------------------------