├── dispersiveloss.png ├── environment.yml ├── transport ├── utils.py ├── __init__.py ├── integrators.py ├── path.py └── transport.py ├── LICENSE.txt ├── download.py ├── wandb_utils.py ├── train_utils.py ├── README.md ├── sample.py ├── run_SiT.ipynb ├── sample_ddp.py ├── train.py └── models.py /dispersiveloss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raywang4/DispLoss/HEAD/dispersiveloss.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: SiT 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python >= 3.8 7 | - pytorch >= 1.13 8 | - torchvision 9 | - pytorch-cuda >=11.7 10 | - pip 11 | - pip: 12 | - timm 13 | - diffusers 14 | - accelerate 15 | - torchdiffeq 16 | - wandb 17 | -------------------------------------------------------------------------------- /transport/utils.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | class EasyDict: 4 | 5 | def __init__(self, sub_dict): 6 | for k, v in sub_dict.items(): 7 | setattr(self, k, v) 8 | 9 | def __getitem__(self, key): 10 | return getattr(self, key) 11 | 12 | def mean_flat(x): 13 | """ 14 | Take the mean over all non-batch dimensions. 15 | """ 16 | return th.mean(x, dim=list(range(1, len(x.size())))) 17 | 18 | def log_state(state): 19 | result = [] 20 | 21 | sorted_state = dict(sorted(state.items())) 22 | for key, value in sorted_state.items(): 23 | # Check if the value is an instance of a class 24 | if " 3 | This repo contains the official PyTorch implementation of Dispersive Loss. 4 | 5 | > [**Diffuse and Disperse: Image Generation with Representation Regularization**](https://arxiv.org/abs/2506.09027)
6 | > [Runqian Wang](https://raywang4.github.io/), [Kaiming He](https://people.csail.mit.edu/kaiming/) 7 | >
MIT
8 | 9 | We propose Dispersive Loss, a simple plug-and-play regularizer that effectively improves diffusion-based generative models. 10 | Our loss function encourages internal representations to disperse in the hidden space, analogous to contrastive self-supervised learning, with the key distinction that it requires no positive sample pairs and therefore does not interfere with the sampling process used for regression. 11 | 12 | We implement our Dispersive Loss on top of [SiT](https://github.com/willisma/SiT) codebase. The core implementation of Dispersive Loss is highlighted below: 13 | ```python 14 | def disp_loss(self, z): # Dispersive Loss implementation (InfoNCE-L2 variant) 15 | z = z.reshape((z.shape[0],-1)) # flatten 16 | diff = th.nn.functional.pdist(z).pow(2)/z.shape[1] # pairwise distance 17 | diff = th.concat((diff, diff, th.zeros(z.shape[0]).cuda())) # match JAX implementation of full BxB matrix 18 | return th.log(th.exp(-diff).mean()) # calculate loss 19 | ``` 20 | 21 | ## Setup 22 | 23 | Run the following script to setup environment. 24 | 25 | ```bash 26 | git clone https://github.com/raywang4/DispLoss.git 27 | cd DispLoss 28 | conda env create -f environment.yml 29 | conda activate SiT 30 | ``` 31 | 32 | 33 | ## Training With Dispersive Loss 34 | 35 | To train with Dispersive Loss, simply add the `--disp` argument to the training script: 36 | 37 | ```bash 38 | torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-XL/2 --data-path /path/to/imagenet/train --disp 39 | ``` 40 | 41 | **Logging.** To enable `wandb`, firstly set `WANDB_KEY`, `ENTITY`, and `PROJECT` as environment variables: 42 | 43 | ```bash 44 | export WANDB_KEY="key" 45 | export ENTITY="entity name" 46 | export PROJECT="project name" 47 | ``` 48 | Then in training command add the `--wandb` flag: 49 | 50 | ```bash 51 | torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-XL/2 --data-path /path/to/imagenet/train --disp --wandb 52 | ``` 53 | **Resume training.** To resume training from custom checkpoint: 54 | 55 | ```bash 56 | torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-L/2 --data-path /path/to/imagenet/train --disp --ckpt /path/to/model.pt 57 | ``` 58 | 59 | ## Sampling 60 | 61 | **Pre-trained checkpoints.** We provide a [SiT-B/2 checkpoint](https://drive.google.com/file/d/18OeryruY-P4KuqJeKB6_EXtHUkQ5Cy7u/view?usp=sharing) and a [SiT-XL/2 checkpoint](https://drive.google.com/file/d/1NR_R6wYXS6dwCwYCM8h8EmeLpJjtiwVr/view?usp=sharing) both trained with Dispersive Loss for 80 epochs on ImageNet 256x256. 62 | 63 | **Sampling from checkpoint.** To sample from the EMA weights of a 256x256 SiT-XL/2 model checkpoint with ODE sampler, run: 64 | ```bash 65 | python sample.py ODE --model SiT-XL/2 --image-size 256 --ckpt /path/to/model.pt 66 | ``` 67 | **More sampling options.** For more sampling options such as SDE sampling, please refer to [`train_utils.py`](train_utils.py). 68 | 69 | ## Evaluation 70 | 71 | The [`sample_ddp.py`](sample_ddp.py) script samples a large number of images from a pre-trained model in parallel. This script 72 | generates a folder of samples as well as a `.npz` file which can be directly used with [ADM's TensorFlow 73 | evaluation suite](https://github.com/openai/guided-diffusion/tree/main/evaluations) to compute FID, Inception Score and 74 | other metrics. To sample 50K images from a pre-trained SiT-XL/2 model over `N` GPUs under default ODE sampler settings, run: 75 | 76 | ```bash 77 | torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py ODE --model SiT-XL/2 --num-fid-samples 50000 --ckpt /path/to/model.pt 78 | ``` 79 | 80 | ## Differences from JAX 81 | Our original implementation is in JAX, and this repo contains our re-implementation in PyTorch. 82 | Therefore, results from running this repo may have minor numerical differences with those reported in our paper. 83 | In our JAX experiments, we used 16 devices with local batch size 16, whereas in PyTorch experiments we used 8 devices with local batch size 32. 84 | We have adjusted the hyperparameter choices slightly for best performance. 85 | We report our reproduction results below. 86 | | implementation | config | local bz | B/2 80 ep | XL/2 80 ep (cfg=1.5) | 87 | |-|:-:|:-:|:-:|:-:| 88 | | baseline | - | 16 | 36.49 | 6.02 | 89 | | JAX | $\lambda$=0.5, $\tau$=0.5, depth=num_layers//4 | 16 | 32.35 | 5.09 | 90 | | PyTorch | $\lambda$=0.25, $\tau$=1, depth=num_layers | 32 | 32.64 | 4.74 | 91 | 92 | 93 | ## License 94 | This project is under the MIT license. See [LICENSE](LICENSE.txt) for details. 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | 4 | """ 5 | Sample new images from a pre-trained SiT. 6 | """ 7 | import torch 8 | torch.backends.cuda.matmul.allow_tf32 = True 9 | torch.backends.cudnn.allow_tf32 = True 10 | from torchvision.utils import save_image 11 | from diffusers.models import AutoencoderKL 12 | from download import find_model 13 | from models import SiT_models 14 | from train_utils import parse_ode_args, parse_sde_args, parse_transport_args 15 | from transport import create_transport, Sampler 16 | import argparse 17 | import sys 18 | from time import time 19 | 20 | 21 | def main(mode, args): 22 | # Setup PyTorch: 23 | torch.manual_seed(args.seed) 24 | torch.set_grad_enabled(False) 25 | device = "cuda" if torch.cuda.is_available() else "cpu" 26 | 27 | if args.ckpt is None: 28 | assert args.model == "SiT-XL/2", "Only SiT-XL/2 models are available for auto-download." 29 | assert args.image_size in [256, 512] 30 | assert args.num_classes == 1000 31 | assert args.image_size == 256, "512x512 models are not yet available for auto-download." # remove this line when 512x512 models are available 32 | learn_sigma = args.image_size == 256 33 | else: 34 | learn_sigma = False 35 | 36 | # Load model: 37 | latent_size = args.image_size // 8 38 | model = SiT_models[args.model]( 39 | input_size=latent_size, 40 | num_classes=args.num_classes, 41 | learn_sigma=learn_sigma, 42 | ).to(device) 43 | # Auto-download a pre-trained model or load a custom SiT checkpoint from train.py: 44 | ckpt_path = args.ckpt or f"SiT-XL-2-{args.image_size}x{args.image_size}.pt" 45 | state_dict = find_model(ckpt_path) 46 | model.load_state_dict(state_dict) 47 | model.eval() # important! 48 | transport = create_transport( 49 | args.path_type, 50 | args.prediction, 51 | args.loss_weight, 52 | args.train_eps, 53 | args.sample_eps 54 | ) 55 | sampler = Sampler(transport) 56 | if mode == "ODE": 57 | if args.likelihood: 58 | assert args.cfg_scale == 1, "Likelihood is incompatible with guidance" 59 | sample_fn = sampler.sample_ode_likelihood( 60 | sampling_method=args.sampling_method, 61 | num_steps=args.num_sampling_steps, 62 | atol=args.atol, 63 | rtol=args.rtol, 64 | ) 65 | else: 66 | sample_fn = sampler.sample_ode( 67 | sampling_method=args.sampling_method, 68 | num_steps=args.num_sampling_steps, 69 | atol=args.atol, 70 | rtol=args.rtol, 71 | reverse=args.reverse 72 | ) 73 | 74 | elif mode == "SDE": 75 | sample_fn = sampler.sample_sde( 76 | sampling_method=args.sampling_method, 77 | diffusion_form=args.diffusion_form, 78 | diffusion_norm=args.diffusion_norm, 79 | last_step=args.last_step, 80 | last_step_size=args.last_step_size, 81 | num_steps=args.num_sampling_steps, 82 | ) 83 | 84 | 85 | vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) 86 | 87 | # Labels to condition the model with (feel free to change): 88 | class_labels = [207, 360, 387, 974, 88, 979, 417, 279] 89 | 90 | # Create sampling noise: 91 | n = len(class_labels) 92 | z = torch.randn(n, 4, latent_size, latent_size, device=device) 93 | y = torch.tensor(class_labels, device=device) 94 | 95 | # Setup classifier-free guidance: 96 | z = torch.cat([z, z], 0) 97 | y_null = torch.tensor([1000] * n, device=device) 98 | y = torch.cat([y, y_null], 0) 99 | model_kwargs = dict(y=y, cfg_scale=args.cfg_scale) 100 | 101 | # Sample images: 102 | start_time = time() 103 | samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1] 104 | samples, _ = samples.chunk(2, dim=0) # Remove null class samples 105 | samples = vae.decode(samples / 0.18215).sample 106 | print(f"Sampling took {time() - start_time:.2f} seconds.") 107 | 108 | # Save and display images: 109 | save_image(samples, "sample.png", nrow=4, normalize=True, value_range=(-1, 1)) 110 | 111 | 112 | if __name__ == "__main__": 113 | parser = argparse.ArgumentParser() 114 | 115 | if len(sys.argv) < 2: 116 | print("Usage: program.py [options]") 117 | sys.exit(1) 118 | 119 | mode = sys.argv[1] 120 | 121 | assert mode[:2] != "--", "Usage: program.py [options]" 122 | assert mode in ["ODE", "SDE"], "Invalid mode. Please choose 'ODE' or 'SDE'" 123 | 124 | parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2") 125 | parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="mse") 126 | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) 127 | parser.add_argument("--num-classes", type=int, default=1000) 128 | parser.add_argument("--cfg-scale", type=float, default=4.0) 129 | parser.add_argument("--num-sampling-steps", type=int, default=250) 130 | parser.add_argument("--seed", type=int, default=0) 131 | parser.add_argument("--ckpt", type=str, default=None, 132 | help="Optional path to a SiT checkpoint (default: auto-download a pre-trained SiT-XL/2 model).") 133 | 134 | 135 | parse_transport_args(parser) 136 | if mode == "ODE": 137 | parse_ode_args(parser) 138 | # Further processing for ODE 139 | elif mode == "SDE": 140 | parse_sde_args(parser) 141 | # Further processing for SDE 142 | 143 | args = parser.parse_known_args()[0] 144 | main(mode, args) 145 | -------------------------------------------------------------------------------- /run_SiT.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "355UKMUQJxFd" 7 | }, 8 | "source": [ 9 | "# SiT: Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers\n", 10 | "\n", 11 | "This notebook samples from pre-trained SiT models. SiTs are class-conSiTional latent interpolant models trained on ImageNet, unifying Flow and Diffusion Methods. \n", 12 | "\n", 13 | "[Paper]() | [GitHub](github.com/willisma/SiT)" 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "metadata": { 19 | "id": "zJlgLkSaKn7u" 20 | }, 21 | "source": [ 22 | "# 1. Setup\n", 23 | "\n", 24 | "We recommend using GPUs (Runtime > Change runtime type > Hardware accelerator > GPU). Run this cell to clone the SiT GitHub repo and setup PyTorch. You only have to run this once." 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "!git clone https://github.com/willisma/SiT.git\n", 34 | "import SiT, os\n", 35 | "os.chdir('SiT')\n", 36 | "os.environ['PYTHONPATH'] = '/env/python:/content/SiT'\n", 37 | "!pip install diffusers timm torchdiffeq --upgrade\n", 38 | "# SiT imports:\n", 39 | "import torch\n", 40 | "from torchvision.utils import save_image\n", 41 | "from transport import create_transport, Sampler\n", 42 | "from diffusers.models import AutoencoderKL\n", 43 | "from download import find_model\n", 44 | "from models import SiT_XL_2\n", 45 | "from PIL import Image\n", 46 | "from IPython.display import display\n", 47 | "torch.set_grad_enabled(False)\n", 48 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", 49 | "if device == \"cpu\":\n", 50 | " print(\"GPU not found. Using CPU instead.\")" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": { 56 | "id": "AXpziRkoOvV9" 57 | }, 58 | "source": [ 59 | "# Download SiT-XL/2 Models" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": { 66 | "id": "EWG-WNimO59K" 67 | }, 68 | "outputs": [], 69 | "source": [ 70 | "image_size = \"256\"\n", 71 | "vae_model = \"stabilityai/sd-vae-ft-ema\" #@param [\"stabilityai/sd-vae-ft-mse\", \"stabilityai/sd-vae-ft-ema\"]\n", 72 | "latent_size = int(image_size) // 8\n", 73 | "# Load model:\n", 74 | "model = SiT_XL_2(input_size=latent_size).to(device)\n", 75 | "state_dict = find_model(f\"SiT-XL-2-{image_size}x{image_size}.pt\")\n", 76 | "model.load_state_dict(state_dict)\n", 77 | "model.eval() # important!\n", 78 | "vae = AutoencoderKL.from_pretrained(vae_model).to(device)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": { 84 | "id": "5JTNyzNZKb9E" 85 | }, 86 | "source": [ 87 | "# 2. Sample from Pre-trained SiT Models\n", 88 | "\n", 89 | "You can customize several sampling options. For the full list of ImageNet classes, [check out this](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a)." 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": { 96 | "id": "-Hw7B5h4Kk4p" 97 | }, 98 | "outputs": [], 99 | "source": [ 100 | "# Set user inputs:\n", 101 | "seed = 0 #@param {type:\"number\"}\n", 102 | "torch.manual_seed(seed)\n", 103 | "num_sampling_steps = 250 #@param {type:\"slider\", min:0, max:1000, step:1}\n", 104 | "cfg_scale = 4 #@param {type:\"slider\", min:1, max:10, step:0.1}\n", 105 | "class_labels = 207, 360, 387, 974, 88, 979, 417, 279 #@param {type:\"raw\"}\n", 106 | "samples_per_row = 4 #@param {type:\"number\"}\n", 107 | "sampler_type = \"ODE\" #@param [\"ODE\", \"SDE\"]\n", 108 | "\n", 109 | "\n", 110 | "# Create diffusion object:\n", 111 | "transport = create_transport()\n", 112 | "sampler = Sampler(transport)\n", 113 | "\n", 114 | "# Create sampling noise:\n", 115 | "n = len(class_labels)\n", 116 | "z = torch.randn(n, 4, latent_size, latent_size, device=device)\n", 117 | "y = torch.tensor(class_labels, device=device)\n", 118 | "\n", 119 | "# Setup classifier-free guidance:\n", 120 | "z = torch.cat([z, z], 0)\n", 121 | "y_null = torch.tensor([1000] * n, device=device)\n", 122 | "y = torch.cat([y, y_null], 0)\n", 123 | "model_kwargs = dict(y=y, cfg_scale=cfg_scale)\n", 124 | "\n", 125 | "# Sample images:\n", 126 | "if sampler_type == \"SDE\":\n", 127 | " SDE_sampling_method = \"Euler\" #@param [\"Euler\", \"Heun\"]\n", 128 | " diffusion_form = \"linear\" #@param [\"constant\", \"SBDM\", \"sigma\", \"linear\", \"decreasing\", \"increasing-decreasing\"]\n", 129 | " diffusion_norm = 1 #@param {type:\"slider\", min:0, max:10.0, step:0.1}\n", 130 | " last_step = \"Mean\" #@param [\"Mean\", \"Tweedie\", \"Euler\"]\n", 131 | " last_step_size = 0.4 #@param {type:\"slider\", min:0, max:1.0, step:0.01}\n", 132 | " sample_fn = sampler.sample_sde(\n", 133 | " sampling_method=SDE_sampling_method,\n", 134 | " diffusion_form=diffusion_form, \n", 135 | " diffusion_norm=diffusion_norm,\n", 136 | " last_step_size=last_step_size, \n", 137 | " num_steps=num_sampling_steps,\n", 138 | " ) \n", 139 | "elif sampler_type == \"ODE\":\n", 140 | " # default to Adaptive Solver\n", 141 | " ODE_sampling_method = \"dopri5\" #@param [\"dopri5\", \"euler\", \"rk4\"]\n", 142 | " atol = 1e-6\n", 143 | " rtol = 1e-3\n", 144 | " sample_fn = sampler.sample_ode(\n", 145 | " sampling_method=ODE_sampling_method,\n", 146 | " atol=atol,\n", 147 | " rtol=rtol,\n", 148 | " num_steps=num_sampling_steps\n", 149 | " ) \n", 150 | "samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]\n", 151 | "samples = vae.decode(samples / 0.18215).sample\n", 152 | "\n", 153 | "# Save and display images:\n", 154 | "save_image(samples, \"sample.png\", nrow=int(samples_per_row), \n", 155 | " normalize=True, value_range=(-1, 1))\n", 156 | "samples = Image.open(\"sample.png\")\n", 157 | "display(samples)" 158 | ] 159 | } 160 | ], 161 | "metadata": { 162 | "colab": { 163 | "provenance": [] 164 | }, 165 | "kernelspec": { 166 | "display_name": "Python 3.8.10 64-bit", 167 | "language": "python", 168 | "name": "python3" 169 | }, 170 | "language_info": { 171 | "name": "python", 172 | "version": "3.8.10" 173 | }, 174 | "vscode": { 175 | "interpreter": { 176 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" 177 | } 178 | } 179 | }, 180 | "nbformat": 4, 181 | "nbformat_minor": 0 182 | } 183 | -------------------------------------------------------------------------------- /transport/path.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | from functools import partial 4 | 5 | def expand_t_like_x(t, x): 6 | """Function to reshape time t to broadcastable dimension of x 7 | Args: 8 | t: [batch_dim,], time vector 9 | x: [batch_dim,...], data point 10 | """ 11 | dims = [1] * (len(x.size()) - 1) 12 | t = t.view(t.size(0), *dims) 13 | return t 14 | 15 | 16 | #################### Coupling Plans #################### 17 | 18 | class ICPlan: 19 | """Linear Coupling Plan""" 20 | def __init__(self, sigma=0.0): 21 | self.sigma = sigma 22 | 23 | def compute_alpha_t(self, t): 24 | """Compute the data coefficient along the path""" 25 | return t, 1 26 | 27 | def compute_sigma_t(self, t): 28 | """Compute the noise coefficient along the path""" 29 | return 1 - t, -1 30 | 31 | def compute_d_alpha_alpha_ratio_t(self, t): 32 | """Compute the ratio between d_alpha and alpha""" 33 | return 1 / t 34 | 35 | def compute_drift(self, x, t): 36 | """We always output sde according to score parametrization; """ 37 | t = expand_t_like_x(t, x) 38 | alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t) 39 | sigma_t, d_sigma_t = self.compute_sigma_t(t) 40 | drift = alpha_ratio * x 41 | diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t 42 | 43 | return -drift, diffusion 44 | 45 | def compute_diffusion(self, x, t, form="constant", norm=1.0): 46 | """Compute the diffusion term of the SDE 47 | Args: 48 | x: [batch_dim, ...], data point 49 | t: [batch_dim,], time vector 50 | form: str, form of the diffusion term 51 | norm: float, norm of the diffusion term 52 | """ 53 | t = expand_t_like_x(t, x) 54 | choices = { 55 | "constant": norm, 56 | "SBDM": norm * self.compute_drift(x, t)[1], 57 | "sigma": norm * self.compute_sigma_t(t)[0], 58 | "linear": norm * (1 - t), 59 | "decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2, 60 | "inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2, 61 | } 62 | 63 | try: 64 | diffusion = choices[form] 65 | except KeyError: 66 | raise NotImplementedError(f"Diffusion form {form} not implemented") 67 | 68 | return diffusion 69 | 70 | def get_score_from_velocity(self, velocity, x, t): 71 | """Wrapper function: transfrom velocity prediction model to score 72 | Args: 73 | velocity: [batch_dim, ...] shaped tensor; velocity model output 74 | x: [batch_dim, ...] shaped tensor; x_t data point 75 | t: [batch_dim,] time tensor 76 | """ 77 | t = expand_t_like_x(t, x) 78 | alpha_t, d_alpha_t = self.compute_alpha_t(t) 79 | sigma_t, d_sigma_t = self.compute_sigma_t(t) 80 | mean = x 81 | reverse_alpha_ratio = alpha_t / d_alpha_t 82 | var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t 83 | score = (reverse_alpha_ratio * velocity - mean) / var 84 | return score 85 | 86 | def get_noise_from_velocity(self, velocity, x, t): 87 | """Wrapper function: transfrom velocity prediction model to denoiser 88 | Args: 89 | velocity: [batch_dim, ...] shaped tensor; velocity model output 90 | x: [batch_dim, ...] shaped tensor; x_t data point 91 | t: [batch_dim,] time tensor 92 | """ 93 | t = expand_t_like_x(t, x) 94 | alpha_t, d_alpha_t = self.compute_alpha_t(t) 95 | sigma_t, d_sigma_t = self.compute_sigma_t(t) 96 | mean = x 97 | reverse_alpha_ratio = alpha_t / d_alpha_t 98 | var = reverse_alpha_ratio * d_sigma_t - sigma_t 99 | noise = (reverse_alpha_ratio * velocity - mean) / var 100 | return noise 101 | 102 | def get_velocity_from_score(self, score, x, t): 103 | """Wrapper function: transfrom score prediction model to velocity 104 | Args: 105 | score: [batch_dim, ...] shaped tensor; score model output 106 | x: [batch_dim, ...] shaped tensor; x_t data point 107 | t: [batch_dim,] time tensor 108 | """ 109 | t = expand_t_like_x(t, x) 110 | drift, var = self.compute_drift(x, t) 111 | velocity = var * score - drift 112 | return velocity 113 | 114 | def compute_mu_t(self, t, x0, x1): 115 | """Compute the mean of time-dependent density p_t""" 116 | t = expand_t_like_x(t, x1) 117 | alpha_t, _ = self.compute_alpha_t(t) 118 | sigma_t, _ = self.compute_sigma_t(t) 119 | return alpha_t * x1 + sigma_t * x0 120 | 121 | def compute_xt(self, t, x0, x1): 122 | """Sample xt from time-dependent density p_t; rng is required""" 123 | xt = self.compute_mu_t(t, x0, x1) 124 | return xt 125 | 126 | def compute_ut(self, t, x0, x1, xt): 127 | """Compute the vector field corresponding to p_t""" 128 | t = expand_t_like_x(t, x1) 129 | _, d_alpha_t = self.compute_alpha_t(t) 130 | _, d_sigma_t = self.compute_sigma_t(t) 131 | return d_alpha_t * x1 + d_sigma_t * x0 132 | 133 | def plan(self, t, x0, x1): 134 | xt = self.compute_xt(t, x0, x1) 135 | ut = self.compute_ut(t, x0, x1, xt) 136 | return t, xt, ut 137 | 138 | 139 | class VPCPlan(ICPlan): 140 | """class for VP path flow matching""" 141 | 142 | def __init__(self, sigma_min=0.1, sigma_max=20.0): 143 | self.sigma_min = sigma_min 144 | self.sigma_max = sigma_max 145 | self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min 146 | self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min 147 | 148 | 149 | def compute_alpha_t(self, t): 150 | """Compute coefficient of x1""" 151 | alpha_t = self.log_mean_coeff(t) 152 | alpha_t = th.exp(alpha_t) 153 | d_alpha_t = alpha_t * self.d_log_mean_coeff(t) 154 | return alpha_t, d_alpha_t 155 | 156 | def compute_sigma_t(self, t): 157 | """Compute coefficient of x0""" 158 | p_sigma_t = 2 * self.log_mean_coeff(t) 159 | sigma_t = th.sqrt(1 - th.exp(p_sigma_t)) 160 | d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t) 161 | return sigma_t, d_sigma_t 162 | 163 | def compute_d_alpha_alpha_ratio_t(self, t): 164 | """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" 165 | return self.d_log_mean_coeff(t) 166 | 167 | def compute_drift(self, x, t): 168 | """Compute the drift term of the SDE""" 169 | t = expand_t_like_x(t, x) 170 | beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min) 171 | return -0.5 * beta_t * x, beta_t / 2 172 | 173 | 174 | class GVPCPlan(ICPlan): 175 | def __init__(self, sigma=0.0): 176 | super().__init__(sigma) 177 | 178 | def compute_alpha_t(self, t): 179 | """Compute coefficient of x1""" 180 | alpha_t = th.sin(t * np.pi / 2) 181 | d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2) 182 | return alpha_t, d_alpha_t 183 | 184 | def compute_sigma_t(self, t): 185 | """Compute coefficient of x0""" 186 | sigma_t = th.cos(t * np.pi / 2) 187 | d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2) 188 | return sigma_t, d_sigma_t 189 | 190 | def compute_d_alpha_alpha_ratio_t(self, t): 191 | """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" 192 | return np.pi / (2 * th.tan(t * np.pi / 2)) -------------------------------------------------------------------------------- /sample_ddp.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | 4 | """ 5 | Samples a large number of images from a pre-trained SiT model using DDP. 6 | Subsequently saves a .npz file that can be used to compute FID and other 7 | evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations 8 | 9 | For a simple single-GPU/CPU sampling script, see sample.py. 10 | """ 11 | import torch 12 | import torch.distributed as dist 13 | from models import SiT_models 14 | from download import find_model 15 | from transport import create_transport, Sampler 16 | from diffusers.models import AutoencoderKL 17 | from train_utils import parse_ode_args, parse_sde_args, parse_transport_args 18 | from tqdm import tqdm 19 | import os 20 | from PIL import Image 21 | import numpy as np 22 | import math 23 | import argparse 24 | import sys 25 | 26 | 27 | def create_npz_from_sample_folder(sample_dir, num=50_000): 28 | """ 29 | Builds a single .npz file from a folder of .png samples. 30 | """ 31 | samples = [] 32 | for i in tqdm(range(num), desc="Building .npz file from samples"): 33 | sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") 34 | sample_np = np.asarray(sample_pil).astype(np.uint8) 35 | samples.append(sample_np) 36 | samples = np.stack(samples) 37 | assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) 38 | npz_path = f"{sample_dir}.npz" 39 | np.savez(npz_path, arr_0=samples) 40 | print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") 41 | return npz_path 42 | 43 | 44 | def main(mode, args): 45 | """ 46 | Run sampling. 47 | """ 48 | torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences 49 | assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" 50 | torch.set_grad_enabled(False) 51 | 52 | # Setup DDP: 53 | dist.init_process_group("nccl") 54 | rank = dist.get_rank() 55 | device = rank % torch.cuda.device_count() 56 | seed = args.global_seed * dist.get_world_size() + rank 57 | torch.manual_seed(seed) 58 | torch.cuda.set_device(device) 59 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") 60 | 61 | if args.ckpt is None: 62 | assert args.model == "SiT-XL/2", "Only SiT-XL/2 models are available for auto-download." 63 | assert args.image_size in [256, 512] 64 | assert args.num_classes == 1000 65 | assert args.image_size == 256, "512x512 models are not yet available for auto-download." # remove this line when 512x512 models are available 66 | learn_sigma = args.image_size == 256 67 | else: 68 | learn_sigma = False 69 | 70 | # Load model: 71 | latent_size = args.image_size // 8 72 | model = SiT_models[args.model]( 73 | input_size=latent_size, 74 | num_classes=args.num_classes, 75 | learn_sigma=learn_sigma, 76 | ).to(device) 77 | # Auto-download a pre-trained model or load a custom SiT checkpoint from train.py: 78 | ckpt_path = args.ckpt or f"SiT-XL-2-{args.image_size}x{args.image_size}.pt" 79 | state_dict = find_model(ckpt_path) 80 | model.load_state_dict(state_dict) 81 | model.eval() # important! 82 | 83 | 84 | transport = create_transport( 85 | args.path_type, 86 | args.prediction, 87 | args.loss_weight, 88 | args.train_eps, 89 | args.sample_eps 90 | ) 91 | sampler = Sampler(transport) 92 | if mode == "ODE": 93 | if args.likelihood: 94 | assert args.cfg_scale == 1, "Likelihood is incompatible with guidance" 95 | sample_fn = sampler.sample_ode_likelihood( 96 | sampling_method=args.sampling_method, 97 | num_steps=args.num_sampling_steps, 98 | atol=args.atol, 99 | rtol=args.rtol, 100 | ) 101 | else: 102 | sample_fn = sampler.sample_ode( 103 | sampling_method=args.sampling_method, 104 | num_steps=args.num_sampling_steps, 105 | atol=args.atol, 106 | rtol=args.rtol, 107 | reverse=args.reverse 108 | ) 109 | elif mode == "SDE": 110 | sample_fn = sampler.sample_sde( 111 | sampling_method=args.sampling_method, 112 | diffusion_form=args.diffusion_form, 113 | diffusion_norm=args.diffusion_norm, 114 | last_step=args.last_step, 115 | last_step_size=args.last_step_size, 116 | num_steps=args.num_sampling_steps, 117 | ) 118 | vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) 119 | assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0" 120 | using_cfg = args.cfg_scale > 1.0 121 | 122 | # Create folder to save samples: 123 | model_string_name = args.model.replace("/", "-") 124 | ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained" 125 | if mode == "ODE": 126 | folder_name = f"{model_string_name}-{ckpt_string_name}-" \ 127 | f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\ 128 | f"{mode}-{args.num_sampling_steps}-{args.sampling_method}" 129 | elif mode == "SDE": 130 | folder_name = f"{model_string_name}-{ckpt_string_name}-" \ 131 | f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\ 132 | f"{mode}-{args.num_sampling_steps}-{args.sampling_method}-"\ 133 | f"{args.diffusion_form}-{args.last_step}-{args.last_step_size}" 134 | sample_folder_dir = f"{args.sample_dir}/{folder_name}" 135 | if rank == 0: 136 | os.makedirs(sample_folder_dir, exist_ok=True) 137 | print(f"Saving .png samples at {sample_folder_dir}") 138 | dist.barrier() 139 | 140 | # Figure out how many samples we need to generate on each GPU and how many iterations we need to run: 141 | n = args.per_proc_batch_size 142 | global_batch_size = n * dist.get_world_size() 143 | # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples: 144 | num_samples = len([name for name in os.listdir(sample_folder_dir) if (os.path.isfile(os.path.join(sample_folder_dir, name)) and ".png" in name)]) 145 | total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size) 146 | if rank == 0: 147 | print(f"Total number of images that will be sampled: {total_samples}") 148 | assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size" 149 | samples_needed_this_gpu = int(total_samples // dist.get_world_size()) 150 | assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size" 151 | iterations = int(samples_needed_this_gpu // n) 152 | done_iterations = int( int(num_samples // dist.get_world_size()) // n) 153 | pbar = range(iterations) 154 | pbar = tqdm(pbar) if rank == 0 else pbar 155 | total = 0 156 | 157 | for i in pbar: 158 | # Sample inputs: 159 | z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device) 160 | y = torch.randint(0, args.num_classes, (n,), device=device) 161 | 162 | # Setup classifier-free guidance: 163 | if using_cfg: 164 | z = torch.cat([z, z], 0) 165 | y_null = torch.tensor([1000] * n, device=device) 166 | y = torch.cat([y, y_null], 0) 167 | model_kwargs = dict(y=y, cfg_scale=args.cfg_scale) 168 | model_fn = model.forward_with_cfg 169 | else: 170 | model_kwargs = dict(y=y) 171 | model_fn = model.forward 172 | 173 | samples = sample_fn(z, model_fn, **model_kwargs)[-1] 174 | if using_cfg: 175 | samples, _ = samples.chunk(2, dim=0) # Remove null class samples 176 | 177 | samples = vae.decode(samples / 0.18215).sample 178 | samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() 179 | 180 | # Save samples to disk as individual .png files 181 | for i, sample in enumerate(samples): 182 | index = i * dist.get_world_size() + rank + total 183 | Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png") 184 | total += global_batch_size 185 | dist.barrier() 186 | 187 | # Make sure all processes have finished saving their samples before attempting to convert to .npz 188 | dist.barrier() 189 | if rank == 0: 190 | create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples) 191 | print("Done.") 192 | dist.barrier() 193 | dist.destroy_process_group() 194 | 195 | 196 | if __name__ == "__main__": 197 | 198 | parser = argparse.ArgumentParser() 199 | 200 | if len(sys.argv) < 2: 201 | print("Usage: program.py [options]") 202 | sys.exit(1) 203 | 204 | mode = sys.argv[1] 205 | 206 | assert mode[:2] != "--", "Usage: program.py [options]" 207 | assert mode in ["ODE", "SDE"], "Invalid mode. Please choose 'ODE' or 'SDE'" 208 | 209 | parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2") 210 | parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") 211 | parser.add_argument("--sample-dir", type=str, default="samples") 212 | parser.add_argument("--per-proc-batch-size", type=int, default=4) 213 | parser.add_argument("--num-fid-samples", type=int, default=50_000) 214 | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) 215 | parser.add_argument("--num-classes", type=int, default=1000) 216 | parser.add_argument("--cfg-scale", type=float, default=1.0) 217 | parser.add_argument("--num-sampling-steps", type=int, default=250) 218 | parser.add_argument("--global-seed", type=int, default=0) 219 | parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True, 220 | help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.") 221 | parser.add_argument("--ckpt", type=str, default=None, 222 | help="Optional path to a SiT checkpoint (default: auto-download a pre-trained SiT-XL/2 model).") 223 | 224 | parse_transport_args(parser) 225 | if mode == "ODE": 226 | parse_ode_args(parser) 227 | # Further processing for ODE 228 | elif mode == "SDE": 229 | parse_sde_args(parser) 230 | # Further processing for SDE 231 | 232 | args = parser.parse_known_args()[0] 233 | main(mode, args) 234 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | 4 | """ 5 | A minimal training script for SiT using PyTorch DDP. 6 | """ 7 | import torch 8 | # the first flag below was False when we tested this script but True makes A100 training a lot faster: 9 | torch.backends.cuda.matmul.allow_tf32 = True 10 | torch.backends.cudnn.allow_tf32 = True 11 | import torch.distributed as dist 12 | from torch.nn.parallel import DistributedDataParallel as DDP 13 | from torch.utils.data import DataLoader 14 | from torch.utils.data.distributed import DistributedSampler 15 | from torchvision.datasets import ImageFolder 16 | from torchvision import transforms 17 | import numpy as np 18 | from collections import OrderedDict 19 | from PIL import Image 20 | from copy import deepcopy 21 | from glob import glob 22 | from time import time 23 | import argparse 24 | import logging 25 | import os 26 | 27 | from models import SiT_models 28 | from download import find_model 29 | from transport import create_transport, Sampler 30 | from diffusers.models import AutoencoderKL 31 | from train_utils import parse_transport_args 32 | import wandb_utils 33 | 34 | 35 | ################################################################################# 36 | # Training Helper Functions # 37 | ################################################################################# 38 | 39 | @torch.no_grad() 40 | def update_ema(ema_model, model, decay=0.9999): 41 | """ 42 | Step the EMA model towards the current model. 43 | """ 44 | ema_params = OrderedDict(ema_model.named_parameters()) 45 | model_params = OrderedDict(model.named_parameters()) 46 | 47 | for name, param in model_params.items(): 48 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed 49 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) 50 | 51 | 52 | def requires_grad(model, flag=True): 53 | """ 54 | Set requires_grad flag for all parameters in a model. 55 | """ 56 | for p in model.parameters(): 57 | p.requires_grad = flag 58 | 59 | 60 | def cleanup(): 61 | """ 62 | End DDP training. 63 | """ 64 | dist.destroy_process_group() 65 | 66 | 67 | def create_logger(logging_dir): 68 | """ 69 | Create a logger that writes to a log file and stdout. 70 | """ 71 | if dist.get_rank() == 0: # real logger 72 | logging.basicConfig( 73 | level=logging.INFO, 74 | format='[\033[34m%(asctime)s\033[0m] %(message)s', 75 | datefmt='%Y-%m-%d %H:%M:%S', 76 | handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] 77 | ) 78 | logger = logging.getLogger(__name__) 79 | else: # dummy logger (does nothing) 80 | logger = logging.getLogger(__name__) 81 | logger.addHandler(logging.NullHandler()) 82 | return logger 83 | 84 | 85 | def center_crop_arr(pil_image, image_size): 86 | """ 87 | Center cropping implementation from ADM. 88 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 89 | """ 90 | while min(*pil_image.size) >= 2 * image_size: 91 | pil_image = pil_image.resize( 92 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 93 | ) 94 | 95 | scale = image_size / min(*pil_image.size) 96 | pil_image = pil_image.resize( 97 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 98 | ) 99 | 100 | arr = np.array(pil_image) 101 | crop_y = (arr.shape[0] - image_size) // 2 102 | crop_x = (arr.shape[1] - image_size) // 2 103 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) 104 | 105 | 106 | ################################################################################# 107 | # Training Loop # 108 | ################################################################################# 109 | 110 | def main(args): 111 | """ 112 | Trains a new SiT model. 113 | """ 114 | assert torch.cuda.is_available(), "Training currently requires at least one GPU." 115 | 116 | # Setup DDP: 117 | dist.init_process_group("nccl") 118 | assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size." 119 | rank = dist.get_rank() 120 | device = rank % torch.cuda.device_count() 121 | seed = args.global_seed * dist.get_world_size() + rank 122 | torch.manual_seed(seed) 123 | torch.cuda.set_device(device) 124 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") 125 | local_batch_size = int(args.global_batch_size // dist.get_world_size()) 126 | 127 | # Setup an experiment folder: 128 | if rank == 0: 129 | os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders) 130 | experiment_index = len(glob(f"{args.results_dir}/*")) 131 | model_string_name = args.model.replace("/", "-") # e.g., SiT-XL/2 --> SiT-XL-2 (for naming folders) 132 | experiment_name = f"{experiment_index:03d}-{model_string_name}-" \ 133 | f"{args.path_type}-{args.prediction}-{args.loss_weight}" 134 | experiment_dir = f"{args.results_dir}/{experiment_name}" # Create an experiment folder 135 | checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints 136 | os.makedirs(checkpoint_dir, exist_ok=True) 137 | logger = create_logger(experiment_dir) 138 | logger.info(f"Experiment directory created at {experiment_dir}") 139 | 140 | entity = os.environ["ENTITY"] 141 | project = os.environ["PROJECT"] 142 | if args.wandb: 143 | wandb_utils.initialize(args, entity, experiment_name, project) 144 | else: 145 | logger = create_logger(None) 146 | 147 | # Create model: 148 | assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)." 149 | latent_size = args.image_size // 8 150 | model = SiT_models[args.model]( 151 | input_size=latent_size, 152 | num_classes=args.num_classes 153 | ) 154 | 155 | # Note that parameter initialization is done within the SiT constructor 156 | ema = deepcopy(model).to(device) # Create an EMA of the model for use after training 157 | 158 | if args.ckpt is not None: 159 | ckpt_path = args.ckpt 160 | state_dict = find_model(ckpt_path) 161 | model.load_state_dict(state_dict["model"]) 162 | ema.load_state_dict(state_dict["ema"]) 163 | opt.load_state_dict(state_dict["opt"]) 164 | args = state_dict["args"] 165 | 166 | requires_grad(ema, False) 167 | 168 | model = DDP(model.to(device), device_ids=[rank]) 169 | transport = create_transport( 170 | args.path_type, 171 | args.prediction, 172 | args.loss_weight, 173 | args.train_eps, 174 | args.sample_eps 175 | ) # default: velocity; 176 | transport_sampler = Sampler(transport) 177 | vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) 178 | logger.info(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}") 179 | 180 | # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper): 181 | opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0) 182 | 183 | # Setup data: 184 | transform = transforms.Compose([ 185 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), 186 | transforms.RandomHorizontalFlip(), 187 | transforms.ToTensor(), 188 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 189 | ]) 190 | dataset = ImageFolder(args.data_path, transform=transform) 191 | sampler = DistributedSampler( 192 | dataset, 193 | num_replicas=dist.get_world_size(), 194 | rank=rank, 195 | shuffle=True, 196 | seed=args.global_seed 197 | ) 198 | loader = DataLoader( 199 | dataset, 200 | batch_size=local_batch_size, 201 | shuffle=False, 202 | sampler=sampler, 203 | num_workers=args.num_workers, 204 | pin_memory=True, 205 | drop_last=True 206 | ) 207 | logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})") 208 | 209 | # Prepare models for training: 210 | update_ema(ema, model.module, decay=0) # Ensure EMA is initialized with synced weights 211 | model.train() # important! This enables embedding dropout for classifier-free guidance 212 | ema.eval() # EMA model should always be in eval mode 213 | 214 | # Variables for monitoring/logging purposes: 215 | train_steps = 0 216 | log_steps = 0 217 | running_loss = 0 218 | start_time = time() 219 | 220 | # Labels to condition the model with (feel free to change): 221 | ys = torch.randint(1000, size=(local_batch_size,), device=device) 222 | use_cfg = args.cfg_scale > 1.0 223 | # Create sampling noise: 224 | n = ys.size(0) 225 | zs = torch.randn(n, 4, latent_size, latent_size, device=device) 226 | 227 | # Setup classifier-free guidance: 228 | if use_cfg: 229 | zs = torch.cat([zs, zs], 0) 230 | y_null = torch.tensor([1000] * n, device=device) 231 | ys = torch.cat([ys, y_null], 0) 232 | sample_model_kwargs = dict(y=ys, cfg_scale=args.cfg_scale) 233 | model_fn = ema.forward_with_cfg 234 | else: 235 | sample_model_kwargs = dict(y=ys) 236 | model_fn = ema.forward 237 | 238 | logger.info(f"Training for {args.epochs} epochs...") 239 | for epoch in range(args.epochs): 240 | sampler.set_epoch(epoch) 241 | logger.info(f"Beginning epoch {epoch}...") 242 | for x, y in loader: 243 | x = x.to(device) 244 | y = y.to(device) 245 | with torch.no_grad(): 246 | # Map input images to latent space + normalize latents: 247 | x = vae.encode(x).latent_dist.sample().mul_(0.18215) 248 | model_kwargs = dict(y=y, return_act=args.disp) 249 | loss_dict = transport.training_losses(model, x, model_kwargs) 250 | loss = loss_dict["loss"].mean() 251 | opt.zero_grad() 252 | loss.backward() 253 | opt.step() 254 | update_ema(ema, model.module) 255 | 256 | # Log loss values: 257 | running_loss += loss.item() 258 | log_steps += 1 259 | train_steps += 1 260 | if train_steps % args.log_every == 0: 261 | # Measure training speed: 262 | torch.cuda.synchronize() 263 | end_time = time() 264 | steps_per_sec = log_steps / (end_time - start_time) 265 | # Reduce loss history over all processes: 266 | avg_loss = torch.tensor(running_loss / log_steps, device=device) 267 | dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) 268 | avg_loss = avg_loss.item() / dist.get_world_size() 269 | logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}") 270 | if args.wandb: 271 | wandb_utils.log( 272 | { "train loss": avg_loss, "train steps/sec": steps_per_sec }, 273 | step=train_steps 274 | ) 275 | # Reset monitoring variables: 276 | running_loss = 0 277 | log_steps = 0 278 | start_time = time() 279 | 280 | # Save SiT checkpoint: 281 | if train_steps % args.ckpt_every == 0 and train_steps > 0: 282 | if rank == 0: 283 | checkpoint = { 284 | "model": model.module.state_dict(), 285 | "ema": ema.state_dict(), 286 | "opt": opt.state_dict(), 287 | "args": args 288 | } 289 | checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt" 290 | torch.save(checkpoint, checkpoint_path) 291 | logger.info(f"Saved checkpoint to {checkpoint_path}") 292 | dist.barrier() 293 | 294 | if train_steps % args.sample_every == 0 and train_steps > 0: 295 | logger.info("Generating EMA samples...") 296 | sample_fn = transport_sampler.sample_ode() # default to ode sampling 297 | samples = sample_fn(zs, model_fn, **sample_model_kwargs)[-1] 298 | dist.barrier() 299 | 300 | if use_cfg: #remove null samples 301 | samples, _ = samples.chunk(2, dim=0) 302 | samples = vae.decode(samples / 0.18215).sample 303 | out_samples = torch.zeros((args.global_batch_size, 3, args.image_size, args.image_size), device=device) 304 | dist.all_gather_into_tensor(out_samples, samples) 305 | if args.wandb: 306 | wandb_utils.log_image(out_samples, train_steps) 307 | logging.info("Generating EMA samples done.") 308 | 309 | model.eval() # important! This disables randomized embedding dropout 310 | # do any sampling/FID calculation/etc. with ema (or model) in eval mode ... 311 | 312 | logger.info("Done!") 313 | cleanup() 314 | 315 | 316 | if __name__ == "__main__": 317 | # Default args here will train SiT-XL/2 with the hyperparameters we used in our paper (except training iters). 318 | parser = argparse.ArgumentParser() 319 | parser.add_argument("--data-path", type=str, required=True) 320 | parser.add_argument("--results-dir", type=str, default="results") 321 | parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2") 322 | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) 323 | parser.add_argument("--num-classes", type=int, default=1000) 324 | parser.add_argument("--epochs", type=int, default=1400) 325 | parser.add_argument("--global-batch-size", type=int, default=256) 326 | parser.add_argument("--global-seed", type=int, default=0) 327 | parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") # Choice doesn't affect training 328 | parser.add_argument("--num-workers", type=int, default=4) 329 | parser.add_argument("--log-every", type=int, default=100) 330 | parser.add_argument("--ckpt-every", type=int, default=50_000) 331 | parser.add_argument("--sample-every", type=int, default=10_000) 332 | parser.add_argument("--cfg-scale", type=float, default=4.0) 333 | parser.add_argument("--wandb", action="store_true") 334 | parser.add_argument("--ckpt", type=str, default=None, 335 | help="Optional path to a custom SiT checkpoint") 336 | parser.add_argument("--disp", action="store_true", 337 | help="Toggle to enable Dispersive Loss") 338 | parse_transport_args(parser) 339 | args = parser.parse_args() 340 | main(args) 341 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | # -------------------------------------------------------- 4 | # References: 5 | # GLIDE: https://github.com/openai/glide-text2im 6 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py 7 | # -------------------------------------------------------- 8 | 9 | import torch 10 | import torch.nn as nn 11 | import numpy as np 12 | import math 13 | from timm.models.vision_transformer import PatchEmbed, Attention, Mlp 14 | 15 | 16 | def modulate(x, shift, scale): 17 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 18 | 19 | 20 | ################################################################################# 21 | # Embedding Layers for Timesteps and Class Labels # 22 | ################################################################################# 23 | 24 | class TimestepEmbedder(nn.Module): 25 | """ 26 | Embeds scalar timesteps into vector representations. 27 | """ 28 | def __init__(self, hidden_size, frequency_embedding_size=256): 29 | super().__init__() 30 | self.mlp = nn.Sequential( 31 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 32 | nn.SiLU(), 33 | nn.Linear(hidden_size, hidden_size, bias=True), 34 | ) 35 | self.frequency_embedding_size = frequency_embedding_size 36 | 37 | @staticmethod 38 | def timestep_embedding(t, dim, max_period=10000): 39 | """ 40 | Create sinusoidal timestep embeddings. 41 | :param t: a 1-D Tensor of N indices, one per batch element. 42 | These may be fractional. 43 | :param dim: the dimension of the output. 44 | :param max_period: controls the minimum frequency of the embeddings. 45 | :return: an (N, D) Tensor of positional embeddings. 46 | """ 47 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 48 | half = dim // 2 49 | freqs = torch.exp( 50 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 51 | ).to(device=t.device) 52 | args = t[:, None].float() * freqs[None] 53 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 54 | if dim % 2: 55 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 56 | return embedding 57 | 58 | def forward(self, t): 59 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 60 | t_emb = self.mlp(t_freq) 61 | return t_emb 62 | 63 | 64 | class LabelEmbedder(nn.Module): 65 | """ 66 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 67 | """ 68 | def __init__(self, num_classes, hidden_size, dropout_prob): 69 | super().__init__() 70 | use_cfg_embedding = dropout_prob > 0 71 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) 72 | self.num_classes = num_classes 73 | self.dropout_prob = dropout_prob 74 | 75 | def token_drop(self, labels, force_drop_ids=None): 76 | """ 77 | Drops labels to enable classifier-free guidance. 78 | """ 79 | if force_drop_ids is None: 80 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob 81 | else: 82 | drop_ids = force_drop_ids == 1 83 | labels = torch.where(drop_ids, self.num_classes, labels) 84 | return labels 85 | 86 | def forward(self, labels, train, force_drop_ids=None): 87 | use_dropout = self.dropout_prob > 0 88 | if (train and use_dropout) or (force_drop_ids is not None): 89 | labels = self.token_drop(labels, force_drop_ids) 90 | embeddings = self.embedding_table(labels) 91 | return embeddings 92 | 93 | 94 | ################################################################################# 95 | # Core SiT Model # 96 | ################################################################################# 97 | 98 | class SiTBlock(nn.Module): 99 | """ 100 | A SiT block with adaptive layer norm zero (adaLN-Zero) conditioning. 101 | """ 102 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): 103 | super().__init__() 104 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 105 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) 106 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 107 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 108 | approx_gelu = lambda: nn.GELU(approximate="tanh") 109 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) 110 | self.adaLN_modulation = nn.Sequential( 111 | nn.SiLU(), 112 | nn.Linear(hidden_size, 6 * hidden_size, bias=True) 113 | ) 114 | 115 | def forward(self, x, c): 116 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) 117 | x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) 118 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) 119 | return x 120 | 121 | 122 | class FinalLayer(nn.Module): 123 | """ 124 | The final layer of SiT. 125 | """ 126 | def __init__(self, hidden_size, patch_size, out_channels): 127 | super().__init__() 128 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 129 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 130 | self.adaLN_modulation = nn.Sequential( 131 | nn.SiLU(), 132 | nn.Linear(hidden_size, 2 * hidden_size, bias=True) 133 | ) 134 | 135 | def forward(self, x, c): 136 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) 137 | x = modulate(self.norm_final(x), shift, scale) 138 | x = self.linear(x) 139 | return x 140 | 141 | 142 | class SiT(nn.Module): 143 | """ 144 | Diffusion model with a Transformer backbone. 145 | """ 146 | def __init__( 147 | self, 148 | input_size=32, 149 | patch_size=2, 150 | in_channels=4, 151 | hidden_size=1152, 152 | depth=28, 153 | num_heads=16, 154 | mlp_ratio=4.0, 155 | class_dropout_prob=0.1, 156 | num_classes=1000, 157 | learn_sigma=True, 158 | ): 159 | super().__init__() 160 | self.learn_sigma = learn_sigma 161 | self.in_channels = in_channels 162 | self.out_channels = in_channels * 2 if learn_sigma else in_channels 163 | self.patch_size = patch_size 164 | self.num_heads = num_heads 165 | 166 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) 167 | self.t_embedder = TimestepEmbedder(hidden_size) 168 | self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) 169 | num_patches = self.x_embedder.num_patches 170 | # Will use fixed sin-cos embedding: 171 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) 172 | 173 | self.blocks = nn.ModuleList([ 174 | SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) 175 | ]) 176 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) 177 | self.initialize_weights() 178 | 179 | def initialize_weights(self): 180 | # Initialize transformer layers: 181 | def _basic_init(module): 182 | if isinstance(module, nn.Linear): 183 | torch.nn.init.xavier_uniform_(module.weight) 184 | if module.bias is not None: 185 | nn.init.constant_(module.bias, 0) 186 | self.apply(_basic_init) 187 | 188 | # Initialize (and freeze) pos_embed by sin-cos embedding: 189 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) 190 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 191 | 192 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): 193 | w = self.x_embedder.proj.weight.data 194 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 195 | nn.init.constant_(self.x_embedder.proj.bias, 0) 196 | 197 | # Initialize label embedding table: 198 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) 199 | 200 | # Initialize timestep embedding MLP: 201 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 202 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 203 | 204 | # Zero-out adaLN modulation layers in SiT blocks: 205 | for block in self.blocks: 206 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 207 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 208 | 209 | # Zero-out output layers: 210 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 211 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 212 | nn.init.constant_(self.final_layer.linear.weight, 0) 213 | nn.init.constant_(self.final_layer.linear.bias, 0) 214 | 215 | def unpatchify(self, x): 216 | """ 217 | x: (N, T, patch_size**2 * C) 218 | imgs: (N, H, W, C) 219 | """ 220 | c = self.out_channels 221 | p = self.x_embedder.patch_size[0] 222 | h = w = int(x.shape[1] ** 0.5) 223 | assert h * w == x.shape[1] 224 | 225 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) 226 | x = torch.einsum('nhwpqc->nchpwq', x) 227 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) 228 | return imgs 229 | 230 | def forward(self, x, t, y, return_act=False): 231 | """ 232 | Forward pass of SiT. 233 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) 234 | t: (N,) tensor of diffusion timesteps 235 | y: (N,) tensor of class labels 236 | """ 237 | act = [] 238 | x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 239 | t = self.t_embedder(t) # (N, D) 240 | y = self.y_embedder(y, self.training) # (N, D) 241 | c = t + y # (N, D) 242 | for block in self.blocks: 243 | x = block(x, c) # (N, T, D) 244 | act.append(x) 245 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) 246 | x = self.unpatchify(x) # (N, out_channels, H, W) 247 | if self.learn_sigma: 248 | x, _ = x.chunk(2, dim=1) 249 | if return_act: 250 | return x, act 251 | return x 252 | 253 | def forward_with_cfg(self, x, t, y, cfg_scale): 254 | """ 255 | Forward pass of SiT, but also batches the unconSiTional forward pass for classifier-free guidance. 256 | """ 257 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb 258 | half = x[: len(x) // 2] 259 | combined = torch.cat([half, half], dim=0) 260 | model_out = self.forward(combined, t, y) 261 | # For exact reproducibility reasons, we apply classifier-free guidance on only 262 | # three channels by default. The standard approach to cfg applies it to all channels. 263 | # This can be done by uncommenting the following line and commenting-out the line following that. 264 | # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] 265 | eps, rest = model_out[:, :3], model_out[:, 3:] 266 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 267 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) 268 | eps = torch.cat([half_eps, half_eps], dim=0) 269 | return torch.cat([eps, rest], dim=1) 270 | 271 | 272 | ################################################################################# 273 | # Sine/Cosine Positional Embedding Functions # 274 | ################################################################################# 275 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 276 | 277 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 278 | """ 279 | grid_size: int of the grid height and width 280 | return: 281 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 282 | """ 283 | grid_h = np.arange(grid_size, dtype=np.float32) 284 | grid_w = np.arange(grid_size, dtype=np.float32) 285 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 286 | grid = np.stack(grid, axis=0) 287 | 288 | grid = grid.reshape([2, 1, grid_size, grid_size]) 289 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 290 | if cls_token and extra_tokens > 0: 291 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 292 | return pos_embed 293 | 294 | 295 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 296 | assert embed_dim % 2 == 0 297 | 298 | # use half of dimensions to encode grid_h 299 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 300 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 301 | 302 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 303 | return emb 304 | 305 | 306 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 307 | """ 308 | embed_dim: output dimension for each position 309 | pos: a list of positions to be encoded: size (M,) 310 | out: (M, D) 311 | """ 312 | assert embed_dim % 2 == 0 313 | omega = np.arange(embed_dim // 2, dtype=np.float64) 314 | omega /= embed_dim / 2. 315 | omega = 1. / 10000**omega # (D/2,) 316 | 317 | pos = pos.reshape(-1) # (M,) 318 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 319 | 320 | emb_sin = np.sin(out) # (M, D/2) 321 | emb_cos = np.cos(out) # (M, D/2) 322 | 323 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 324 | return emb 325 | 326 | 327 | ################################################################################# 328 | # SiT Configs # 329 | ################################################################################# 330 | 331 | def SiT_XL_2(**kwargs): 332 | return SiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) 333 | 334 | def SiT_XL_4(**kwargs): 335 | return SiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs) 336 | 337 | def SiT_XL_8(**kwargs): 338 | return SiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs) 339 | 340 | def SiT_L_2(**kwargs): 341 | return SiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) 342 | 343 | def SiT_L_4(**kwargs): 344 | return SiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs) 345 | 346 | def SiT_L_8(**kwargs): 347 | return SiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs) 348 | 349 | def SiT_B_2(**kwargs): 350 | return SiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) 351 | 352 | def SiT_B_4(**kwargs): 353 | return SiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs) 354 | 355 | def SiT_B_8(**kwargs): 356 | return SiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs) 357 | 358 | def SiT_S_2(**kwargs): 359 | return SiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) 360 | 361 | def SiT_S_4(**kwargs): 362 | return SiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) 363 | 364 | def SiT_S_8(**kwargs): 365 | return SiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) 366 | 367 | 368 | SiT_models = { 369 | 'SiT-XL/2': SiT_XL_2, 'SiT-XL/4': SiT_XL_4, 'SiT-XL/8': SiT_XL_8, 370 | 'SiT-L/2': SiT_L_2, 'SiT-L/4': SiT_L_4, 'SiT-L/8': SiT_L_8, 371 | 'SiT-B/2': SiT_B_2, 'SiT-B/4': SiT_B_4, 'SiT-B/8': SiT_B_8, 372 | 'SiT-S/2': SiT_S_2, 'SiT-S/4': SiT_S_4, 'SiT-S/8': SiT_S_8, 373 | } 374 | -------------------------------------------------------------------------------- /transport/transport.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | import logging 4 | 5 | import enum 6 | 7 | from . import path 8 | from .utils import EasyDict, log_state, mean_flat 9 | from .integrators import ode, sde 10 | 11 | class ModelType(enum.Enum): 12 | """ 13 | Which type of output the model predicts. 14 | """ 15 | 16 | NOISE = enum.auto() # the model predicts epsilon 17 | SCORE = enum.auto() # the model predicts \nabla \log p(x) 18 | VELOCITY = enum.auto() # the model predicts v(x) 19 | 20 | class PathType(enum.Enum): 21 | """ 22 | Which type of path to use. 23 | """ 24 | 25 | LINEAR = enum.auto() 26 | GVP = enum.auto() 27 | VP = enum.auto() 28 | 29 | class WeightType(enum.Enum): 30 | """ 31 | Which type of weighting to use. 32 | """ 33 | 34 | NONE = enum.auto() 35 | VELOCITY = enum.auto() 36 | LIKELIHOOD = enum.auto() 37 | 38 | 39 | class Transport: 40 | 41 | def __init__( 42 | self, 43 | *, 44 | model_type, 45 | path_type, 46 | loss_type, 47 | train_eps, 48 | sample_eps, 49 | ): 50 | path_options = { 51 | PathType.LINEAR: path.ICPlan, 52 | PathType.GVP: path.GVPCPlan, 53 | PathType.VP: path.VPCPlan, 54 | } 55 | 56 | self.loss_type = loss_type 57 | self.model_type = model_type 58 | self.path_sampler = path_options[path_type]() 59 | self.train_eps = train_eps 60 | self.sample_eps = sample_eps 61 | 62 | def prior_logp(self, z): 63 | ''' 64 | Standard multivariate normal prior 65 | Assume z is batched 66 | ''' 67 | shape = th.tensor(z.size()) 68 | N = th.prod(shape[1:]) 69 | _fn = lambda x: -N / 2. * np.log(2 * np.pi) - th.sum(x ** 2) / 2. 70 | return th.vmap(_fn)(z) 71 | 72 | 73 | def check_interval( 74 | self, 75 | train_eps, 76 | sample_eps, 77 | *, 78 | diffusion_form="SBDM", 79 | sde=False, 80 | reverse=False, 81 | eval=False, 82 | last_step_size=0.0, 83 | ): 84 | t0 = 0 85 | t1 = 1 86 | eps = train_eps if not eval else sample_eps 87 | if (type(self.path_sampler) in [path.VPCPlan]): 88 | 89 | t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size 90 | 91 | elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) \ 92 | and (self.model_type != ModelType.VELOCITY or sde): # avoid numerical issue by taking a first semi-implicit step 93 | 94 | t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0 95 | t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size 96 | 97 | if reverse: 98 | t0, t1 = 1 - t0, 1 - t1 99 | 100 | return t0, t1 101 | 102 | 103 | def sample(self, x1): 104 | """Sampling x0 & t based on shape of x1 (if needed) 105 | Args: 106 | x1 - data point; [batch, *dim] 107 | """ 108 | 109 | x0 = th.randn_like(x1) 110 | t0, t1 = self.check_interval(self.train_eps, self.sample_eps) 111 | t = th.rand((x1.shape[0],)) * (t1 - t0) + t0 112 | t = t.to(x1) 113 | return t, x0, x1 114 | 115 | def disp_loss(self, z): # Dispersive Loss implementation (InfoNCE-L2 variant) 116 | z = z.reshape((z.shape[0],-1)) # flatten 117 | diff = th.nn.functional.pdist(z).pow(2)/z.shape[1] # pairwise distance 118 | diff = th.concat((diff, diff, th.zeros(z.shape[0]).cuda())) # match JAX implementation of full BxB matrix 119 | return th.log(th.exp(-diff).mean()) # calculate loss 120 | 121 | def training_losses( 122 | self, 123 | model, 124 | x1, 125 | model_kwargs=None 126 | ): 127 | """Loss for training the score model 128 | Args: 129 | - model: backbone model; could be score, noise, or velocity 130 | - x1: datapoint 131 | - model_kwargs: additional arguments for the model 132 | """ 133 | if model_kwargs == None: 134 | model_kwargs = {} 135 | 136 | t, x0, x1 = self.sample(x1) 137 | t, xt, ut = self.path_sampler.plan(t, x0, x1) 138 | model_output = model(xt, t, **model_kwargs) 139 | 140 | disp_loss = 0 141 | if "return_act" in model_kwargs and model_kwargs['return_act']: 142 | model_output, act = model_output 143 | disp_loss = self.disp_loss(act[-1]) 144 | 145 | B, *_, C = xt.shape 146 | assert model_output.size() == (B, *xt.size()[1:-1], C) 147 | 148 | terms = {} 149 | terms['pred'] = model_output 150 | if self.model_type == ModelType.VELOCITY: 151 | terms['loss'] = mean_flat(((model_output - ut) ** 2)) 152 | else: 153 | _, drift_var = self.path_sampler.compute_drift(xt, t) 154 | sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt)) 155 | if self.loss_type in [WeightType.VELOCITY]: 156 | weight = (drift_var / sigma_t) ** 2 157 | elif self.loss_type in [WeightType.LIKELIHOOD]: 158 | weight = drift_var / (sigma_t ** 2) 159 | elif self.loss_type in [WeightType.NONE]: 160 | weight = 1 161 | else: 162 | raise NotImplementedError() 163 | 164 | if self.model_type == ModelType.NOISE: 165 | terms['loss'] = mean_flat(weight * ((model_output - x0) ** 2)) 166 | else: 167 | terms['loss'] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2)) 168 | terms['loss'] += 0.25*disp_loss 169 | return terms 170 | 171 | 172 | def get_drift( 173 | self 174 | ): 175 | """member function for obtaining the drift of the probability flow ODE""" 176 | def score_ode(x, t, model, **model_kwargs): 177 | drift_mean, drift_var = self.path_sampler.compute_drift(x, t) 178 | model_output = model(x, t, **model_kwargs) 179 | return (-drift_mean + drift_var * model_output) # by change of variable 180 | 181 | def noise_ode(x, t, model, **model_kwargs): 182 | drift_mean, drift_var = self.path_sampler.compute_drift(x, t) 183 | sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x)) 184 | model_output = model(x, t, **model_kwargs) 185 | score = model_output / -sigma_t 186 | return (-drift_mean + drift_var * score) 187 | 188 | def velocity_ode(x, t, model, **model_kwargs): 189 | model_output = model(x, t, **model_kwargs) 190 | return model_output 191 | 192 | if self.model_type == ModelType.NOISE: 193 | drift_fn = noise_ode 194 | elif self.model_type == ModelType.SCORE: 195 | drift_fn = score_ode 196 | else: 197 | drift_fn = velocity_ode 198 | 199 | def body_fn(x, t, model, **model_kwargs): 200 | model_output = drift_fn(x, t, model, **model_kwargs) 201 | assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape" 202 | return model_output 203 | 204 | return body_fn 205 | 206 | 207 | def get_score( 208 | self, 209 | ): 210 | """member function for obtaining score of 211 | x_t = alpha_t * x + sigma_t * eps""" 212 | if self.model_type == ModelType.NOISE: 213 | score_fn = lambda x, t, model, **kwargs: model(x, t, **kwargs) / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0] 214 | elif self.model_type == ModelType.SCORE: 215 | score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs) 216 | elif self.model_type == ModelType.VELOCITY: 217 | score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x, t) 218 | else: 219 | raise NotImplementedError() 220 | 221 | return score_fn 222 | 223 | 224 | class Sampler: 225 | """Sampler class for the transport model""" 226 | def __init__( 227 | self, 228 | transport, 229 | ): 230 | """Constructor for a general sampler; supporting different sampling methods 231 | Args: 232 | - transport: an tranport object specify model prediction & interpolant type 233 | """ 234 | 235 | self.transport = transport 236 | self.drift = self.transport.get_drift() 237 | self.score = self.transport.get_score() 238 | 239 | def __get_sde_diffusion_and_drift( 240 | self, 241 | *, 242 | diffusion_form="SBDM", 243 | diffusion_norm=1.0, 244 | ): 245 | 246 | def diffusion_fn(x, t): 247 | diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm) 248 | return diffusion 249 | 250 | sde_drift = \ 251 | lambda x, t, model, **kwargs: \ 252 | self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs) 253 | 254 | sde_diffusion = diffusion_fn 255 | 256 | return sde_drift, sde_diffusion 257 | 258 | def __get_last_step( 259 | self, 260 | sde_drift, 261 | *, 262 | last_step, 263 | last_step_size, 264 | ): 265 | """Get the last step function of the SDE solver""" 266 | 267 | if last_step is None: 268 | last_step_fn = \ 269 | lambda x, t, model, **model_kwargs: \ 270 | x 271 | elif last_step == "Mean": 272 | last_step_fn = \ 273 | lambda x, t, model, **model_kwargs: \ 274 | x + sde_drift(x, t, model, **model_kwargs) * last_step_size 275 | elif last_step == "Tweedie": 276 | alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long 277 | sigma = self.transport.path_sampler.compute_sigma_t 278 | last_step_fn = \ 279 | lambda x, t, model, **model_kwargs: \ 280 | x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][0] * self.score(x, t, model, **model_kwargs) 281 | elif last_step == "Euler": 282 | last_step_fn = \ 283 | lambda x, t, model, **model_kwargs: \ 284 | x + self.drift(x, t, model, **model_kwargs) * last_step_size 285 | else: 286 | raise NotImplementedError() 287 | 288 | return last_step_fn 289 | 290 | def sample_sde( 291 | self, 292 | *, 293 | sampling_method="Euler", 294 | diffusion_form="SBDM", 295 | diffusion_norm=1.0, 296 | last_step="Mean", 297 | last_step_size=0.04, 298 | num_steps=250, 299 | ): 300 | """returns a sampling function with given SDE settings 301 | Args: 302 | - sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama 303 | - diffusion_form: function form of diffusion coefficient; default to be matching SBDM 304 | - diffusion_norm: function magnitude of diffusion coefficient; default to 1 305 | - last_step: type of the last step; default to identity 306 | - last_step_size: size of the last step; default to match the stride of 250 steps over [0,1] 307 | - num_steps: total integration step of SDE 308 | """ 309 | 310 | if last_step is None: 311 | last_step_size = 0.0 312 | 313 | sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift( 314 | diffusion_form=diffusion_form, 315 | diffusion_norm=diffusion_norm, 316 | ) 317 | 318 | t0, t1 = self.transport.check_interval( 319 | self.transport.train_eps, 320 | self.transport.sample_eps, 321 | diffusion_form=diffusion_form, 322 | sde=True, 323 | eval=True, 324 | reverse=False, 325 | last_step_size=last_step_size, 326 | ) 327 | 328 | _sde = sde( 329 | sde_drift, 330 | sde_diffusion, 331 | t0=t0, 332 | t1=t1, 333 | num_steps=num_steps, 334 | sampler_type=sampling_method 335 | ) 336 | 337 | last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size) 338 | 339 | 340 | def _sample(init, model, **model_kwargs): 341 | xs = _sde.sample(init, model, **model_kwargs) 342 | ts = th.ones(init.size(0), device=init.device) * t1 343 | x = last_step_fn(xs[-1], ts, model, **model_kwargs) 344 | xs.append(x) 345 | 346 | assert len(xs) == num_steps, "Samples does not match the number of steps" 347 | 348 | return xs 349 | 350 | return _sample 351 | 352 | def sample_ode( 353 | self, 354 | *, 355 | sampling_method="dopri5", 356 | num_steps=50, 357 | atol=1e-6, 358 | rtol=1e-3, 359 | reverse=False, 360 | ): 361 | """returns a sampling function with given ODE settings 362 | Args: 363 | - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 364 | - num_steps: 365 | - fixed solver (Euler, Heun): the actual number of integration steps performed 366 | - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation 367 | - atol: absolute error tolerance for the solver 368 | - rtol: relative error tolerance for the solver 369 | - reverse: whether solving the ODE in reverse (data to noise); default to False 370 | """ 371 | if reverse: 372 | drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs) 373 | else: 374 | drift = self.drift 375 | 376 | t0, t1 = self.transport.check_interval( 377 | self.transport.train_eps, 378 | self.transport.sample_eps, 379 | sde=False, 380 | eval=True, 381 | reverse=reverse, 382 | last_step_size=0.0, 383 | ) 384 | 385 | _ode = ode( 386 | drift=drift, 387 | t0=t0, 388 | t1=t1, 389 | sampler_type=sampling_method, 390 | num_steps=num_steps, 391 | atol=atol, 392 | rtol=rtol, 393 | ) 394 | 395 | return _ode.sample 396 | 397 | def sample_ode_likelihood( 398 | self, 399 | *, 400 | sampling_method="dopri5", 401 | num_steps=50, 402 | atol=1e-6, 403 | rtol=1e-3, 404 | ): 405 | 406 | """returns a sampling function for calculating likelihood with given ODE settings 407 | Args: 408 | - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 409 | - num_steps: 410 | - fixed solver (Euler, Heun): the actual number of integration steps performed 411 | - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation 412 | - atol: absolute error tolerance for the solver 413 | - rtol: relative error tolerance for the solver 414 | """ 415 | def _likelihood_drift(x, t, model, **model_kwargs): 416 | x, _ = x 417 | eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1 418 | t = th.ones_like(t) * (1 - t) 419 | with th.enable_grad(): 420 | x.requires_grad = True 421 | grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0] 422 | logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size())))) 423 | drift = self.drift(x, t, model, **model_kwargs) 424 | return (-drift, logp_grad) 425 | 426 | t0, t1 = self.transport.check_interval( 427 | self.transport.train_eps, 428 | self.transport.sample_eps, 429 | sde=False, 430 | eval=True, 431 | reverse=False, 432 | last_step_size=0.0, 433 | ) 434 | 435 | _ode = ode( 436 | drift=_likelihood_drift, 437 | t0=t0, 438 | t1=t1, 439 | sampler_type=sampling_method, 440 | num_steps=num_steps, 441 | atol=atol, 442 | rtol=rtol, 443 | ) 444 | 445 | def _sample_fn(x, model, **model_kwargs): 446 | init_logp = th.zeros(x.size(0)).to(x) 447 | input = (x, init_logp) 448 | drift, delta_logp = _ode.sample(input, model, **model_kwargs) 449 | drift, delta_logp = drift[-1], delta_logp[-1] 450 | prior_logp = self.transport.prior_logp(drift) 451 | logp = prior_logp - delta_logp 452 | return logp, drift 453 | 454 | return _sample_fn 455 | --------------------------------------------------------------------------------