├── .gitignore ├── LICENSE.txt ├── README.md ├── download.py ├── environment.yml ├── models.py ├── run_SiT.ipynb ├── sample.py ├── sample_ddp.py ├── train.py ├── train_utils.py ├── transport ├── __init__.py ├── integrators.py ├── path.py ├── transport.py └── utils.py ├── visuals ├── .DS_Store ├── visual.png └── visual_2.png └── wandb_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | wandb 3 | 4 | .DS_store 5 | samples 6 | results 7 | pretrained_models -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Meta Platforms, Inc. and affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers (SiT)
Official PyTorch Implementation 2 | 3 | ### [Paper](https://arxiv.org/pdf/2401.08740.pdf) | [Project Page](https://scalable-interpolant.github.io/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/willisma/SiT/blob/main/run_SiT.ipynb) 4 | 5 | ![SiT samples](visuals/visual.png) 6 | 7 | This repo contains PyTorch model definitions, pre-trained weights and training/sampling code for our paper exploring 8 | interpolant models with scalable transformers (SiTs). 9 | 10 | > [**Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers**](https://arxiv.org/pdf/2401.08740.pdf)
11 | > [Nanye Ma](https://willisma.github.io), [Mark Goldstein](https://marikgoldstein.github.io/), [Michael Albergo](http://malbergo.me/), [Nicholas Boffi](https://nmboffi.github.io/), [Eric Vanden-Eijnden](https://wp.nyu.edu/courantinstituteofmathematicalsciences-eve2/), [Saining Xie](https://www.sainingxie.com) 12 | >
New York University
13 | 14 | We present Scalable Interpolant Transformers (SiT), a family of generative models built on the backbone of Diffusion Transformers (DiT). The interpolant framework, which allows for connecting two distributions in a more flexible way than standard diffusion models, makes possible a modular study of various design choices impacting generative models built on dynamical transport: using discrete vs. continuous time learning, deciding the model to learn, choosing the interpolant connecting the distributions, and deploying a deterministic or stochastic sampler. By carefully introducing the above ingredients, SiT surpasses DiT uniformly across model sizes on the conditional ImageNet 256x256 benchmark using the exact same backbone, number of parameters, and GFLOPs. By exploring various diffusion coefficients, which can be tuned separately from learning, SiT achieves an FID-50K score of 2.06. 15 | 16 | This repository contains: 17 | 18 | * 🪐 A simple PyTorch [implementation](models.py) of SiT 19 | * ⚡️ Pre-trained class-conditional SiT models trained on ImageNet 256x256 20 | * 🛸 A SiT [training script](train.py) using PyTorch DDP 21 | 22 | ## Setup 23 | 24 | First, download and set up the repo: 25 | 26 | ```bash 27 | git clone https://github.com/willisma/SiT.git 28 | cd SiT 29 | ``` 30 | 31 | We provide an [`environment.yml`](environment.yml) file that can be used to create a Conda environment. If you only want 32 | to run pre-trained models locally on CPU, you can remove the `cudatoolkit` and `pytorch-cuda` requirements from the file. 33 | 34 | ```bash 35 | conda env create -f environment.yml 36 | conda activate SiT 37 | ``` 38 | 39 | 40 | ## Sampling [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/willisma/SiT/blob/main/run_SiT.ipynb) 41 | ![More SiT samples](visuals/visual_2.png) 42 | 43 | **Pre-trained SiT checkpoints.** You can sample from our pre-trained SiT models with [`sample.py`](sample.py). Weights for our pre-trained SiT model will be 44 | automatically downloaded depending on the model you use. The script has various arguments to adjust sampler configurations (ODE & SDE), sampling steps, change the classifier-free guidance scale, etc. For example, to sample from 45 | our 256x256 SiT-XL model with default ODE setting, you can use: 46 | 47 | ```bash 48 | python sample.py ODE --image-size 256 --seed 1 49 | ``` 50 | 51 | For convenience, our pre-trained SiT models can be downloaded directly here as well: 52 | 53 | | SiT Model | Image Resolution | FID-50K | Inception Score | Gflops | 54 | |---------------|------------------|---------|-----------------|--------| 55 | | [XL/2](https://www.dl.dropboxusercontent.com/scl/fi/as9oeomcbub47de5g4be0/SiT-XL-2-256.pt?rlkey=uxzxmpicu46coq3msb17b9ofa&dl=0) | 256x256 | 2.06 | 270.27 | 119 | 56 | 57 | 58 | 59 | **Custom SiT checkpoints.** If you've trained a new SiT model with [`train.py`](train.py) (see [below](#training-SiT)), you can add the `--ckpt` 60 | argument to use your own checkpoint instead. For example, to sample from the EMA weights of a custom 61 | 256x256 SiT-L/4 model with ODE sampler, run: 62 | 63 | ```bash 64 | python sample.py ODE --model SiT-L/4 --image-size 256 --ckpt /path/to/model.pt 65 | ``` 66 | 67 | ### Advanced sampler settings 68 | 69 | | | | | | 70 | |-----|----------|----------|--------------------------| 71 | | ODE | `--atol` | `float` | Absolute error tolerance | 72 | | | `--rtol` | `float` | Relative error tolenrace | 73 | | | `--sampling-method` | `str` | Sampling methods (refer to [`torchdiffeq`](https://github.com/rtqichen/torchdiffeq) ) | 74 | 75 | | | | | | 76 | |-----|----------|----------|--------------------------| 77 | | SDE | `--diffusion-form` | `str` | Form of SDE's diffusion coefficient (refer to Tab. 2 in [paper]()) | 78 | | | `--diffusion-norm` | `float` | Magnitude of SDE's diffusion coefficient | 79 | | | `--last-step` | `str` | Form of SDE's last step | 80 | | | | | None - Single SDE integration step | 81 | | | | | "Mean" - SDE integration step without diffusion coefficient | 82 | | | | | "Tweedie" - [Tweedie's denoising](https://efron.ckirby.su.domains/papers/2011TweediesFormula.pdf) step | 83 | | | | | "Euler" - Single ODE integration step 84 | | | `--sampling-method` | `str` | Sampling methods | 85 | | | | | "Euler" - First order integration | 86 | | | | | "Heun" - Second order integration | 87 | 88 | There are some more options; refer to [`train_utils.py`](train_utils.py) for details. 89 | 90 | ## Training SiT 91 | 92 | We provide a training script for SiT in [`train.py`](train.py). To launch SiT-XL/2 (256x256) training with `N` GPUs on 93 | one node: 94 | 95 | ```bash 96 | torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-XL/2 --data-path /path/to/imagenet/train 97 | ``` 98 | 99 | **Logging.** To enable `wandb`, firstly set `WANDB_KEY`, `ENTITY`, and `PROJECT` as environment variables: 100 | 101 | ```bash 102 | export WANDB_KEY="key" 103 | export ENTITY="entity name" 104 | export PROJECT="project name" 105 | ``` 106 | 107 | Then in training command add the `--wandb` flag: 108 | 109 | ```bash 110 | torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-XL/2 --data-path /path/to/imagenet/train --wandb 111 | ``` 112 | 113 | **Interpolant settings.** We also support different choices of interpolant and model predictions. For example, to launch SiT-XL/2 (256x256) with `Linear` interpolant and `noise` prediction: 114 | 115 | ```bash 116 | torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-XL/2 --data-path /path/to/imagenet/train --path-type Linear --prediction noise 117 | ``` 118 | 119 | **Resume training.** To resume training from custom checkpoint: 120 | 121 | ```bash 122 | torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-L/2 --data-path /path/to/imagenet/train --ckpt /path/to/model.pt 123 | ``` 124 | 125 | **Caution.** Resuming training will automatically restore both model, EMA, and optimizer states and training configs to be the same as in the checkpoint. 126 | 127 | ## Evaluation (FID, Inception Score, etc.) 128 | 129 | We include a [`sample_ddp.py`](sample_ddp.py) script which samples a large number of images from a SiT model in parallel. This script 130 | generates a folder of samples as well as a `.npz` file which can be directly used with [ADM's TensorFlow 131 | evaluation suite](https://github.com/openai/guided-diffusion/tree/main/evaluations) to compute FID, Inception Score and 132 | other metrics. For example, to sample 50K images from our pre-trained SiT-XL/2 model over `N` GPUs under default ODE sampler settings, run: 133 | 134 | ```bash 135 | torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py ODE --model SiT-XL/2 --num-fid-samples 50000 136 | ``` 137 | 138 | **Likelihood.** Likelihood evaluation is supported. To calculate likelihood, you can add the `--likelihood` flag to ODE sampler: 139 | 140 | ```bash 141 | torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py ODE --model SiT-XL/2 --likelihood 142 | ``` 143 | 144 | Notice that only under ODE sampler likelihood can be calculated; see [`sample_ddp.py`](sample_ddp.py) for more details and settings. 145 | 146 | ### Enhancements 147 | Training (and sampling) could likely be speed-up significantly by: 148 | - [ ] using [Flash Attention](https://github.com/HazyResearch/flash-attention) in the SiT model 149 | - [ ] using `torch.compile` in PyTorch 2.0 150 | 151 | Basic features that would be nice to add: 152 | - [ ] Monitor FID and other metrics 153 | - [ ] AMP/bfloat16 support 154 | 155 | Precision in likelihood calculation could likely be improved by: 156 | - [ ] Uniform / Gaussian Dequantization 157 | 158 | 159 | ## Differences from JAX 160 | 161 | Our models were originally trained in JAX on TPUs. The weights in this repo are ported directly from the JAX models. 162 | There may be minor differences in results stemming from sampling on different platforms (TPU vs. GPU). We observed that sampling on TPU performs marginally worse than GPU (2.15 FID 163 | versus 2.06 in the paper). 164 | 165 | 166 | ## License 167 | This project is under the MIT license. See [LICENSE](LICENSE.txt) for details. 168 | 169 | 170 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | 4 | """ 5 | Functions for downloading pre-trained SiT models 6 | """ 7 | from torchvision.datasets.utils import download_url 8 | import torch 9 | import os 10 | 11 | 12 | pretrained_models = {'SiT-XL-2-256x256.pt'} 13 | 14 | 15 | def find_model(model_name): 16 | """ 17 | Finds a pre-trained SiT model, downloading it if necessary. Alternatively, loads a model from a local path. 18 | """ 19 | if model_name in pretrained_models: 20 | return download_model(model_name) 21 | else: 22 | assert os.path.isfile(model_name), f'Could not find SiT checkpoint at {model_name}' 23 | checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage) 24 | if "ema" in checkpoint: # supports checkpoints from train.py 25 | checkpoint = checkpoint["ema"] 26 | return checkpoint 27 | 28 | 29 | def download_model(model_name): 30 | """ 31 | Downloads a pre-trained SiT model from the web. 32 | """ 33 | assert model_name in pretrained_models 34 | local_path = f'pretrained_models/{model_name}' 35 | if not os.path.isfile(local_path): 36 | os.makedirs('pretrained_models', exist_ok=True) 37 | web_path = f'https://www.dl.dropboxusercontent.com/scl/fi/as9oeomcbub47de5g4be0/SiT-XL-2-256.pt?rlkey=uxzxmpicu46coq3msb17b9ofa&dl=0' 38 | download_url(web_path, 'pretrained_models', filename=model_name) 39 | model = torch.load(local_path, map_location=lambda storage, loc: storage) 40 | return model 41 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: SiT 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python >= 3.8 7 | - pytorch >= 1.13 8 | - torchvision 9 | - pytorch-cuda >=11.7 10 | - pip 11 | - pip: 12 | - timm 13 | - diffusers 14 | - accelerate 15 | - torchdiffeq 16 | - wandb 17 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | # -------------------------------------------------------- 4 | # References: 5 | # GLIDE: https://github.com/openai/glide-text2im 6 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py 7 | # -------------------------------------------------------- 8 | 9 | import torch 10 | import torch.nn as nn 11 | import numpy as np 12 | import math 13 | from timm.models.vision_transformer import PatchEmbed, Attention, Mlp 14 | 15 | 16 | def modulate(x, shift, scale): 17 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 18 | 19 | 20 | ################################################################################# 21 | # Embedding Layers for Timesteps and Class Labels # 22 | ################################################################################# 23 | 24 | class TimestepEmbedder(nn.Module): 25 | """ 26 | Embeds scalar timesteps into vector representations. 27 | """ 28 | def __init__(self, hidden_size, frequency_embedding_size=256): 29 | super().__init__() 30 | self.mlp = nn.Sequential( 31 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 32 | nn.SiLU(), 33 | nn.Linear(hidden_size, hidden_size, bias=True), 34 | ) 35 | self.frequency_embedding_size = frequency_embedding_size 36 | 37 | @staticmethod 38 | def timestep_embedding(t, dim, max_period=10000): 39 | """ 40 | Create sinusoidal timestep embeddings. 41 | :param t: a 1-D Tensor of N indices, one per batch element. 42 | These may be fractional. 43 | :param dim: the dimension of the output. 44 | :param max_period: controls the minimum frequency of the embeddings. 45 | :return: an (N, D) Tensor of positional embeddings. 46 | """ 47 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 48 | half = dim // 2 49 | freqs = torch.exp( 50 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 51 | ).to(device=t.device) 52 | args = t[:, None].float() * freqs[None] 53 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 54 | if dim % 2: 55 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 56 | return embedding 57 | 58 | def forward(self, t): 59 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 60 | t_emb = self.mlp(t_freq) 61 | return t_emb 62 | 63 | 64 | class LabelEmbedder(nn.Module): 65 | """ 66 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 67 | """ 68 | def __init__(self, num_classes, hidden_size, dropout_prob): 69 | super().__init__() 70 | use_cfg_embedding = dropout_prob > 0 71 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) 72 | self.num_classes = num_classes 73 | self.dropout_prob = dropout_prob 74 | 75 | def token_drop(self, labels, force_drop_ids=None): 76 | """ 77 | Drops labels to enable classifier-free guidance. 78 | """ 79 | if force_drop_ids is None: 80 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob 81 | else: 82 | drop_ids = force_drop_ids == 1 83 | labels = torch.where(drop_ids, self.num_classes, labels) 84 | return labels 85 | 86 | def forward(self, labels, train, force_drop_ids=None): 87 | use_dropout = self.dropout_prob > 0 88 | if (train and use_dropout) or (force_drop_ids is not None): 89 | labels = self.token_drop(labels, force_drop_ids) 90 | embeddings = self.embedding_table(labels) 91 | return embeddings 92 | 93 | 94 | ################################################################################# 95 | # Core SiT Model # 96 | ################################################################################# 97 | 98 | class SiTBlock(nn.Module): 99 | """ 100 | A SiT block with adaptive layer norm zero (adaLN-Zero) conditioning. 101 | """ 102 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): 103 | super().__init__() 104 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 105 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) 106 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 107 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 108 | approx_gelu = lambda: nn.GELU(approximate="tanh") 109 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) 110 | self.adaLN_modulation = nn.Sequential( 111 | nn.SiLU(), 112 | nn.Linear(hidden_size, 6 * hidden_size, bias=True) 113 | ) 114 | 115 | def forward(self, x, c): 116 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) 117 | x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) 118 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) 119 | return x 120 | 121 | 122 | class FinalLayer(nn.Module): 123 | """ 124 | The final layer of SiT. 125 | """ 126 | def __init__(self, hidden_size, patch_size, out_channels): 127 | super().__init__() 128 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 129 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 130 | self.adaLN_modulation = nn.Sequential( 131 | nn.SiLU(), 132 | nn.Linear(hidden_size, 2 * hidden_size, bias=True) 133 | ) 134 | 135 | def forward(self, x, c): 136 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) 137 | x = modulate(self.norm_final(x), shift, scale) 138 | x = self.linear(x) 139 | return x 140 | 141 | 142 | class SiT(nn.Module): 143 | """ 144 | Diffusion model with a Transformer backbone. 145 | """ 146 | def __init__( 147 | self, 148 | input_size=32, 149 | patch_size=2, 150 | in_channels=4, 151 | hidden_size=1152, 152 | depth=28, 153 | num_heads=16, 154 | mlp_ratio=4.0, 155 | class_dropout_prob=0.1, 156 | num_classes=1000, 157 | learn_sigma=True, 158 | ): 159 | super().__init__() 160 | self.learn_sigma = learn_sigma 161 | self.in_channels = in_channels 162 | self.out_channels = in_channels * 2 if learn_sigma else in_channels 163 | self.patch_size = patch_size 164 | self.num_heads = num_heads 165 | 166 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) 167 | self.t_embedder = TimestepEmbedder(hidden_size) 168 | self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) 169 | num_patches = self.x_embedder.num_patches 170 | # Will use fixed sin-cos embedding: 171 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) 172 | 173 | self.blocks = nn.ModuleList([ 174 | SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) 175 | ]) 176 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) 177 | self.initialize_weights() 178 | 179 | def initialize_weights(self): 180 | # Initialize transformer layers: 181 | def _basic_init(module): 182 | if isinstance(module, nn.Linear): 183 | torch.nn.init.xavier_uniform_(module.weight) 184 | if module.bias is not None: 185 | nn.init.constant_(module.bias, 0) 186 | self.apply(_basic_init) 187 | 188 | # Initialize (and freeze) pos_embed by sin-cos embedding: 189 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) 190 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 191 | 192 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): 193 | w = self.x_embedder.proj.weight.data 194 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 195 | nn.init.constant_(self.x_embedder.proj.bias, 0) 196 | 197 | # Initialize label embedding table: 198 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) 199 | 200 | # Initialize timestep embedding MLP: 201 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 202 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 203 | 204 | # Zero-out adaLN modulation layers in SiT blocks: 205 | for block in self.blocks: 206 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 207 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 208 | 209 | # Zero-out output layers: 210 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 211 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 212 | nn.init.constant_(self.final_layer.linear.weight, 0) 213 | nn.init.constant_(self.final_layer.linear.bias, 0) 214 | 215 | def unpatchify(self, x): 216 | """ 217 | x: (N, T, patch_size**2 * C) 218 | imgs: (N, H, W, C) 219 | """ 220 | c = self.out_channels 221 | p = self.x_embedder.patch_size[0] 222 | h = w = int(x.shape[1] ** 0.5) 223 | assert h * w == x.shape[1] 224 | 225 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) 226 | x = torch.einsum('nhwpqc->nchpwq', x) 227 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) 228 | return imgs 229 | 230 | def forward(self, x, t, y): 231 | """ 232 | Forward pass of SiT. 233 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) 234 | t: (N,) tensor of diffusion timesteps 235 | y: (N,) tensor of class labels 236 | """ 237 | x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 238 | t = self.t_embedder(t) # (N, D) 239 | y = self.y_embedder(y, self.training) # (N, D) 240 | c = t + y # (N, D) 241 | for block in self.blocks: 242 | x = block(x, c) # (N, T, D) 243 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) 244 | x = self.unpatchify(x) # (N, out_channels, H, W) 245 | if self.learn_sigma: 246 | x, _ = x.chunk(2, dim=1) 247 | return x 248 | 249 | def forward_with_cfg(self, x, t, y, cfg_scale): 250 | """ 251 | Forward pass of SiT, but also batches the unconSiTional forward pass for classifier-free guidance. 252 | """ 253 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb 254 | half = x[: len(x) // 2] 255 | combined = torch.cat([half, half], dim=0) 256 | model_out = self.forward(combined, t, y) 257 | # For exact reproducibility reasons, we apply classifier-free guidance on only 258 | # three channels by default. The standard approach to cfg applies it to all channels. 259 | # This can be done by uncommenting the following line and commenting-out the line following that. 260 | # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] 261 | eps, rest = model_out[:, :3], model_out[:, 3:] 262 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 263 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) 264 | eps = torch.cat([half_eps, half_eps], dim=0) 265 | return torch.cat([eps, rest], dim=1) 266 | 267 | 268 | ################################################################################# 269 | # Sine/Cosine Positional Embedding Functions # 270 | ################################################################################# 271 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 272 | 273 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 274 | """ 275 | grid_size: int of the grid height and width 276 | return: 277 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 278 | """ 279 | grid_h = np.arange(grid_size, dtype=np.float32) 280 | grid_w = np.arange(grid_size, dtype=np.float32) 281 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 282 | grid = np.stack(grid, axis=0) 283 | 284 | grid = grid.reshape([2, 1, grid_size, grid_size]) 285 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 286 | if cls_token and extra_tokens > 0: 287 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 288 | return pos_embed 289 | 290 | 291 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 292 | assert embed_dim % 2 == 0 293 | 294 | # use half of dimensions to encode grid_h 295 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 296 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 297 | 298 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 299 | return emb 300 | 301 | 302 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 303 | """ 304 | embed_dim: output dimension for each position 305 | pos: a list of positions to be encoded: size (M,) 306 | out: (M, D) 307 | """ 308 | assert embed_dim % 2 == 0 309 | omega = np.arange(embed_dim // 2, dtype=np.float64) 310 | omega /= embed_dim / 2. 311 | omega = 1. / 10000**omega # (D/2,) 312 | 313 | pos = pos.reshape(-1) # (M,) 314 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 315 | 316 | emb_sin = np.sin(out) # (M, D/2) 317 | emb_cos = np.cos(out) # (M, D/2) 318 | 319 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 320 | return emb 321 | 322 | 323 | ################################################################################# 324 | # SiT Configs # 325 | ################################################################################# 326 | 327 | def SiT_XL_2(**kwargs): 328 | return SiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) 329 | 330 | def SiT_XL_4(**kwargs): 331 | return SiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs) 332 | 333 | def SiT_XL_8(**kwargs): 334 | return SiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs) 335 | 336 | def SiT_L_2(**kwargs): 337 | return SiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) 338 | 339 | def SiT_L_4(**kwargs): 340 | return SiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs) 341 | 342 | def SiT_L_8(**kwargs): 343 | return SiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs) 344 | 345 | def SiT_B_2(**kwargs): 346 | return SiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) 347 | 348 | def SiT_B_4(**kwargs): 349 | return SiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs) 350 | 351 | def SiT_B_8(**kwargs): 352 | return SiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs) 353 | 354 | def SiT_S_2(**kwargs): 355 | return SiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) 356 | 357 | def SiT_S_4(**kwargs): 358 | return SiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) 359 | 360 | def SiT_S_8(**kwargs): 361 | return SiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) 362 | 363 | 364 | SiT_models = { 365 | 'SiT-XL/2': SiT_XL_2, 'SiT-XL/4': SiT_XL_4, 'SiT-XL/8': SiT_XL_8, 366 | 'SiT-L/2': SiT_L_2, 'SiT-L/4': SiT_L_4, 'SiT-L/8': SiT_L_8, 367 | 'SiT-B/2': SiT_B_2, 'SiT-B/4': SiT_B_4, 'SiT-B/8': SiT_B_8, 368 | 'SiT-S/2': SiT_S_2, 'SiT-S/4': SiT_S_4, 'SiT-S/8': SiT_S_8, 369 | } 370 | -------------------------------------------------------------------------------- /run_SiT.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "355UKMUQJxFd" 7 | }, 8 | "source": [ 9 | "# SiT: Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers\n", 10 | "\n", 11 | "This notebook samples from pre-trained SiT models. SiTs are class-conSiTional latent interpolant models trained on ImageNet, unifying Flow and Diffusion Methods. \n", 12 | "\n", 13 | "[Paper]() | [GitHub](github.com/willisma/SiT)" 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "metadata": { 19 | "id": "zJlgLkSaKn7u" 20 | }, 21 | "source": [ 22 | "# 1. Setup\n", 23 | "\n", 24 | "We recommend using GPUs (Runtime > Change runtime type > Hardware accelerator > GPU). Run this cell to clone the SiT GitHub repo and setup PyTorch. You only have to run this once." 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "!git clone https://github.com/willisma/SiT.git\n", 34 | "import SiT, os\n", 35 | "os.chdir('SiT')\n", 36 | "os.environ['PYTHONPATH'] = '/env/python:/content/SiT'\n", 37 | "!pip install diffusers timm torchdiffeq --upgrade\n", 38 | "# SiT imports:\n", 39 | "import torch\n", 40 | "from torchvision.utils import save_image\n", 41 | "from transport import create_transport, Sampler\n", 42 | "from diffusers.models import AutoencoderKL\n", 43 | "from download import find_model\n", 44 | "from models import SiT_XL_2\n", 45 | "from PIL import Image\n", 46 | "from IPython.display import display\n", 47 | "torch.set_grad_enabled(False)\n", 48 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", 49 | "if device == \"cpu\":\n", 50 | " print(\"GPU not found. Using CPU instead.\")" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": { 56 | "id": "AXpziRkoOvV9" 57 | }, 58 | "source": [ 59 | "# Download SiT-XL/2 Models" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": { 66 | "id": "EWG-WNimO59K" 67 | }, 68 | "outputs": [], 69 | "source": [ 70 | "image_size = \"256\"\n", 71 | "vae_model = \"stabilityai/sd-vae-ft-ema\" #@param [\"stabilityai/sd-vae-ft-mse\", \"stabilityai/sd-vae-ft-ema\"]\n", 72 | "latent_size = int(image_size) // 8\n", 73 | "# Load model:\n", 74 | "model = SiT_XL_2(input_size=latent_size).to(device)\n", 75 | "state_dict = find_model(f\"SiT-XL-2-{image_size}x{image_size}.pt\")\n", 76 | "model.load_state_dict(state_dict)\n", 77 | "model.eval() # important!\n", 78 | "vae = AutoencoderKL.from_pretrained(vae_model).to(device)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": { 84 | "id": "5JTNyzNZKb9E" 85 | }, 86 | "source": [ 87 | "# 2. Sample from Pre-trained SiT Models\n", 88 | "\n", 89 | "You can customize several sampling options. For the full list of ImageNet classes, [check out this](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a)." 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": { 96 | "id": "-Hw7B5h4Kk4p" 97 | }, 98 | "outputs": [], 99 | "source": [ 100 | "# Set user inputs:\n", 101 | "seed = 0 #@param {type:\"number\"}\n", 102 | "torch.manual_seed(seed)\n", 103 | "num_sampling_steps = 250 #@param {type:\"slider\", min:0, max:1000, step:1}\n", 104 | "cfg_scale = 4 #@param {type:\"slider\", min:1, max:10, step:0.1}\n", 105 | "class_labels = 207, 360, 387, 974, 88, 979, 417, 279 #@param {type:\"raw\"}\n", 106 | "samples_per_row = 4 #@param {type:\"number\"}\n", 107 | "sampler_type = \"ODE\" #@param [\"ODE\", \"SDE\"]\n", 108 | "\n", 109 | "\n", 110 | "# Create diffusion object:\n", 111 | "transport = create_transport()\n", 112 | "sampler = Sampler(transport)\n", 113 | "\n", 114 | "# Create sampling noise:\n", 115 | "n = len(class_labels)\n", 116 | "z = torch.randn(n, 4, latent_size, latent_size, device=device)\n", 117 | "y = torch.tensor(class_labels, device=device)\n", 118 | "\n", 119 | "# Setup classifier-free guidance:\n", 120 | "z = torch.cat([z, z], 0)\n", 121 | "y_null = torch.tensor([1000] * n, device=device)\n", 122 | "y = torch.cat([y, y_null], 0)\n", 123 | "model_kwargs = dict(y=y, cfg_scale=cfg_scale)\n", 124 | "\n", 125 | "# Sample images:\n", 126 | "if sampler_type == \"SDE\":\n", 127 | " SDE_sampling_method = \"Euler\" #@param [\"Euler\", \"Heun\"]\n", 128 | " diffusion_form = \"linear\" #@param [\"constant\", \"SBDM\", \"sigma\", \"linear\", \"decreasing\", \"increasing-decreasing\"]\n", 129 | " diffusion_norm = 1 #@param {type:\"slider\", min:0, max:10.0, step:0.1}\n", 130 | " last_step = \"Mean\" #@param [\"Mean\", \"Tweedie\", \"Euler\"]\n", 131 | " last_step_size = 0.4 #@param {type:\"slider\", min:0, max:1.0, step:0.01}\n", 132 | " sample_fn = sampler.sample_sde(\n", 133 | " sampling_method=SDE_sampling_method,\n", 134 | " diffusion_form=diffusion_form, \n", 135 | " diffusion_norm=diffusion_norm,\n", 136 | " last_step_size=last_step_size, \n", 137 | " num_steps=num_sampling_steps,\n", 138 | " ) \n", 139 | "elif sampler_type == \"ODE\":\n", 140 | " # default to Adaptive Solver\n", 141 | " ODE_sampling_method = \"dopri5\" #@param [\"dopri5\", \"euler\", \"rk4\"]\n", 142 | " atol = 1e-6\n", 143 | " rtol = 1e-3\n", 144 | " sample_fn = sampler.sample_ode(\n", 145 | " sampling_method=ODE_sampling_method,\n", 146 | " atol=atol,\n", 147 | " rtol=rtol,\n", 148 | " num_steps=num_sampling_steps\n", 149 | " ) \n", 150 | "samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]\n", 151 | "samples = vae.decode(samples / 0.18215).sample\n", 152 | "\n", 153 | "# Save and display images:\n", 154 | "save_image(samples, \"sample.png\", nrow=int(samples_per_row), \n", 155 | " normalize=True, value_range=(-1, 1))\n", 156 | "samples = Image.open(\"sample.png\")\n", 157 | "display(samples)" 158 | ] 159 | } 160 | ], 161 | "metadata": { 162 | "colab": { 163 | "provenance": [] 164 | }, 165 | "kernelspec": { 166 | "display_name": "Python 3.8.10 64-bit", 167 | "language": "python", 168 | "name": "python3" 169 | }, 170 | "language_info": { 171 | "name": "python", 172 | "version": "3.8.10" 173 | }, 174 | "vscode": { 175 | "interpreter": { 176 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" 177 | } 178 | } 179 | }, 180 | "nbformat": 4, 181 | "nbformat_minor": 0 182 | } 183 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | 4 | """ 5 | Sample new images from a pre-trained SiT. 6 | """ 7 | import torch 8 | torch.backends.cuda.matmul.allow_tf32 = True 9 | torch.backends.cudnn.allow_tf32 = True 10 | from torchvision.utils import save_image 11 | from diffusers.models import AutoencoderKL 12 | from download import find_model 13 | from models import SiT_models 14 | from train_utils import parse_ode_args, parse_sde_args, parse_transport_args 15 | from transport import create_transport, Sampler 16 | import argparse 17 | import sys 18 | from time import time 19 | 20 | 21 | def main(mode, args): 22 | # Setup PyTorch: 23 | torch.manual_seed(args.seed) 24 | torch.set_grad_enabled(False) 25 | device = "cuda" if torch.cuda.is_available() else "cpu" 26 | 27 | if args.ckpt is None: 28 | assert args.model == "SiT-XL/2", "Only SiT-XL/2 models are available for auto-download." 29 | assert args.image_size in [256, 512] 30 | assert args.num_classes == 1000 31 | assert args.image_size == 256, "512x512 models are not yet available for auto-download." # remove this line when 512x512 models are available 32 | learn_sigma = args.image_size == 256 33 | else: 34 | learn_sigma = False 35 | 36 | # Load model: 37 | latent_size = args.image_size // 8 38 | model = SiT_models[args.model]( 39 | input_size=latent_size, 40 | num_classes=args.num_classes, 41 | learn_sigma=learn_sigma, 42 | ).to(device) 43 | # Auto-download a pre-trained model or load a custom SiT checkpoint from train.py: 44 | ckpt_path = args.ckpt or f"SiT-XL-2-{args.image_size}x{args.image_size}.pt" 45 | state_dict = find_model(ckpt_path) 46 | model.load_state_dict(state_dict) 47 | model.eval() # important! 48 | transport = create_transport( 49 | args.path_type, 50 | args.prediction, 51 | args.loss_weight, 52 | args.train_eps, 53 | args.sample_eps 54 | ) 55 | sampler = Sampler(transport) 56 | if mode == "ODE": 57 | if args.likelihood: 58 | assert args.cfg_scale == 1, "Likelihood is incompatible with guidance" 59 | sample_fn = sampler.sample_ode_likelihood( 60 | sampling_method=args.sampling_method, 61 | num_steps=args.num_sampling_steps, 62 | atol=args.atol, 63 | rtol=args.rtol, 64 | ) 65 | else: 66 | sample_fn = sampler.sample_ode( 67 | sampling_method=args.sampling_method, 68 | num_steps=args.num_sampling_steps, 69 | atol=args.atol, 70 | rtol=args.rtol, 71 | reverse=args.reverse 72 | ) 73 | 74 | elif mode == "SDE": 75 | sample_fn = sampler.sample_sde( 76 | sampling_method=args.sampling_method, 77 | diffusion_form=args.diffusion_form, 78 | diffusion_norm=args.diffusion_norm, 79 | last_step=args.last_step, 80 | last_step_size=args.last_step_size, 81 | num_steps=args.num_sampling_steps, 82 | ) 83 | 84 | 85 | vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) 86 | 87 | # Labels to condition the model with (feel free to change): 88 | class_labels = [207, 360, 387, 974, 88, 979, 417, 279] 89 | 90 | # Create sampling noise: 91 | n = len(class_labels) 92 | z = torch.randn(n, 4, latent_size, latent_size, device=device) 93 | y = torch.tensor(class_labels, device=device) 94 | 95 | # Setup classifier-free guidance: 96 | z = torch.cat([z, z], 0) 97 | y_null = torch.tensor([1000] * n, device=device) 98 | y = torch.cat([y, y_null], 0) 99 | model_kwargs = dict(y=y, cfg_scale=args.cfg_scale) 100 | 101 | # Sample images: 102 | start_time = time() 103 | samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1] 104 | samples, _ = samples.chunk(2, dim=0) # Remove null class samples 105 | samples = vae.decode(samples / 0.18215).sample 106 | print(f"Sampling took {time() - start_time:.2f} seconds.") 107 | 108 | # Save and display images: 109 | save_image(samples, "sample.png", nrow=4, normalize=True, value_range=(-1, 1)) 110 | 111 | 112 | if __name__ == "__main__": 113 | parser = argparse.ArgumentParser() 114 | 115 | if len(sys.argv) < 2: 116 | print("Usage: program.py [options]") 117 | sys.exit(1) 118 | 119 | mode = sys.argv[1] 120 | 121 | assert mode[:2] != "--", "Usage: program.py [options]" 122 | assert mode in ["ODE", "SDE"], "Invalid mode. Please choose 'ODE' or 'SDE'" 123 | 124 | parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2") 125 | parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="mse") 126 | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) 127 | parser.add_argument("--num-classes", type=int, default=1000) 128 | parser.add_argument("--cfg-scale", type=float, default=4.0) 129 | parser.add_argument("--num-sampling-steps", type=int, default=250) 130 | parser.add_argument("--seed", type=int, default=0) 131 | parser.add_argument("--ckpt", type=str, default=None, 132 | help="Optional path to a SiT checkpoint (default: auto-download a pre-trained SiT-XL/2 model).") 133 | 134 | 135 | parse_transport_args(parser) 136 | if mode == "ODE": 137 | parse_ode_args(parser) 138 | # Further processing for ODE 139 | elif mode == "SDE": 140 | parse_sde_args(parser) 141 | # Further processing for SDE 142 | 143 | args = parser.parse_known_args()[0] 144 | main(mode, args) 145 | -------------------------------------------------------------------------------- /sample_ddp.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | 4 | """ 5 | Samples a large number of images from a pre-trained SiT model using DDP. 6 | Subsequently saves a .npz file that can be used to compute FID and other 7 | evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations 8 | 9 | For a simple single-GPU/CPU sampling script, see sample.py. 10 | """ 11 | import torch 12 | import torch.distributed as dist 13 | from models import SiT_models 14 | from download import find_model 15 | from transport import create_transport, Sampler 16 | from diffusers.models import AutoencoderKL 17 | from train_utils import parse_ode_args, parse_sde_args, parse_transport_args 18 | from tqdm import tqdm 19 | import os 20 | from PIL import Image 21 | import numpy as np 22 | import math 23 | import argparse 24 | import sys 25 | 26 | 27 | def create_npz_from_sample_folder(sample_dir, num=50_000): 28 | """ 29 | Builds a single .npz file from a folder of .png samples. 30 | """ 31 | samples = [] 32 | for i in tqdm(range(num), desc="Building .npz file from samples"): 33 | sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") 34 | sample_np = np.asarray(sample_pil).astype(np.uint8) 35 | samples.append(sample_np) 36 | samples = np.stack(samples) 37 | assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) 38 | npz_path = f"{sample_dir}.npz" 39 | np.savez(npz_path, arr_0=samples) 40 | print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") 41 | return npz_path 42 | 43 | 44 | def main(mode, args): 45 | """ 46 | Run sampling. 47 | """ 48 | torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences 49 | assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" 50 | torch.set_grad_enabled(False) 51 | 52 | # Setup DDP: 53 | dist.init_process_group("nccl") 54 | rank = dist.get_rank() 55 | device = rank % torch.cuda.device_count() 56 | seed = args.global_seed * dist.get_world_size() + rank 57 | torch.manual_seed(seed) 58 | torch.cuda.set_device(device) 59 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") 60 | 61 | if args.ckpt is None: 62 | assert args.model == "SiT-XL/2", "Only SiT-XL/2 models are available for auto-download." 63 | assert args.image_size in [256, 512] 64 | assert args.num_classes == 1000 65 | assert args.image_size == 256, "512x512 models are not yet available for auto-download." # remove this line when 512x512 models are available 66 | learn_sigma = args.image_size == 256 67 | else: 68 | learn_sigma = False 69 | 70 | # Load model: 71 | latent_size = args.image_size // 8 72 | model = SiT_models[args.model]( 73 | input_size=latent_size, 74 | num_classes=args.num_classes, 75 | learn_sigma=learn_sigma, 76 | ).to(device) 77 | # Auto-download a pre-trained model or load a custom SiT checkpoint from train.py: 78 | ckpt_path = args.ckpt or f"SiT-XL-2-{args.image_size}x{args.image_size}.pt" 79 | state_dict = find_model(ckpt_path) 80 | model.load_state_dict(state_dict) 81 | model.eval() # important! 82 | 83 | 84 | transport = create_transport( 85 | args.path_type, 86 | args.prediction, 87 | args.loss_weight, 88 | args.train_eps, 89 | args.sample_eps 90 | ) 91 | sampler = Sampler(transport) 92 | if mode == "ODE": 93 | if args.likelihood: 94 | assert args.cfg_scale == 1, "Likelihood is incompatible with guidance" 95 | sample_fn = sampler.sample_ode_likelihood( 96 | sampling_method=args.sampling_method, 97 | num_steps=args.num_sampling_steps, 98 | atol=args.atol, 99 | rtol=args.rtol, 100 | ) 101 | else: 102 | sample_fn = sampler.sample_ode( 103 | sampling_method=args.sampling_method, 104 | num_steps=args.num_sampling_steps, 105 | atol=args.atol, 106 | rtol=args.rtol, 107 | reverse=args.reverse 108 | ) 109 | elif mode == "SDE": 110 | sample_fn = sampler.sample_sde( 111 | sampling_method=args.sampling_method, 112 | diffusion_form=args.diffusion_form, 113 | diffusion_norm=args.diffusion_norm, 114 | last_step=args.last_step, 115 | last_step_size=args.last_step_size, 116 | num_steps=args.num_sampling_steps, 117 | ) 118 | vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) 119 | assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0" 120 | using_cfg = args.cfg_scale > 1.0 121 | 122 | # Create folder to save samples: 123 | model_string_name = args.model.replace("/", "-") 124 | ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained" 125 | if mode == "ODE": 126 | folder_name = f"{model_string_name}-{ckpt_string_name}-" \ 127 | f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\ 128 | f"{mode}-{args.num_sampling_steps}-{args.sampling_method}" 129 | elif mode == "SDE": 130 | folder_name = f"{model_string_name}-{ckpt_string_name}-" \ 131 | f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\ 132 | f"{mode}-{args.num_sampling_steps}-{args.sampling_method}-"\ 133 | f"{args.diffusion_form}-{args.last_step}-{args.last_step_size}" 134 | sample_folder_dir = f"{args.sample_dir}/{folder_name}" 135 | if rank == 0: 136 | os.makedirs(sample_folder_dir, exist_ok=True) 137 | print(f"Saving .png samples at {sample_folder_dir}") 138 | dist.barrier() 139 | 140 | # Figure out how many samples we need to generate on each GPU and how many iterations we need to run: 141 | n = args.per_proc_batch_size 142 | global_batch_size = n * dist.get_world_size() 143 | # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples: 144 | num_samples = len([name for name in os.listdir(sample_folder_dir) if (os.path.isfile(os.path.join(sample_folder_dir, name)) and ".png" in name)]) 145 | total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size) 146 | if rank == 0: 147 | print(f"Total number of images that will be sampled: {total_samples}") 148 | assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size" 149 | samples_needed_this_gpu = int(total_samples // dist.get_world_size()) 150 | assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size" 151 | iterations = int(samples_needed_this_gpu // n) 152 | done_iterations = int( int(num_samples // dist.get_world_size()) // n) 153 | pbar = range(iterations) 154 | pbar = tqdm(pbar) if rank == 0 else pbar 155 | total = 0 156 | 157 | for i in pbar: 158 | # Sample inputs: 159 | z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device) 160 | y = torch.randint(0, args.num_classes, (n,), device=device) 161 | 162 | # Setup classifier-free guidance: 163 | if using_cfg: 164 | z = torch.cat([z, z], 0) 165 | y_null = torch.tensor([1000] * n, device=device) 166 | y = torch.cat([y, y_null], 0) 167 | model_kwargs = dict(y=y, cfg_scale=args.cfg_scale) 168 | model_fn = model.forward_with_cfg 169 | else: 170 | model_kwargs = dict(y=y) 171 | model_fn = model.forward 172 | 173 | samples = sample_fn(z, model_fn, **model_kwargs)[-1] 174 | if using_cfg: 175 | samples, _ = samples.chunk(2, dim=0) # Remove null class samples 176 | 177 | samples = vae.decode(samples / 0.18215).sample 178 | samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() 179 | 180 | # Save samples to disk as individual .png files 181 | for i, sample in enumerate(samples): 182 | index = i * dist.get_world_size() + rank + total 183 | Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png") 184 | total += global_batch_size 185 | dist.barrier() 186 | 187 | # Make sure all processes have finished saving their samples before attempting to convert to .npz 188 | dist.barrier() 189 | if rank == 0: 190 | create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples) 191 | print("Done.") 192 | dist.barrier() 193 | dist.destroy_process_group() 194 | 195 | 196 | if __name__ == "__main__": 197 | 198 | parser = argparse.ArgumentParser() 199 | 200 | if len(sys.argv) < 2: 201 | print("Usage: program.py [options]") 202 | sys.exit(1) 203 | 204 | mode = sys.argv[1] 205 | 206 | assert mode[:2] != "--", "Usage: program.py [options]" 207 | assert mode in ["ODE", "SDE"], "Invalid mode. Please choose 'ODE' or 'SDE'" 208 | 209 | parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2") 210 | parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") 211 | parser.add_argument("--sample-dir", type=str, default="samples") 212 | parser.add_argument("--per-proc-batch-size", type=int, default=4) 213 | parser.add_argument("--num-fid-samples", type=int, default=50_000) 214 | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) 215 | parser.add_argument("--num-classes", type=int, default=1000) 216 | parser.add_argument("--cfg-scale", type=float, default=1.0) 217 | parser.add_argument("--num-sampling-steps", type=int, default=250) 218 | parser.add_argument("--global-seed", type=int, default=0) 219 | parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True, 220 | help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.") 221 | parser.add_argument("--ckpt", type=str, default=None, 222 | help="Optional path to a SiT checkpoint (default: auto-download a pre-trained SiT-XL/2 model).") 223 | 224 | parse_transport_args(parser) 225 | if mode == "ODE": 226 | parse_ode_args(parser) 227 | # Further processing for ODE 228 | elif mode == "SDE": 229 | parse_sde_args(parser) 230 | # Further processing for SDE 231 | 232 | args = parser.parse_known_args()[0] 233 | main(mode, args) 234 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | 4 | """ 5 | A minimal training script for SiT using PyTorch DDP. 6 | """ 7 | import torch 8 | # the first flag below was False when we tested this script but True makes A100 training a lot faster: 9 | torch.backends.cuda.matmul.allow_tf32 = True 10 | torch.backends.cudnn.allow_tf32 = True 11 | import torch.distributed as dist 12 | from torch.nn.parallel import DistributedDataParallel as DDP 13 | from torch.utils.data import DataLoader 14 | from torch.utils.data.distributed import DistributedSampler 15 | from torchvision.datasets import ImageFolder 16 | from torchvision import transforms 17 | import numpy as np 18 | from collections import OrderedDict 19 | from PIL import Image 20 | from copy import deepcopy 21 | from glob import glob 22 | from time import time 23 | import argparse 24 | import logging 25 | import os 26 | 27 | from models import SiT_models 28 | from download import find_model 29 | from transport import create_transport, Sampler 30 | from diffusers.models import AutoencoderKL 31 | from train_utils import parse_transport_args 32 | import wandb_utils 33 | 34 | 35 | ################################################################################# 36 | # Training Helper Functions # 37 | ################################################################################# 38 | 39 | @torch.no_grad() 40 | def update_ema(ema_model, model, decay=0.9999): 41 | """ 42 | Step the EMA model towards the current model. 43 | """ 44 | ema_params = OrderedDict(ema_model.named_parameters()) 45 | model_params = OrderedDict(model.named_parameters()) 46 | 47 | for name, param in model_params.items(): 48 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed 49 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) 50 | 51 | 52 | def requires_grad(model, flag=True): 53 | """ 54 | Set requires_grad flag for all parameters in a model. 55 | """ 56 | for p in model.parameters(): 57 | p.requires_grad = flag 58 | 59 | 60 | def cleanup(): 61 | """ 62 | End DDP training. 63 | """ 64 | dist.destroy_process_group() 65 | 66 | 67 | def create_logger(logging_dir): 68 | """ 69 | Create a logger that writes to a log file and stdout. 70 | """ 71 | if dist.get_rank() == 0: # real logger 72 | logging.basicConfig( 73 | level=logging.INFO, 74 | format='[\033[34m%(asctime)s\033[0m] %(message)s', 75 | datefmt='%Y-%m-%d %H:%M:%S', 76 | handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] 77 | ) 78 | logger = logging.getLogger(__name__) 79 | else: # dummy logger (does nothing) 80 | logger = logging.getLogger(__name__) 81 | logger.addHandler(logging.NullHandler()) 82 | return logger 83 | 84 | 85 | def center_crop_arr(pil_image, image_size): 86 | """ 87 | Center cropping implementation from ADM. 88 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 89 | """ 90 | while min(*pil_image.size) >= 2 * image_size: 91 | pil_image = pil_image.resize( 92 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 93 | ) 94 | 95 | scale = image_size / min(*pil_image.size) 96 | pil_image = pil_image.resize( 97 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 98 | ) 99 | 100 | arr = np.array(pil_image) 101 | crop_y = (arr.shape[0] - image_size) // 2 102 | crop_x = (arr.shape[1] - image_size) // 2 103 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) 104 | 105 | 106 | ################################################################################# 107 | # Training Loop # 108 | ################################################################################# 109 | 110 | def main(args): 111 | """ 112 | Trains a new SiT model. 113 | """ 114 | assert torch.cuda.is_available(), "Training currently requires at least one GPU." 115 | 116 | # Setup DDP: 117 | dist.init_process_group("nccl") 118 | assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size." 119 | rank = dist.get_rank() 120 | device = rank % torch.cuda.device_count() 121 | seed = args.global_seed * dist.get_world_size() + rank 122 | torch.manual_seed(seed) 123 | torch.cuda.set_device(device) 124 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") 125 | local_batch_size = int(args.global_batch_size // dist.get_world_size()) 126 | 127 | # Setup an experiment folder: 128 | if rank == 0: 129 | os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders) 130 | experiment_index = len(glob(f"{args.results_dir}/*")) 131 | model_string_name = args.model.replace("/", "-") # e.g., SiT-XL/2 --> SiT-XL-2 (for naming folders) 132 | experiment_name = f"{experiment_index:03d}-{model_string_name}-" \ 133 | f"{args.path_type}-{args.prediction}-{args.loss_weight}" 134 | experiment_dir = f"{args.results_dir}/{experiment_name}" # Create an experiment folder 135 | checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints 136 | os.makedirs(checkpoint_dir, exist_ok=True) 137 | logger = create_logger(experiment_dir) 138 | logger.info(f"Experiment directory created at {experiment_dir}") 139 | 140 | entity = os.environ["ENTITY"] 141 | project = os.environ["PROJECT"] 142 | if args.wandb: 143 | wandb_utils.initialize(args, entity, experiment_name, project) 144 | else: 145 | logger = create_logger(None) 146 | 147 | # Create model: 148 | assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)." 149 | latent_size = args.image_size // 8 150 | model = SiT_models[args.model]( 151 | input_size=latent_size, 152 | num_classes=args.num_classes 153 | ) 154 | 155 | # Note that parameter initialization is done within the SiT constructor 156 | ema = deepcopy(model).to(device) # Create an EMA of the model for use after training 157 | 158 | if args.ckpt is not None: 159 | ckpt_path = args.ckpt 160 | state_dict = find_model(ckpt_path) 161 | model.load_state_dict(state_dict["model"]) 162 | ema.load_state_dict(state_dict["ema"]) 163 | opt.load_state_dict(state_dict["opt"]) 164 | args = state_dict["args"] 165 | 166 | requires_grad(ema, False) 167 | 168 | model = DDP(model.to(device), device_ids=[rank]) 169 | transport = create_transport( 170 | args.path_type, 171 | args.prediction, 172 | args.loss_weight, 173 | args.train_eps, 174 | args.sample_eps 175 | ) # default: velocity; 176 | transport_sampler = Sampler(transport) 177 | vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) 178 | logger.info(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}") 179 | 180 | # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper): 181 | opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0) 182 | 183 | # Setup data: 184 | transform = transforms.Compose([ 185 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), 186 | transforms.RandomHorizontalFlip(), 187 | transforms.ToTensor(), 188 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 189 | ]) 190 | dataset = ImageFolder(args.data_path, transform=transform) 191 | sampler = DistributedSampler( 192 | dataset, 193 | num_replicas=dist.get_world_size(), 194 | rank=rank, 195 | shuffle=True, 196 | seed=args.global_seed 197 | ) 198 | loader = DataLoader( 199 | dataset, 200 | batch_size=local_batch_size, 201 | shuffle=False, 202 | sampler=sampler, 203 | num_workers=args.num_workers, 204 | pin_memory=True, 205 | drop_last=True 206 | ) 207 | logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})") 208 | 209 | # Prepare models for training: 210 | update_ema(ema, model.module, decay=0) # Ensure EMA is initialized with synced weights 211 | model.train() # important! This enables embedding dropout for classifier-free guidance 212 | ema.eval() # EMA model should always be in eval mode 213 | 214 | # Variables for monitoring/logging purposes: 215 | train_steps = 0 216 | log_steps = 0 217 | running_loss = 0 218 | start_time = time() 219 | 220 | # Labels to condition the model with (feel free to change): 221 | ys = torch.randint(1000, size=(local_batch_size,), device=device) 222 | use_cfg = args.cfg_scale > 1.0 223 | # Create sampling noise: 224 | n = ys.size(0) 225 | zs = torch.randn(n, 4, latent_size, latent_size, device=device) 226 | 227 | # Setup classifier-free guidance: 228 | if use_cfg: 229 | zs = torch.cat([zs, zs], 0) 230 | y_null = torch.tensor([1000] * n, device=device) 231 | ys = torch.cat([ys, y_null], 0) 232 | sample_model_kwargs = dict(y=ys, cfg_scale=args.cfg_scale) 233 | model_fn = ema.forward_with_cfg 234 | else: 235 | sample_model_kwargs = dict(y=ys) 236 | model_fn = ema.forward 237 | 238 | logger.info(f"Training for {args.epochs} epochs...") 239 | for epoch in range(args.epochs): 240 | sampler.set_epoch(epoch) 241 | logger.info(f"Beginning epoch {epoch}...") 242 | for x, y in loader: 243 | x = x.to(device) 244 | y = y.to(device) 245 | with torch.no_grad(): 246 | # Map input images to latent space + normalize latents: 247 | x = vae.encode(x).latent_dist.sample().mul_(0.18215) 248 | model_kwargs = dict(y=y) 249 | loss_dict = transport.training_losses(model, x, model_kwargs) 250 | loss = loss_dict["loss"].mean() 251 | opt.zero_grad() 252 | loss.backward() 253 | opt.step() 254 | update_ema(ema, model.module) 255 | 256 | # Log loss values: 257 | running_loss += loss.item() 258 | log_steps += 1 259 | train_steps += 1 260 | if train_steps % args.log_every == 0: 261 | # Measure training speed: 262 | torch.cuda.synchronize() 263 | end_time = time() 264 | steps_per_sec = log_steps / (end_time - start_time) 265 | # Reduce loss history over all processes: 266 | avg_loss = torch.tensor(running_loss / log_steps, device=device) 267 | dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) 268 | avg_loss = avg_loss.item() / dist.get_world_size() 269 | logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}") 270 | if args.wandb: 271 | wandb_utils.log( 272 | { "train loss": avg_loss, "train steps/sec": steps_per_sec }, 273 | step=train_steps 274 | ) 275 | # Reset monitoring variables: 276 | running_loss = 0 277 | log_steps = 0 278 | start_time = time() 279 | 280 | # Save SiT checkpoint: 281 | if train_steps % args.ckpt_every == 0 and train_steps > 0: 282 | if rank == 0: 283 | checkpoint = { 284 | "model": model.module.state_dict(), 285 | "ema": ema.state_dict(), 286 | "opt": opt.state_dict(), 287 | "args": args 288 | } 289 | checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt" 290 | torch.save(checkpoint, checkpoint_path) 291 | logger.info(f"Saved checkpoint to {checkpoint_path}") 292 | dist.barrier() 293 | 294 | if train_steps % args.sample_every == 0 and train_steps > 0: 295 | logger.info("Generating EMA samples...") 296 | sample_fn = transport_sampler.sample_ode() # default to ode sampling 297 | samples = sample_fn(zs, model_fn, **sample_model_kwargs)[-1] 298 | dist.barrier() 299 | 300 | if use_cfg: #remove null samples 301 | samples, _ = samples.chunk(2, dim=0) 302 | samples = vae.decode(samples / 0.18215).sample 303 | out_samples = torch.zeros((args.global_batch_size, 3, args.image_size, args.image_size), device=device) 304 | dist.all_gather_into_tensor(out_samples, samples) 305 | if args.wandb: 306 | wandb_utils.log_image(out_samples, train_steps) 307 | logging.info("Generating EMA samples done.") 308 | 309 | model.eval() # important! This disables randomized embedding dropout 310 | # do any sampling/FID calculation/etc. with ema (or model) in eval mode ... 311 | 312 | logger.info("Done!") 313 | cleanup() 314 | 315 | 316 | if __name__ == "__main__": 317 | # Default args here will train SiT-XL/2 with the hyperparameters we used in our paper (except training iters). 318 | parser = argparse.ArgumentParser() 319 | parser.add_argument("--data-path", type=str, required=True) 320 | parser.add_argument("--results-dir", type=str, default="results") 321 | parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2") 322 | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) 323 | parser.add_argument("--num-classes", type=int, default=1000) 324 | parser.add_argument("--epochs", type=int, default=1400) 325 | parser.add_argument("--global-batch-size", type=int, default=256) 326 | parser.add_argument("--global-seed", type=int, default=0) 327 | parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") # Choice doesn't affect training 328 | parser.add_argument("--num-workers", type=int, default=4) 329 | parser.add_argument("--log-every", type=int, default=100) 330 | parser.add_argument("--ckpt-every", type=int, default=50_000) 331 | parser.add_argument("--sample-every", type=int, default=10_000) 332 | parser.add_argument("--cfg-scale", type=float, default=4.0) 333 | parser.add_argument("--wandb", action="store_true") 334 | parser.add_argument("--ckpt", type=str, default=None, 335 | help="Optional path to a custom SiT checkpoint") 336 | 337 | parse_transport_args(parser) 338 | args = parser.parse_args() 339 | main(args) 340 | -------------------------------------------------------------------------------- /train_utils.py: -------------------------------------------------------------------------------- 1 | def none_or_str(value): 2 | if value == 'None': 3 | return None 4 | return value 5 | 6 | def parse_transport_args(parser): 7 | group = parser.add_argument_group("Transport arguments") 8 | group.add_argument("--path-type", type=str, default="Linear", choices=["Linear", "GVP", "VP"]) 9 | group.add_argument("--prediction", type=str, default="velocity", choices=["velocity", "score", "noise"]) 10 | group.add_argument("--loss-weight", type=none_or_str, default=None, choices=[None, "velocity", "likelihood"]) 11 | group.add_argument("--sample-eps", type=float) 12 | group.add_argument("--train-eps", type=float) 13 | 14 | def parse_ode_args(parser): 15 | group = parser.add_argument_group("ODE arguments") 16 | group.add_argument("--sampling-method", type=str, default="dopri5", help="blackbox ODE solver methods; for full list check https://github.com/rtqichen/torchdiffeq") 17 | group.add_argument("--atol", type=float, default=1e-6, help="Absolute tolerance") 18 | group.add_argument("--rtol", type=float, default=1e-3, help="Relative tolerance") 19 | group.add_argument("--reverse", action="store_true") 20 | group.add_argument("--likelihood", action="store_true") 21 | 22 | def parse_sde_args(parser): 23 | group = parser.add_argument_group("SDE arguments") 24 | group.add_argument("--sampling-method", type=str, default="Euler", choices=["Euler", "Heun"]) 25 | group.add_argument("--diffusion-form", type=str, default="sigma", \ 26 | choices=["constant", "SBDM", "sigma", "linear", "decreasing", "increasing-decreasing"],\ 27 | help="form of diffusion coefficient in the SDE") 28 | group.add_argument("--diffusion-norm", type=float, default=1.0) 29 | group.add_argument("--last-step", type=none_or_str, default="Mean", choices=[None, "Mean", "Tweedie", "Euler"],\ 30 | help="form of last step taken in the SDE") 31 | group.add_argument("--last-step-size", type=float, default=0.04, \ 32 | help="size of the last step taken") -------------------------------------------------------------------------------- /transport/__init__.py: -------------------------------------------------------------------------------- 1 | from .transport import Transport, ModelType, WeightType, PathType, Sampler 2 | 3 | def create_transport( 4 | path_type='Linear', 5 | prediction="velocity", 6 | loss_weight=None, 7 | train_eps=None, 8 | sample_eps=None, 9 | ): 10 | """function for creating Transport object 11 | **Note**: model prediction defaults to velocity 12 | Args: 13 | - path_type: type of path to use; default to linear 14 | - learn_score: set model prediction to score 15 | - learn_noise: set model prediction to noise 16 | - velocity_weighted: weight loss by velocity weight 17 | - likelihood_weighted: weight loss by likelihood weight 18 | - train_eps: small epsilon for avoiding instability during training 19 | - sample_eps: small epsilon for avoiding instability during sampling 20 | """ 21 | 22 | if prediction == "noise": 23 | model_type = ModelType.NOISE 24 | elif prediction == "score": 25 | model_type = ModelType.SCORE 26 | else: 27 | model_type = ModelType.VELOCITY 28 | 29 | if loss_weight == "velocity": 30 | loss_type = WeightType.VELOCITY 31 | elif loss_weight == "likelihood": 32 | loss_type = WeightType.LIKELIHOOD 33 | else: 34 | loss_type = WeightType.NONE 35 | 36 | path_choice = { 37 | "Linear": PathType.LINEAR, 38 | "GVP": PathType.GVP, 39 | "VP": PathType.VP, 40 | } 41 | 42 | path_type = path_choice[path_type] 43 | 44 | if (path_type in [PathType.VP]): 45 | train_eps = 1e-5 if train_eps is None else train_eps 46 | sample_eps = 1e-3 if train_eps is None else sample_eps 47 | elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY): 48 | train_eps = 1e-3 if train_eps is None else train_eps 49 | sample_eps = 1e-3 if train_eps is None else sample_eps 50 | else: # velocity & [GVP, LINEAR] is stable everywhere 51 | train_eps = 0 52 | sample_eps = 0 53 | 54 | # create flow state 55 | state = Transport( 56 | model_type=model_type, 57 | path_type=path_type, 58 | loss_type=loss_type, 59 | train_eps=train_eps, 60 | sample_eps=sample_eps, 61 | ) 62 | 63 | return state -------------------------------------------------------------------------------- /transport/integrators.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | import torch.nn as nn 4 | from torchdiffeq import odeint 5 | from functools import partial 6 | from tqdm import tqdm 7 | 8 | class sde: 9 | """SDE solver class""" 10 | def __init__( 11 | self, 12 | drift, 13 | diffusion, 14 | *, 15 | t0, 16 | t1, 17 | num_steps, 18 | sampler_type, 19 | ): 20 | assert t0 < t1, "SDE sampler has to be in forward time" 21 | 22 | self.num_timesteps = num_steps 23 | self.t = th.linspace(t0, t1, num_steps) 24 | self.dt = self.t[1] - self.t[0] 25 | self.drift = drift 26 | self.diffusion = diffusion 27 | self.sampler_type = sampler_type 28 | 29 | def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs): 30 | w_cur = th.randn(x.size()).to(x) 31 | t = th.ones(x.size(0)).to(x) * t 32 | dw = w_cur * th.sqrt(self.dt) 33 | drift = self.drift(x, t, model, **model_kwargs) 34 | diffusion = self.diffusion(x, t) 35 | mean_x = x + drift * self.dt 36 | x = mean_x + th.sqrt(2 * diffusion) * dw 37 | return x, mean_x 38 | 39 | def __Heun_step(self, x, _, t, model, **model_kwargs): 40 | w_cur = th.randn(x.size()).to(x) 41 | dw = w_cur * th.sqrt(self.dt) 42 | t_cur = th.ones(x.size(0)).to(x) * t 43 | diffusion = self.diffusion(x, t_cur) 44 | xhat = x + th.sqrt(2 * diffusion) * dw 45 | K1 = self.drift(xhat, t_cur, model, **model_kwargs) 46 | xp = xhat + self.dt * K1 47 | K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs) 48 | return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step 49 | 50 | def __forward_fn(self): 51 | """TODO: generalize here by adding all private functions ending with steps to it""" 52 | sampler_dict = { 53 | "Euler": self.__Euler_Maruyama_step, 54 | "Heun": self.__Heun_step, 55 | } 56 | 57 | try: 58 | sampler = sampler_dict[self.sampler_type] 59 | except: 60 | raise NotImplementedError("Smapler type not implemented.") 61 | 62 | return sampler 63 | 64 | def sample(self, init, model, **model_kwargs): 65 | """forward loop of sde""" 66 | x = init 67 | mean_x = init 68 | samples = [] 69 | sampler = self.__forward_fn() 70 | for ti in self.t[:-1]: 71 | with th.no_grad(): 72 | x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs) 73 | samples.append(x) 74 | 75 | return samples 76 | 77 | class ode: 78 | """ODE solver class""" 79 | def __init__( 80 | self, 81 | drift, 82 | *, 83 | t0, 84 | t1, 85 | sampler_type, 86 | num_steps, 87 | atol, 88 | rtol, 89 | ): 90 | assert t0 < t1, "ODE sampler has to be in forward time" 91 | 92 | self.drift = drift 93 | self.t = th.linspace(t0, t1, num_steps) 94 | self.atol = atol 95 | self.rtol = rtol 96 | self.sampler_type = sampler_type 97 | 98 | def sample(self, x, model, **model_kwargs): 99 | 100 | device = x[0].device if isinstance(x, tuple) else x.device 101 | def _fn(t, x): 102 | t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t 103 | model_output = self.drift(x, t, model, **model_kwargs) 104 | return model_output 105 | 106 | t = self.t.to(device) 107 | atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol] 108 | rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol] 109 | samples = odeint( 110 | _fn, 111 | x, 112 | t, 113 | method=self.sampler_type, 114 | atol=atol, 115 | rtol=rtol 116 | ) 117 | return samples -------------------------------------------------------------------------------- /transport/path.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | from functools import partial 4 | 5 | def expand_t_like_x(t, x): 6 | """Function to reshape time t to broadcastable dimension of x 7 | Args: 8 | t: [batch_dim,], time vector 9 | x: [batch_dim,...], data point 10 | """ 11 | dims = [1] * (len(x.size()) - 1) 12 | t = t.view(t.size(0), *dims) 13 | return t 14 | 15 | 16 | #################### Coupling Plans #################### 17 | 18 | class ICPlan: 19 | """Linear Coupling Plan""" 20 | def __init__(self, sigma=0.0): 21 | self.sigma = sigma 22 | 23 | def compute_alpha_t(self, t): 24 | """Compute the data coefficient along the path""" 25 | return t, 1 26 | 27 | def compute_sigma_t(self, t): 28 | """Compute the noise coefficient along the path""" 29 | return 1 - t, -1 30 | 31 | def compute_d_alpha_alpha_ratio_t(self, t): 32 | """Compute the ratio between d_alpha and alpha""" 33 | return 1 / t 34 | 35 | def compute_drift(self, x, t): 36 | """We always output sde according to score parametrization; """ 37 | t = expand_t_like_x(t, x) 38 | alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t) 39 | sigma_t, d_sigma_t = self.compute_sigma_t(t) 40 | drift = alpha_ratio * x 41 | diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t 42 | 43 | return -drift, diffusion 44 | 45 | def compute_diffusion(self, x, t, form="constant", norm=1.0): 46 | """Compute the diffusion term of the SDE 47 | Args: 48 | x: [batch_dim, ...], data point 49 | t: [batch_dim,], time vector 50 | form: str, form of the diffusion term 51 | norm: float, norm of the diffusion term 52 | """ 53 | t = expand_t_like_x(t, x) 54 | choices = { 55 | "constant": norm, 56 | "SBDM": norm * self.compute_drift(x, t)[1], 57 | "sigma": norm * self.compute_sigma_t(t)[0], 58 | "linear": norm * (1 - t), 59 | "decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2, 60 | "inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2, 61 | } 62 | 63 | try: 64 | diffusion = choices[form] 65 | except KeyError: 66 | raise NotImplementedError(f"Diffusion form {form} not implemented") 67 | 68 | return diffusion 69 | 70 | def get_score_from_velocity(self, velocity, x, t): 71 | """Wrapper function: transfrom velocity prediction model to score 72 | Args: 73 | velocity: [batch_dim, ...] shaped tensor; velocity model output 74 | x: [batch_dim, ...] shaped tensor; x_t data point 75 | t: [batch_dim,] time tensor 76 | """ 77 | t = expand_t_like_x(t, x) 78 | alpha_t, d_alpha_t = self.compute_alpha_t(t) 79 | sigma_t, d_sigma_t = self.compute_sigma_t(t) 80 | mean = x 81 | reverse_alpha_ratio = alpha_t / d_alpha_t 82 | var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t 83 | score = (reverse_alpha_ratio * velocity - mean) / var 84 | return score 85 | 86 | def get_noise_from_velocity(self, velocity, x, t): 87 | """Wrapper function: transfrom velocity prediction model to denoiser 88 | Args: 89 | velocity: [batch_dim, ...] shaped tensor; velocity model output 90 | x: [batch_dim, ...] shaped tensor; x_t data point 91 | t: [batch_dim,] time tensor 92 | """ 93 | t = expand_t_like_x(t, x) 94 | alpha_t, d_alpha_t = self.compute_alpha_t(t) 95 | sigma_t, d_sigma_t = self.compute_sigma_t(t) 96 | mean = x 97 | reverse_alpha_ratio = alpha_t / d_alpha_t 98 | var = reverse_alpha_ratio * d_sigma_t - sigma_t 99 | noise = (reverse_alpha_ratio * velocity - mean) / var 100 | return noise 101 | 102 | def get_velocity_from_score(self, score, x, t): 103 | """Wrapper function: transfrom score prediction model to velocity 104 | Args: 105 | score: [batch_dim, ...] shaped tensor; score model output 106 | x: [batch_dim, ...] shaped tensor; x_t data point 107 | t: [batch_dim,] time tensor 108 | """ 109 | t = expand_t_like_x(t, x) 110 | drift, var = self.compute_drift(x, t) 111 | velocity = var * score - drift 112 | return velocity 113 | 114 | def compute_mu_t(self, t, x0, x1): 115 | """Compute the mean of time-dependent density p_t""" 116 | t = expand_t_like_x(t, x1) 117 | alpha_t, _ = self.compute_alpha_t(t) 118 | sigma_t, _ = self.compute_sigma_t(t) 119 | return alpha_t * x1 + sigma_t * x0 120 | 121 | def compute_xt(self, t, x0, x1): 122 | """Sample xt from time-dependent density p_t; rng is required""" 123 | xt = self.compute_mu_t(t, x0, x1) 124 | return xt 125 | 126 | def compute_ut(self, t, x0, x1, xt): 127 | """Compute the vector field corresponding to p_t""" 128 | t = expand_t_like_x(t, x1) 129 | _, d_alpha_t = self.compute_alpha_t(t) 130 | _, d_sigma_t = self.compute_sigma_t(t) 131 | return d_alpha_t * x1 + d_sigma_t * x0 132 | 133 | def plan(self, t, x0, x1): 134 | xt = self.compute_xt(t, x0, x1) 135 | ut = self.compute_ut(t, x0, x1, xt) 136 | return t, xt, ut 137 | 138 | 139 | class VPCPlan(ICPlan): 140 | """class for VP path flow matching""" 141 | 142 | def __init__(self, sigma_min=0.1, sigma_max=20.0): 143 | self.sigma_min = sigma_min 144 | self.sigma_max = sigma_max 145 | self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min 146 | self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min 147 | 148 | 149 | def compute_alpha_t(self, t): 150 | """Compute coefficient of x1""" 151 | alpha_t = self.log_mean_coeff(t) 152 | alpha_t = th.exp(alpha_t) 153 | d_alpha_t = alpha_t * self.d_log_mean_coeff(t) 154 | return alpha_t, d_alpha_t 155 | 156 | def compute_sigma_t(self, t): 157 | """Compute coefficient of x0""" 158 | p_sigma_t = 2 * self.log_mean_coeff(t) 159 | sigma_t = th.sqrt(1 - th.exp(p_sigma_t)) 160 | d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t) 161 | return sigma_t, d_sigma_t 162 | 163 | def compute_d_alpha_alpha_ratio_t(self, t): 164 | """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" 165 | return self.d_log_mean_coeff(t) 166 | 167 | def compute_drift(self, x, t): 168 | """Compute the drift term of the SDE""" 169 | t = expand_t_like_x(t, x) 170 | beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min) 171 | return -0.5 * beta_t * x, beta_t / 2 172 | 173 | 174 | class GVPCPlan(ICPlan): 175 | def __init__(self, sigma=0.0): 176 | super().__init__(sigma) 177 | 178 | def compute_alpha_t(self, t): 179 | """Compute coefficient of x1""" 180 | alpha_t = th.sin(t * np.pi / 2) 181 | d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2) 182 | return alpha_t, d_alpha_t 183 | 184 | def compute_sigma_t(self, t): 185 | """Compute coefficient of x0""" 186 | sigma_t = th.cos(t * np.pi / 2) 187 | d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2) 188 | return sigma_t, d_sigma_t 189 | 190 | def compute_d_alpha_alpha_ratio_t(self, t): 191 | """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" 192 | return np.pi / (2 * th.tan(t * np.pi / 2)) -------------------------------------------------------------------------------- /transport/transport.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | import logging 4 | 5 | import enum 6 | 7 | from . import path 8 | from .utils import EasyDict, log_state, mean_flat 9 | from .integrators import ode, sde 10 | 11 | class ModelType(enum.Enum): 12 | """ 13 | Which type of output the model predicts. 14 | """ 15 | 16 | NOISE = enum.auto() # the model predicts epsilon 17 | SCORE = enum.auto() # the model predicts \nabla \log p(x) 18 | VELOCITY = enum.auto() # the model predicts v(x) 19 | 20 | class PathType(enum.Enum): 21 | """ 22 | Which type of path to use. 23 | """ 24 | 25 | LINEAR = enum.auto() 26 | GVP = enum.auto() 27 | VP = enum.auto() 28 | 29 | class WeightType(enum.Enum): 30 | """ 31 | Which type of weighting to use. 32 | """ 33 | 34 | NONE = enum.auto() 35 | VELOCITY = enum.auto() 36 | LIKELIHOOD = enum.auto() 37 | 38 | 39 | class Transport: 40 | 41 | def __init__( 42 | self, 43 | *, 44 | model_type, 45 | path_type, 46 | loss_type, 47 | train_eps, 48 | sample_eps, 49 | ): 50 | path_options = { 51 | PathType.LINEAR: path.ICPlan, 52 | PathType.GVP: path.GVPCPlan, 53 | PathType.VP: path.VPCPlan, 54 | } 55 | 56 | self.loss_type = loss_type 57 | self.model_type = model_type 58 | self.path_sampler = path_options[path_type]() 59 | self.train_eps = train_eps 60 | self.sample_eps = sample_eps 61 | 62 | def prior_logp(self, z): 63 | ''' 64 | Standard multivariate normal prior 65 | Assume z is batched 66 | ''' 67 | shape = th.tensor(z.size()) 68 | N = th.prod(shape[1:]) 69 | _fn = lambda x: -N / 2. * np.log(2 * np.pi) - th.sum(x ** 2) / 2. 70 | return th.vmap(_fn)(z) 71 | 72 | 73 | def check_interval( 74 | self, 75 | train_eps, 76 | sample_eps, 77 | *, 78 | diffusion_form="SBDM", 79 | sde=False, 80 | reverse=False, 81 | eval=False, 82 | last_step_size=0.0, 83 | ): 84 | t0 = 0 85 | t1 = 1 86 | eps = train_eps if not eval else sample_eps 87 | if (type(self.path_sampler) in [path.VPCPlan]): 88 | 89 | t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size 90 | 91 | elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) \ 92 | and (self.model_type != ModelType.VELOCITY or sde): # avoid numerical issue by taking a first semi-implicit step 93 | 94 | t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0 95 | t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size 96 | 97 | if reverse: 98 | t0, t1 = 1 - t0, 1 - t1 99 | 100 | return t0, t1 101 | 102 | 103 | def sample(self, x1): 104 | """Sampling x0 & t based on shape of x1 (if needed) 105 | Args: 106 | x1 - data point; [batch, *dim] 107 | """ 108 | 109 | x0 = th.randn_like(x1) 110 | t0, t1 = self.check_interval(self.train_eps, self.sample_eps) 111 | t = th.rand((x1.shape[0],)) * (t1 - t0) + t0 112 | t = t.to(x1) 113 | return t, x0, x1 114 | 115 | 116 | def training_losses( 117 | self, 118 | model, 119 | x1, 120 | model_kwargs=None 121 | ): 122 | """Loss for training the score model 123 | Args: 124 | - model: backbone model; could be score, noise, or velocity 125 | - x1: datapoint 126 | - model_kwargs: additional arguments for the model 127 | """ 128 | if model_kwargs == None: 129 | model_kwargs = {} 130 | 131 | t, x0, x1 = self.sample(x1) 132 | t, xt, ut = self.path_sampler.plan(t, x0, x1) 133 | model_output = model(xt, t, **model_kwargs) 134 | B, *_, C = xt.shape 135 | assert model_output.size() == (B, *xt.size()[1:-1], C) 136 | 137 | terms = {} 138 | terms['pred'] = model_output 139 | if self.model_type == ModelType.VELOCITY: 140 | terms['loss'] = mean_flat(((model_output - ut) ** 2)) 141 | else: 142 | _, drift_var = self.path_sampler.compute_drift(xt, t) 143 | sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt)) 144 | if self.loss_type in [WeightType.VELOCITY]: 145 | weight = (drift_var / sigma_t) ** 2 146 | elif self.loss_type in [WeightType.LIKELIHOOD]: 147 | weight = drift_var / (sigma_t ** 2) 148 | elif self.loss_type in [WeightType.NONE]: 149 | weight = 1 150 | else: 151 | raise NotImplementedError() 152 | 153 | if self.model_type == ModelType.NOISE: 154 | terms['loss'] = mean_flat(weight * ((model_output - x0) ** 2)) 155 | else: 156 | terms['loss'] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2)) 157 | 158 | return terms 159 | 160 | 161 | def get_drift( 162 | self 163 | ): 164 | """member function for obtaining the drift of the probability flow ODE""" 165 | def score_ode(x, t, model, **model_kwargs): 166 | drift_mean, drift_var = self.path_sampler.compute_drift(x, t) 167 | model_output = model(x, t, **model_kwargs) 168 | return (-drift_mean + drift_var * model_output) # by change of variable 169 | 170 | def noise_ode(x, t, model, **model_kwargs): 171 | drift_mean, drift_var = self.path_sampler.compute_drift(x, t) 172 | sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x)) 173 | model_output = model(x, t, **model_kwargs) 174 | score = model_output / -sigma_t 175 | return (-drift_mean + drift_var * score) 176 | 177 | def velocity_ode(x, t, model, **model_kwargs): 178 | model_output = model(x, t, **model_kwargs) 179 | return model_output 180 | 181 | if self.model_type == ModelType.NOISE: 182 | drift_fn = noise_ode 183 | elif self.model_type == ModelType.SCORE: 184 | drift_fn = score_ode 185 | else: 186 | drift_fn = velocity_ode 187 | 188 | def body_fn(x, t, model, **model_kwargs): 189 | model_output = drift_fn(x, t, model, **model_kwargs) 190 | assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape" 191 | return model_output 192 | 193 | return body_fn 194 | 195 | 196 | def get_score( 197 | self, 198 | ): 199 | """member function for obtaining score of 200 | x_t = alpha_t * x + sigma_t * eps""" 201 | if self.model_type == ModelType.NOISE: 202 | score_fn = lambda x, t, model, **kwargs: model(x, t, **kwargs) / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0] 203 | elif self.model_type == ModelType.SCORE: 204 | score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs) 205 | elif self.model_type == ModelType.VELOCITY: 206 | score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x, t) 207 | else: 208 | raise NotImplementedError() 209 | 210 | return score_fn 211 | 212 | 213 | class Sampler: 214 | """Sampler class for the transport model""" 215 | def __init__( 216 | self, 217 | transport, 218 | ): 219 | """Constructor for a general sampler; supporting different sampling methods 220 | Args: 221 | - transport: an tranport object specify model prediction & interpolant type 222 | """ 223 | 224 | self.transport = transport 225 | self.drift = self.transport.get_drift() 226 | self.score = self.transport.get_score() 227 | 228 | def __get_sde_diffusion_and_drift( 229 | self, 230 | *, 231 | diffusion_form="SBDM", 232 | diffusion_norm=1.0, 233 | ): 234 | 235 | def diffusion_fn(x, t): 236 | diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm) 237 | return diffusion 238 | 239 | sde_drift = \ 240 | lambda x, t, model, **kwargs: \ 241 | self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs) 242 | 243 | sde_diffusion = diffusion_fn 244 | 245 | return sde_drift, sde_diffusion 246 | 247 | def __get_last_step( 248 | self, 249 | sde_drift, 250 | *, 251 | last_step, 252 | last_step_size, 253 | ): 254 | """Get the last step function of the SDE solver""" 255 | 256 | if last_step is None: 257 | last_step_fn = \ 258 | lambda x, t, model, **model_kwargs: \ 259 | x 260 | elif last_step == "Mean": 261 | last_step_fn = \ 262 | lambda x, t, model, **model_kwargs: \ 263 | x + sde_drift(x, t, model, **model_kwargs) * last_step_size 264 | elif last_step == "Tweedie": 265 | alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long 266 | sigma = self.transport.path_sampler.compute_sigma_t 267 | last_step_fn = \ 268 | lambda x, t, model, **model_kwargs: \ 269 | x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][0] * self.score(x, t, model, **model_kwargs) 270 | elif last_step == "Euler": 271 | last_step_fn = \ 272 | lambda x, t, model, **model_kwargs: \ 273 | x + self.drift(x, t, model, **model_kwargs) * last_step_size 274 | else: 275 | raise NotImplementedError() 276 | 277 | return last_step_fn 278 | 279 | def sample_sde( 280 | self, 281 | *, 282 | sampling_method="Euler", 283 | diffusion_form="SBDM", 284 | diffusion_norm=1.0, 285 | last_step="Mean", 286 | last_step_size=0.04, 287 | num_steps=250, 288 | ): 289 | """returns a sampling function with given SDE settings 290 | Args: 291 | - sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama 292 | - diffusion_form: function form of diffusion coefficient; default to be matching SBDM 293 | - diffusion_norm: function magnitude of diffusion coefficient; default to 1 294 | - last_step: type of the last step; default to identity 295 | - last_step_size: size of the last step; default to match the stride of 250 steps over [0,1] 296 | - num_steps: total integration step of SDE 297 | """ 298 | 299 | if last_step is None: 300 | last_step_size = 0.0 301 | 302 | sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift( 303 | diffusion_form=diffusion_form, 304 | diffusion_norm=diffusion_norm, 305 | ) 306 | 307 | t0, t1 = self.transport.check_interval( 308 | self.transport.train_eps, 309 | self.transport.sample_eps, 310 | diffusion_form=diffusion_form, 311 | sde=True, 312 | eval=True, 313 | reverse=False, 314 | last_step_size=last_step_size, 315 | ) 316 | 317 | _sde = sde( 318 | sde_drift, 319 | sde_diffusion, 320 | t0=t0, 321 | t1=t1, 322 | num_steps=num_steps, 323 | sampler_type=sampling_method 324 | ) 325 | 326 | last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size) 327 | 328 | 329 | def _sample(init, model, **model_kwargs): 330 | xs = _sde.sample(init, model, **model_kwargs) 331 | ts = th.ones(init.size(0), device=init.device) * t1 332 | x = last_step_fn(xs[-1], ts, model, **model_kwargs) 333 | xs.append(x) 334 | 335 | assert len(xs) == num_steps, "Samples does not match the number of steps" 336 | 337 | return xs 338 | 339 | return _sample 340 | 341 | def sample_ode( 342 | self, 343 | *, 344 | sampling_method="dopri5", 345 | num_steps=50, 346 | atol=1e-6, 347 | rtol=1e-3, 348 | reverse=False, 349 | ): 350 | """returns a sampling function with given ODE settings 351 | Args: 352 | - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 353 | - num_steps: 354 | - fixed solver (Euler, Heun): the actual number of integration steps performed 355 | - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation 356 | - atol: absolute error tolerance for the solver 357 | - rtol: relative error tolerance for the solver 358 | - reverse: whether solving the ODE in reverse (data to noise); default to False 359 | """ 360 | if reverse: 361 | drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs) 362 | else: 363 | drift = self.drift 364 | 365 | t0, t1 = self.transport.check_interval( 366 | self.transport.train_eps, 367 | self.transport.sample_eps, 368 | sde=False, 369 | eval=True, 370 | reverse=reverse, 371 | last_step_size=0.0, 372 | ) 373 | 374 | _ode = ode( 375 | drift=drift, 376 | t0=t0, 377 | t1=t1, 378 | sampler_type=sampling_method, 379 | num_steps=num_steps, 380 | atol=atol, 381 | rtol=rtol, 382 | ) 383 | 384 | return _ode.sample 385 | 386 | def sample_ode_likelihood( 387 | self, 388 | *, 389 | sampling_method="dopri5", 390 | num_steps=50, 391 | atol=1e-6, 392 | rtol=1e-3, 393 | ): 394 | 395 | """returns a sampling function for calculating likelihood with given ODE settings 396 | Args: 397 | - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 398 | - num_steps: 399 | - fixed solver (Euler, Heun): the actual number of integration steps performed 400 | - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation 401 | - atol: absolute error tolerance for the solver 402 | - rtol: relative error tolerance for the solver 403 | """ 404 | def _likelihood_drift(x, t, model, **model_kwargs): 405 | x, _ = x 406 | eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1 407 | t = th.ones_like(t) * (1 - t) 408 | with th.enable_grad(): 409 | x.requires_grad = True 410 | grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0] 411 | logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size())))) 412 | drift = self.drift(x, t, model, **model_kwargs) 413 | return (-drift, logp_grad) 414 | 415 | t0, t1 = self.transport.check_interval( 416 | self.transport.train_eps, 417 | self.transport.sample_eps, 418 | sde=False, 419 | eval=True, 420 | reverse=False, 421 | last_step_size=0.0, 422 | ) 423 | 424 | _ode = ode( 425 | drift=_likelihood_drift, 426 | t0=t0, 427 | t1=t1, 428 | sampler_type=sampling_method, 429 | num_steps=num_steps, 430 | atol=atol, 431 | rtol=rtol, 432 | ) 433 | 434 | def _sample_fn(x, model, **model_kwargs): 435 | init_logp = th.zeros(x.size(0)).to(x) 436 | input = (x, init_logp) 437 | drift, delta_logp = _ode.sample(input, model, **model_kwargs) 438 | drift, delta_logp = drift[-1], delta_logp[-1] 439 | prior_logp = self.transport.prior_logp(drift) 440 | logp = prior_logp - delta_logp 441 | return logp, drift 442 | 443 | return _sample_fn -------------------------------------------------------------------------------- /transport/utils.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | class EasyDict: 4 | 5 | def __init__(self, sub_dict): 6 | for k, v in sub_dict.items(): 7 | setattr(self, k, v) 8 | 9 | def __getitem__(self, key): 10 | return getattr(self, key) 11 | 12 | def mean_flat(x): 13 | """ 14 | Take the mean over all non-batch dimensions. 15 | """ 16 | return th.mean(x, dim=list(range(1, len(x.size())))) 17 | 18 | def log_state(state): 19 | result = [] 20 | 21 | sorted_state = dict(sorted(state.items())) 22 | for key, value in sorted_state.items(): 23 | # Check if the value is an instance of a class 24 | if "