├── .gitignore
├── LICENSE.txt
├── README.md
├── download.py
├── environment.yml
├── models.py
├── run_SiT.ipynb
├── sample.py
├── sample_ddp.py
├── train.py
├── train_utils.py
├── transport
├── __init__.py
├── integrators.py
├── path.py
├── transport.py
└── utils.py
├── visuals
├── .DS_Store
├── visual.png
└── visual_2.png
└── wandb_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | wandb
3 |
4 | .DS_store
5 | samples
6 | results
7 | pretrained_models
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Meta Platforms, Inc. and affiliates.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers (SiT)
Official PyTorch Implementation
2 |
3 | ### [Paper](https://arxiv.org/pdf/2401.08740.pdf) | [Project Page](https://scalable-interpolant.github.io/) | [](http://colab.research.google.com/github/willisma/SiT/blob/main/run_SiT.ipynb)
4 |
5 | 
6 |
7 | This repo contains PyTorch model definitions, pre-trained weights and training/sampling code for our paper exploring
8 | interpolant models with scalable transformers (SiTs).
9 |
10 | > [**Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers**](https://arxiv.org/pdf/2401.08740.pdf)
11 | > [Nanye Ma](https://willisma.github.io), [Mark Goldstein](https://marikgoldstein.github.io/), [Michael Albergo](http://malbergo.me/), [Nicholas Boffi](https://nmboffi.github.io/), [Eric Vanden-Eijnden](https://wp.nyu.edu/courantinstituteofmathematicalsciences-eve2/), [Saining Xie](https://www.sainingxie.com)
12 | >
New York University
13 |
14 | We present Scalable Interpolant Transformers (SiT), a family of generative models built on the backbone of Diffusion Transformers (DiT). The interpolant framework, which allows for connecting two distributions in a more flexible way than standard diffusion models, makes possible a modular study of various design choices impacting generative models built on dynamical transport: using discrete vs. continuous time learning, deciding the model to learn, choosing the interpolant connecting the distributions, and deploying a deterministic or stochastic sampler. By carefully introducing the above ingredients, SiT surpasses DiT uniformly across model sizes on the conditional ImageNet 256x256 benchmark using the exact same backbone, number of parameters, and GFLOPs. By exploring various diffusion coefficients, which can be tuned separately from learning, SiT achieves an FID-50K score of 2.06.
15 |
16 | This repository contains:
17 |
18 | * 🪐 A simple PyTorch [implementation](models.py) of SiT
19 | * ⚡️ Pre-trained class-conditional SiT models trained on ImageNet 256x256
20 | * 🛸 A SiT [training script](train.py) using PyTorch DDP
21 |
22 | ## Setup
23 |
24 | First, download and set up the repo:
25 |
26 | ```bash
27 | git clone https://github.com/willisma/SiT.git
28 | cd SiT
29 | ```
30 |
31 | We provide an [`environment.yml`](environment.yml) file that can be used to create a Conda environment. If you only want
32 | to run pre-trained models locally on CPU, you can remove the `cudatoolkit` and `pytorch-cuda` requirements from the file.
33 |
34 | ```bash
35 | conda env create -f environment.yml
36 | conda activate SiT
37 | ```
38 |
39 |
40 | ## Sampling [](https://github.com/willisma/SiT/blob/main/run_SiT.ipynb)
41 | 
42 |
43 | **Pre-trained SiT checkpoints.** You can sample from our pre-trained SiT models with [`sample.py`](sample.py). Weights for our pre-trained SiT model will be
44 | automatically downloaded depending on the model you use. The script has various arguments to adjust sampler configurations (ODE & SDE), sampling steps, change the classifier-free guidance scale, etc. For example, to sample from
45 | our 256x256 SiT-XL model with default ODE setting, you can use:
46 |
47 | ```bash
48 | python sample.py ODE --image-size 256 --seed 1
49 | ```
50 |
51 | For convenience, our pre-trained SiT models can be downloaded directly here as well:
52 |
53 | | SiT Model | Image Resolution | FID-50K | Inception Score | Gflops |
54 | |---------------|------------------|---------|-----------------|--------|
55 | | [XL/2](https://www.dl.dropboxusercontent.com/scl/fi/as9oeomcbub47de5g4be0/SiT-XL-2-256.pt?rlkey=uxzxmpicu46coq3msb17b9ofa&dl=0) | 256x256 | 2.06 | 270.27 | 119 |
56 |
57 |
58 |
59 | **Custom SiT checkpoints.** If you've trained a new SiT model with [`train.py`](train.py) (see [below](#training-SiT)), you can add the `--ckpt`
60 | argument to use your own checkpoint instead. For example, to sample from the EMA weights of a custom
61 | 256x256 SiT-L/4 model with ODE sampler, run:
62 |
63 | ```bash
64 | python sample.py ODE --model SiT-L/4 --image-size 256 --ckpt /path/to/model.pt
65 | ```
66 |
67 | ### Advanced sampler settings
68 |
69 | | | | | |
70 | |-----|----------|----------|--------------------------|
71 | | ODE | `--atol` | `float` | Absolute error tolerance |
72 | | | `--rtol` | `float` | Relative error tolenrace |
73 | | | `--sampling-method` | `str` | Sampling methods (refer to [`torchdiffeq`](https://github.com/rtqichen/torchdiffeq) ) |
74 |
75 | | | | | |
76 | |-----|----------|----------|--------------------------|
77 | | SDE | `--diffusion-form` | `str` | Form of SDE's diffusion coefficient (refer to Tab. 2 in [paper]()) |
78 | | | `--diffusion-norm` | `float` | Magnitude of SDE's diffusion coefficient |
79 | | | `--last-step` | `str` | Form of SDE's last step |
80 | | | | | None - Single SDE integration step |
81 | | | | | "Mean" - SDE integration step without diffusion coefficient |
82 | | | | | "Tweedie" - [Tweedie's denoising](https://efron.ckirby.su.domains/papers/2011TweediesFormula.pdf) step |
83 | | | | | "Euler" - Single ODE integration step
84 | | | `--sampling-method` | `str` | Sampling methods |
85 | | | | | "Euler" - First order integration |
86 | | | | | "Heun" - Second order integration |
87 |
88 | There are some more options; refer to [`train_utils.py`](train_utils.py) for details.
89 |
90 | ## Training SiT
91 |
92 | We provide a training script for SiT in [`train.py`](train.py). To launch SiT-XL/2 (256x256) training with `N` GPUs on
93 | one node:
94 |
95 | ```bash
96 | torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-XL/2 --data-path /path/to/imagenet/train
97 | ```
98 |
99 | **Logging.** To enable `wandb`, firstly set `WANDB_KEY`, `ENTITY`, and `PROJECT` as environment variables:
100 |
101 | ```bash
102 | export WANDB_KEY="key"
103 | export ENTITY="entity name"
104 | export PROJECT="project name"
105 | ```
106 |
107 | Then in training command add the `--wandb` flag:
108 |
109 | ```bash
110 | torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-XL/2 --data-path /path/to/imagenet/train --wandb
111 | ```
112 |
113 | **Interpolant settings.** We also support different choices of interpolant and model predictions. For example, to launch SiT-XL/2 (256x256) with `Linear` interpolant and `noise` prediction:
114 |
115 | ```bash
116 | torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-XL/2 --data-path /path/to/imagenet/train --path-type Linear --prediction noise
117 | ```
118 |
119 | **Resume training.** To resume training from custom checkpoint:
120 |
121 | ```bash
122 | torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-L/2 --data-path /path/to/imagenet/train --ckpt /path/to/model.pt
123 | ```
124 |
125 | **Caution.** Resuming training will automatically restore both model, EMA, and optimizer states and training configs to be the same as in the checkpoint.
126 |
127 | ## Evaluation (FID, Inception Score, etc.)
128 |
129 | We include a [`sample_ddp.py`](sample_ddp.py) script which samples a large number of images from a SiT model in parallel. This script
130 | generates a folder of samples as well as a `.npz` file which can be directly used with [ADM's TensorFlow
131 | evaluation suite](https://github.com/openai/guided-diffusion/tree/main/evaluations) to compute FID, Inception Score and
132 | other metrics. For example, to sample 50K images from our pre-trained SiT-XL/2 model over `N` GPUs under default ODE sampler settings, run:
133 |
134 | ```bash
135 | torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py ODE --model SiT-XL/2 --num-fid-samples 50000
136 | ```
137 |
138 | **Likelihood.** Likelihood evaluation is supported. To calculate likelihood, you can add the `--likelihood` flag to ODE sampler:
139 |
140 | ```bash
141 | torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py ODE --model SiT-XL/2 --likelihood
142 | ```
143 |
144 | Notice that only under ODE sampler likelihood can be calculated; see [`sample_ddp.py`](sample_ddp.py) for more details and settings.
145 |
146 | ### Enhancements
147 | Training (and sampling) could likely be speed-up significantly by:
148 | - [ ] using [Flash Attention](https://github.com/HazyResearch/flash-attention) in the SiT model
149 | - [ ] using `torch.compile` in PyTorch 2.0
150 |
151 | Basic features that would be nice to add:
152 | - [ ] Monitor FID and other metrics
153 | - [ ] AMP/bfloat16 support
154 |
155 | Precision in likelihood calculation could likely be improved by:
156 | - [ ] Uniform / Gaussian Dequantization
157 |
158 |
159 | ## Differences from JAX
160 |
161 | Our models were originally trained in JAX on TPUs. The weights in this repo are ported directly from the JAX models.
162 | There may be minor differences in results stemming from sampling on different platforms (TPU vs. GPU). We observed that sampling on TPU performs marginally worse than GPU (2.15 FID
163 | versus 2.06 in the paper).
164 |
165 |
166 | ## License
167 | This project is under the MIT license. See [LICENSE](LICENSE.txt) for details.
168 |
169 |
170 |
--------------------------------------------------------------------------------
/download.py:
--------------------------------------------------------------------------------
1 | # This source code is licensed under the license found in the
2 | # LICENSE file in the root directory of this source tree.
3 |
4 | """
5 | Functions for downloading pre-trained SiT models
6 | """
7 | from torchvision.datasets.utils import download_url
8 | import torch
9 | import os
10 |
11 |
12 | pretrained_models = {'SiT-XL-2-256x256.pt'}
13 |
14 |
15 | def find_model(model_name):
16 | """
17 | Finds a pre-trained SiT model, downloading it if necessary. Alternatively, loads a model from a local path.
18 | """
19 | if model_name in pretrained_models:
20 | return download_model(model_name)
21 | else:
22 | assert os.path.isfile(model_name), f'Could not find SiT checkpoint at {model_name}'
23 | checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
24 | if "ema" in checkpoint: # supports checkpoints from train.py
25 | checkpoint = checkpoint["ema"]
26 | return checkpoint
27 |
28 |
29 | def download_model(model_name):
30 | """
31 | Downloads a pre-trained SiT model from the web.
32 | """
33 | assert model_name in pretrained_models
34 | local_path = f'pretrained_models/{model_name}'
35 | if not os.path.isfile(local_path):
36 | os.makedirs('pretrained_models', exist_ok=True)
37 | web_path = f'https://www.dl.dropboxusercontent.com/scl/fi/as9oeomcbub47de5g4be0/SiT-XL-2-256.pt?rlkey=uxzxmpicu46coq3msb17b9ofa&dl=0'
38 | download_url(web_path, 'pretrained_models', filename=model_name)
39 | model = torch.load(local_path, map_location=lambda storage, loc: storage)
40 | return model
41 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: SiT
2 | channels:
3 | - pytorch
4 | - nvidia
5 | dependencies:
6 | - python >= 3.8
7 | - pytorch >= 1.13
8 | - torchvision
9 | - pytorch-cuda >=11.7
10 | - pip
11 | - pip:
12 | - timm
13 | - diffusers
14 | - accelerate
15 | - torchdiffeq
16 | - wandb
17 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | # This source code is licensed under the license found in the
2 | # LICENSE file in the root directory of this source tree.
3 | # --------------------------------------------------------
4 | # References:
5 | # GLIDE: https://github.com/openai/glide-text2im
6 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
7 | # --------------------------------------------------------
8 |
9 | import torch
10 | import torch.nn as nn
11 | import numpy as np
12 | import math
13 | from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
14 |
15 |
16 | def modulate(x, shift, scale):
17 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
18 |
19 |
20 | #################################################################################
21 | # Embedding Layers for Timesteps and Class Labels #
22 | #################################################################################
23 |
24 | class TimestepEmbedder(nn.Module):
25 | """
26 | Embeds scalar timesteps into vector representations.
27 | """
28 | def __init__(self, hidden_size, frequency_embedding_size=256):
29 | super().__init__()
30 | self.mlp = nn.Sequential(
31 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
32 | nn.SiLU(),
33 | nn.Linear(hidden_size, hidden_size, bias=True),
34 | )
35 | self.frequency_embedding_size = frequency_embedding_size
36 |
37 | @staticmethod
38 | def timestep_embedding(t, dim, max_period=10000):
39 | """
40 | Create sinusoidal timestep embeddings.
41 | :param t: a 1-D Tensor of N indices, one per batch element.
42 | These may be fractional.
43 | :param dim: the dimension of the output.
44 | :param max_period: controls the minimum frequency of the embeddings.
45 | :return: an (N, D) Tensor of positional embeddings.
46 | """
47 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
48 | half = dim // 2
49 | freqs = torch.exp(
50 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
51 | ).to(device=t.device)
52 | args = t[:, None].float() * freqs[None]
53 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
54 | if dim % 2:
55 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
56 | return embedding
57 |
58 | def forward(self, t):
59 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
60 | t_emb = self.mlp(t_freq)
61 | return t_emb
62 |
63 |
64 | class LabelEmbedder(nn.Module):
65 | """
66 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
67 | """
68 | def __init__(self, num_classes, hidden_size, dropout_prob):
69 | super().__init__()
70 | use_cfg_embedding = dropout_prob > 0
71 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
72 | self.num_classes = num_classes
73 | self.dropout_prob = dropout_prob
74 |
75 | def token_drop(self, labels, force_drop_ids=None):
76 | """
77 | Drops labels to enable classifier-free guidance.
78 | """
79 | if force_drop_ids is None:
80 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
81 | else:
82 | drop_ids = force_drop_ids == 1
83 | labels = torch.where(drop_ids, self.num_classes, labels)
84 | return labels
85 |
86 | def forward(self, labels, train, force_drop_ids=None):
87 | use_dropout = self.dropout_prob > 0
88 | if (train and use_dropout) or (force_drop_ids is not None):
89 | labels = self.token_drop(labels, force_drop_ids)
90 | embeddings = self.embedding_table(labels)
91 | return embeddings
92 |
93 |
94 | #################################################################################
95 | # Core SiT Model #
96 | #################################################################################
97 |
98 | class SiTBlock(nn.Module):
99 | """
100 | A SiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
101 | """
102 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
103 | super().__init__()
104 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
105 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
106 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
107 | mlp_hidden_dim = int(hidden_size * mlp_ratio)
108 | approx_gelu = lambda: nn.GELU(approximate="tanh")
109 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
110 | self.adaLN_modulation = nn.Sequential(
111 | nn.SiLU(),
112 | nn.Linear(hidden_size, 6 * hidden_size, bias=True)
113 | )
114 |
115 | def forward(self, x, c):
116 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
117 | x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
118 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
119 | return x
120 |
121 |
122 | class FinalLayer(nn.Module):
123 | """
124 | The final layer of SiT.
125 | """
126 | def __init__(self, hidden_size, patch_size, out_channels):
127 | super().__init__()
128 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
129 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
130 | self.adaLN_modulation = nn.Sequential(
131 | nn.SiLU(),
132 | nn.Linear(hidden_size, 2 * hidden_size, bias=True)
133 | )
134 |
135 | def forward(self, x, c):
136 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
137 | x = modulate(self.norm_final(x), shift, scale)
138 | x = self.linear(x)
139 | return x
140 |
141 |
142 | class SiT(nn.Module):
143 | """
144 | Diffusion model with a Transformer backbone.
145 | """
146 | def __init__(
147 | self,
148 | input_size=32,
149 | patch_size=2,
150 | in_channels=4,
151 | hidden_size=1152,
152 | depth=28,
153 | num_heads=16,
154 | mlp_ratio=4.0,
155 | class_dropout_prob=0.1,
156 | num_classes=1000,
157 | learn_sigma=True,
158 | ):
159 | super().__init__()
160 | self.learn_sigma = learn_sigma
161 | self.in_channels = in_channels
162 | self.out_channels = in_channels * 2 if learn_sigma else in_channels
163 | self.patch_size = patch_size
164 | self.num_heads = num_heads
165 |
166 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
167 | self.t_embedder = TimestepEmbedder(hidden_size)
168 | self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
169 | num_patches = self.x_embedder.num_patches
170 | # Will use fixed sin-cos embedding:
171 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
172 |
173 | self.blocks = nn.ModuleList([
174 | SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
175 | ])
176 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
177 | self.initialize_weights()
178 |
179 | def initialize_weights(self):
180 | # Initialize transformer layers:
181 | def _basic_init(module):
182 | if isinstance(module, nn.Linear):
183 | torch.nn.init.xavier_uniform_(module.weight)
184 | if module.bias is not None:
185 | nn.init.constant_(module.bias, 0)
186 | self.apply(_basic_init)
187 |
188 | # Initialize (and freeze) pos_embed by sin-cos embedding:
189 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
190 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
191 |
192 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
193 | w = self.x_embedder.proj.weight.data
194 | nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
195 | nn.init.constant_(self.x_embedder.proj.bias, 0)
196 |
197 | # Initialize label embedding table:
198 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
199 |
200 | # Initialize timestep embedding MLP:
201 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
202 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
203 |
204 | # Zero-out adaLN modulation layers in SiT blocks:
205 | for block in self.blocks:
206 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
207 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
208 |
209 | # Zero-out output layers:
210 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
211 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
212 | nn.init.constant_(self.final_layer.linear.weight, 0)
213 | nn.init.constant_(self.final_layer.linear.bias, 0)
214 |
215 | def unpatchify(self, x):
216 | """
217 | x: (N, T, patch_size**2 * C)
218 | imgs: (N, H, W, C)
219 | """
220 | c = self.out_channels
221 | p = self.x_embedder.patch_size[0]
222 | h = w = int(x.shape[1] ** 0.5)
223 | assert h * w == x.shape[1]
224 |
225 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
226 | x = torch.einsum('nhwpqc->nchpwq', x)
227 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
228 | return imgs
229 |
230 | def forward(self, x, t, y):
231 | """
232 | Forward pass of SiT.
233 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
234 | t: (N,) tensor of diffusion timesteps
235 | y: (N,) tensor of class labels
236 | """
237 | x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
238 | t = self.t_embedder(t) # (N, D)
239 | y = self.y_embedder(y, self.training) # (N, D)
240 | c = t + y # (N, D)
241 | for block in self.blocks:
242 | x = block(x, c) # (N, T, D)
243 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
244 | x = self.unpatchify(x) # (N, out_channels, H, W)
245 | if self.learn_sigma:
246 | x, _ = x.chunk(2, dim=1)
247 | return x
248 |
249 | def forward_with_cfg(self, x, t, y, cfg_scale):
250 | """
251 | Forward pass of SiT, but also batches the unconSiTional forward pass for classifier-free guidance.
252 | """
253 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
254 | half = x[: len(x) // 2]
255 | combined = torch.cat([half, half], dim=0)
256 | model_out = self.forward(combined, t, y)
257 | # For exact reproducibility reasons, we apply classifier-free guidance on only
258 | # three channels by default. The standard approach to cfg applies it to all channels.
259 | # This can be done by uncommenting the following line and commenting-out the line following that.
260 | # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
261 | eps, rest = model_out[:, :3], model_out[:, 3:]
262 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
263 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
264 | eps = torch.cat([half_eps, half_eps], dim=0)
265 | return torch.cat([eps, rest], dim=1)
266 |
267 |
268 | #################################################################################
269 | # Sine/Cosine Positional Embedding Functions #
270 | #################################################################################
271 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
272 |
273 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
274 | """
275 | grid_size: int of the grid height and width
276 | return:
277 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
278 | """
279 | grid_h = np.arange(grid_size, dtype=np.float32)
280 | grid_w = np.arange(grid_size, dtype=np.float32)
281 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
282 | grid = np.stack(grid, axis=0)
283 |
284 | grid = grid.reshape([2, 1, grid_size, grid_size])
285 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
286 | if cls_token and extra_tokens > 0:
287 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
288 | return pos_embed
289 |
290 |
291 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
292 | assert embed_dim % 2 == 0
293 |
294 | # use half of dimensions to encode grid_h
295 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
296 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
297 |
298 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
299 | return emb
300 |
301 |
302 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
303 | """
304 | embed_dim: output dimension for each position
305 | pos: a list of positions to be encoded: size (M,)
306 | out: (M, D)
307 | """
308 | assert embed_dim % 2 == 0
309 | omega = np.arange(embed_dim // 2, dtype=np.float64)
310 | omega /= embed_dim / 2.
311 | omega = 1. / 10000**omega # (D/2,)
312 |
313 | pos = pos.reshape(-1) # (M,)
314 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
315 |
316 | emb_sin = np.sin(out) # (M, D/2)
317 | emb_cos = np.cos(out) # (M, D/2)
318 |
319 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
320 | return emb
321 |
322 |
323 | #################################################################################
324 | # SiT Configs #
325 | #################################################################################
326 |
327 | def SiT_XL_2(**kwargs):
328 | return SiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
329 |
330 | def SiT_XL_4(**kwargs):
331 | return SiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
332 |
333 | def SiT_XL_8(**kwargs):
334 | return SiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
335 |
336 | def SiT_L_2(**kwargs):
337 | return SiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
338 |
339 | def SiT_L_4(**kwargs):
340 | return SiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
341 |
342 | def SiT_L_8(**kwargs):
343 | return SiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
344 |
345 | def SiT_B_2(**kwargs):
346 | return SiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
347 |
348 | def SiT_B_4(**kwargs):
349 | return SiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
350 |
351 | def SiT_B_8(**kwargs):
352 | return SiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
353 |
354 | def SiT_S_2(**kwargs):
355 | return SiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
356 |
357 | def SiT_S_4(**kwargs):
358 | return SiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
359 |
360 | def SiT_S_8(**kwargs):
361 | return SiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
362 |
363 |
364 | SiT_models = {
365 | 'SiT-XL/2': SiT_XL_2, 'SiT-XL/4': SiT_XL_4, 'SiT-XL/8': SiT_XL_8,
366 | 'SiT-L/2': SiT_L_2, 'SiT-L/4': SiT_L_4, 'SiT-L/8': SiT_L_8,
367 | 'SiT-B/2': SiT_B_2, 'SiT-B/4': SiT_B_4, 'SiT-B/8': SiT_B_8,
368 | 'SiT-S/2': SiT_S_2, 'SiT-S/4': SiT_S_4, 'SiT-S/8': SiT_S_8,
369 | }
370 |
--------------------------------------------------------------------------------
/run_SiT.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "355UKMUQJxFd"
7 | },
8 | "source": [
9 | "# SiT: Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers\n",
10 | "\n",
11 | "This notebook samples from pre-trained SiT models. SiTs are class-conSiTional latent interpolant models trained on ImageNet, unifying Flow and Diffusion Methods. \n",
12 | "\n",
13 | "[Paper]() | [GitHub](github.com/willisma/SiT)"
14 | ]
15 | },
16 | {
17 | "cell_type": "markdown",
18 | "metadata": {
19 | "id": "zJlgLkSaKn7u"
20 | },
21 | "source": [
22 | "# 1. Setup\n",
23 | "\n",
24 | "We recommend using GPUs (Runtime > Change runtime type > Hardware accelerator > GPU). Run this cell to clone the SiT GitHub repo and setup PyTorch. You only have to run this once."
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": null,
30 | "metadata": {},
31 | "outputs": [],
32 | "source": [
33 | "!git clone https://github.com/willisma/SiT.git\n",
34 | "import SiT, os\n",
35 | "os.chdir('SiT')\n",
36 | "os.environ['PYTHONPATH'] = '/env/python:/content/SiT'\n",
37 | "!pip install diffusers timm torchdiffeq --upgrade\n",
38 | "# SiT imports:\n",
39 | "import torch\n",
40 | "from torchvision.utils import save_image\n",
41 | "from transport import create_transport, Sampler\n",
42 | "from diffusers.models import AutoencoderKL\n",
43 | "from download import find_model\n",
44 | "from models import SiT_XL_2\n",
45 | "from PIL import Image\n",
46 | "from IPython.display import display\n",
47 | "torch.set_grad_enabled(False)\n",
48 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
49 | "if device == \"cpu\":\n",
50 | " print(\"GPU not found. Using CPU instead.\")"
51 | ]
52 | },
53 | {
54 | "cell_type": "markdown",
55 | "metadata": {
56 | "id": "AXpziRkoOvV9"
57 | },
58 | "source": [
59 | "# Download SiT-XL/2 Models"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": null,
65 | "metadata": {
66 | "id": "EWG-WNimO59K"
67 | },
68 | "outputs": [],
69 | "source": [
70 | "image_size = \"256\"\n",
71 | "vae_model = \"stabilityai/sd-vae-ft-ema\" #@param [\"stabilityai/sd-vae-ft-mse\", \"stabilityai/sd-vae-ft-ema\"]\n",
72 | "latent_size = int(image_size) // 8\n",
73 | "# Load model:\n",
74 | "model = SiT_XL_2(input_size=latent_size).to(device)\n",
75 | "state_dict = find_model(f\"SiT-XL-2-{image_size}x{image_size}.pt\")\n",
76 | "model.load_state_dict(state_dict)\n",
77 | "model.eval() # important!\n",
78 | "vae = AutoencoderKL.from_pretrained(vae_model).to(device)"
79 | ]
80 | },
81 | {
82 | "cell_type": "markdown",
83 | "metadata": {
84 | "id": "5JTNyzNZKb9E"
85 | },
86 | "source": [
87 | "# 2. Sample from Pre-trained SiT Models\n",
88 | "\n",
89 | "You can customize several sampling options. For the full list of ImageNet classes, [check out this](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a)."
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": null,
95 | "metadata": {
96 | "id": "-Hw7B5h4Kk4p"
97 | },
98 | "outputs": [],
99 | "source": [
100 | "# Set user inputs:\n",
101 | "seed = 0 #@param {type:\"number\"}\n",
102 | "torch.manual_seed(seed)\n",
103 | "num_sampling_steps = 250 #@param {type:\"slider\", min:0, max:1000, step:1}\n",
104 | "cfg_scale = 4 #@param {type:\"slider\", min:1, max:10, step:0.1}\n",
105 | "class_labels = 207, 360, 387, 974, 88, 979, 417, 279 #@param {type:\"raw\"}\n",
106 | "samples_per_row = 4 #@param {type:\"number\"}\n",
107 | "sampler_type = \"ODE\" #@param [\"ODE\", \"SDE\"]\n",
108 | "\n",
109 | "\n",
110 | "# Create diffusion object:\n",
111 | "transport = create_transport()\n",
112 | "sampler = Sampler(transport)\n",
113 | "\n",
114 | "# Create sampling noise:\n",
115 | "n = len(class_labels)\n",
116 | "z = torch.randn(n, 4, latent_size, latent_size, device=device)\n",
117 | "y = torch.tensor(class_labels, device=device)\n",
118 | "\n",
119 | "# Setup classifier-free guidance:\n",
120 | "z = torch.cat([z, z], 0)\n",
121 | "y_null = torch.tensor([1000] * n, device=device)\n",
122 | "y = torch.cat([y, y_null], 0)\n",
123 | "model_kwargs = dict(y=y, cfg_scale=cfg_scale)\n",
124 | "\n",
125 | "# Sample images:\n",
126 | "if sampler_type == \"SDE\":\n",
127 | " SDE_sampling_method = \"Euler\" #@param [\"Euler\", \"Heun\"]\n",
128 | " diffusion_form = \"linear\" #@param [\"constant\", \"SBDM\", \"sigma\", \"linear\", \"decreasing\", \"increasing-decreasing\"]\n",
129 | " diffusion_norm = 1 #@param {type:\"slider\", min:0, max:10.0, step:0.1}\n",
130 | " last_step = \"Mean\" #@param [\"Mean\", \"Tweedie\", \"Euler\"]\n",
131 | " last_step_size = 0.4 #@param {type:\"slider\", min:0, max:1.0, step:0.01}\n",
132 | " sample_fn = sampler.sample_sde(\n",
133 | " sampling_method=SDE_sampling_method,\n",
134 | " diffusion_form=diffusion_form, \n",
135 | " diffusion_norm=diffusion_norm,\n",
136 | " last_step_size=last_step_size, \n",
137 | " num_steps=num_sampling_steps,\n",
138 | " ) \n",
139 | "elif sampler_type == \"ODE\":\n",
140 | " # default to Adaptive Solver\n",
141 | " ODE_sampling_method = \"dopri5\" #@param [\"dopri5\", \"euler\", \"rk4\"]\n",
142 | " atol = 1e-6\n",
143 | " rtol = 1e-3\n",
144 | " sample_fn = sampler.sample_ode(\n",
145 | " sampling_method=ODE_sampling_method,\n",
146 | " atol=atol,\n",
147 | " rtol=rtol,\n",
148 | " num_steps=num_sampling_steps\n",
149 | " ) \n",
150 | "samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]\n",
151 | "samples = vae.decode(samples / 0.18215).sample\n",
152 | "\n",
153 | "# Save and display images:\n",
154 | "save_image(samples, \"sample.png\", nrow=int(samples_per_row), \n",
155 | " normalize=True, value_range=(-1, 1))\n",
156 | "samples = Image.open(\"sample.png\")\n",
157 | "display(samples)"
158 | ]
159 | }
160 | ],
161 | "metadata": {
162 | "colab": {
163 | "provenance": []
164 | },
165 | "kernelspec": {
166 | "display_name": "Python 3.8.10 64-bit",
167 | "language": "python",
168 | "name": "python3"
169 | },
170 | "language_info": {
171 | "name": "python",
172 | "version": "3.8.10"
173 | },
174 | "vscode": {
175 | "interpreter": {
176 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
177 | }
178 | }
179 | },
180 | "nbformat": 4,
181 | "nbformat_minor": 0
182 | }
183 |
--------------------------------------------------------------------------------
/sample.py:
--------------------------------------------------------------------------------
1 | # This source code is licensed under the license found in the
2 | # LICENSE file in the root directory of this source tree.
3 |
4 | """
5 | Sample new images from a pre-trained SiT.
6 | """
7 | import torch
8 | torch.backends.cuda.matmul.allow_tf32 = True
9 | torch.backends.cudnn.allow_tf32 = True
10 | from torchvision.utils import save_image
11 | from diffusers.models import AutoencoderKL
12 | from download import find_model
13 | from models import SiT_models
14 | from train_utils import parse_ode_args, parse_sde_args, parse_transport_args
15 | from transport import create_transport, Sampler
16 | import argparse
17 | import sys
18 | from time import time
19 |
20 |
21 | def main(mode, args):
22 | # Setup PyTorch:
23 | torch.manual_seed(args.seed)
24 | torch.set_grad_enabled(False)
25 | device = "cuda" if torch.cuda.is_available() else "cpu"
26 |
27 | if args.ckpt is None:
28 | assert args.model == "SiT-XL/2", "Only SiT-XL/2 models are available for auto-download."
29 | assert args.image_size in [256, 512]
30 | assert args.num_classes == 1000
31 | assert args.image_size == 256, "512x512 models are not yet available for auto-download." # remove this line when 512x512 models are available
32 | learn_sigma = args.image_size == 256
33 | else:
34 | learn_sigma = False
35 |
36 | # Load model:
37 | latent_size = args.image_size // 8
38 | model = SiT_models[args.model](
39 | input_size=latent_size,
40 | num_classes=args.num_classes,
41 | learn_sigma=learn_sigma,
42 | ).to(device)
43 | # Auto-download a pre-trained model or load a custom SiT checkpoint from train.py:
44 | ckpt_path = args.ckpt or f"SiT-XL-2-{args.image_size}x{args.image_size}.pt"
45 | state_dict = find_model(ckpt_path)
46 | model.load_state_dict(state_dict)
47 | model.eval() # important!
48 | transport = create_transport(
49 | args.path_type,
50 | args.prediction,
51 | args.loss_weight,
52 | args.train_eps,
53 | args.sample_eps
54 | )
55 | sampler = Sampler(transport)
56 | if mode == "ODE":
57 | if args.likelihood:
58 | assert args.cfg_scale == 1, "Likelihood is incompatible with guidance"
59 | sample_fn = sampler.sample_ode_likelihood(
60 | sampling_method=args.sampling_method,
61 | num_steps=args.num_sampling_steps,
62 | atol=args.atol,
63 | rtol=args.rtol,
64 | )
65 | else:
66 | sample_fn = sampler.sample_ode(
67 | sampling_method=args.sampling_method,
68 | num_steps=args.num_sampling_steps,
69 | atol=args.atol,
70 | rtol=args.rtol,
71 | reverse=args.reverse
72 | )
73 |
74 | elif mode == "SDE":
75 | sample_fn = sampler.sample_sde(
76 | sampling_method=args.sampling_method,
77 | diffusion_form=args.diffusion_form,
78 | diffusion_norm=args.diffusion_norm,
79 | last_step=args.last_step,
80 | last_step_size=args.last_step_size,
81 | num_steps=args.num_sampling_steps,
82 | )
83 |
84 |
85 | vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
86 |
87 | # Labels to condition the model with (feel free to change):
88 | class_labels = [207, 360, 387, 974, 88, 979, 417, 279]
89 |
90 | # Create sampling noise:
91 | n = len(class_labels)
92 | z = torch.randn(n, 4, latent_size, latent_size, device=device)
93 | y = torch.tensor(class_labels, device=device)
94 |
95 | # Setup classifier-free guidance:
96 | z = torch.cat([z, z], 0)
97 | y_null = torch.tensor([1000] * n, device=device)
98 | y = torch.cat([y, y_null], 0)
99 | model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)
100 |
101 | # Sample images:
102 | start_time = time()
103 | samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]
104 | samples, _ = samples.chunk(2, dim=0) # Remove null class samples
105 | samples = vae.decode(samples / 0.18215).sample
106 | print(f"Sampling took {time() - start_time:.2f} seconds.")
107 |
108 | # Save and display images:
109 | save_image(samples, "sample.png", nrow=4, normalize=True, value_range=(-1, 1))
110 |
111 |
112 | if __name__ == "__main__":
113 | parser = argparse.ArgumentParser()
114 |
115 | if len(sys.argv) < 2:
116 | print("Usage: program.py [options]")
117 | sys.exit(1)
118 |
119 | mode = sys.argv[1]
120 |
121 | assert mode[:2] != "--", "Usage: program.py [options]"
122 | assert mode in ["ODE", "SDE"], "Invalid mode. Please choose 'ODE' or 'SDE'"
123 |
124 | parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2")
125 | parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="mse")
126 | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
127 | parser.add_argument("--num-classes", type=int, default=1000)
128 | parser.add_argument("--cfg-scale", type=float, default=4.0)
129 | parser.add_argument("--num-sampling-steps", type=int, default=250)
130 | parser.add_argument("--seed", type=int, default=0)
131 | parser.add_argument("--ckpt", type=str, default=None,
132 | help="Optional path to a SiT checkpoint (default: auto-download a pre-trained SiT-XL/2 model).")
133 |
134 |
135 | parse_transport_args(parser)
136 | if mode == "ODE":
137 | parse_ode_args(parser)
138 | # Further processing for ODE
139 | elif mode == "SDE":
140 | parse_sde_args(parser)
141 | # Further processing for SDE
142 |
143 | args = parser.parse_known_args()[0]
144 | main(mode, args)
145 |
--------------------------------------------------------------------------------
/sample_ddp.py:
--------------------------------------------------------------------------------
1 | # This source code is licensed under the license found in the
2 | # LICENSE file in the root directory of this source tree.
3 |
4 | """
5 | Samples a large number of images from a pre-trained SiT model using DDP.
6 | Subsequently saves a .npz file that can be used to compute FID and other
7 | evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations
8 |
9 | For a simple single-GPU/CPU sampling script, see sample.py.
10 | """
11 | import torch
12 | import torch.distributed as dist
13 | from models import SiT_models
14 | from download import find_model
15 | from transport import create_transport, Sampler
16 | from diffusers.models import AutoencoderKL
17 | from train_utils import parse_ode_args, parse_sde_args, parse_transport_args
18 | from tqdm import tqdm
19 | import os
20 | from PIL import Image
21 | import numpy as np
22 | import math
23 | import argparse
24 | import sys
25 |
26 |
27 | def create_npz_from_sample_folder(sample_dir, num=50_000):
28 | """
29 | Builds a single .npz file from a folder of .png samples.
30 | """
31 | samples = []
32 | for i in tqdm(range(num), desc="Building .npz file from samples"):
33 | sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
34 | sample_np = np.asarray(sample_pil).astype(np.uint8)
35 | samples.append(sample_np)
36 | samples = np.stack(samples)
37 | assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
38 | npz_path = f"{sample_dir}.npz"
39 | np.savez(npz_path, arr_0=samples)
40 | print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
41 | return npz_path
42 |
43 |
44 | def main(mode, args):
45 | """
46 | Run sampling.
47 | """
48 | torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences
49 | assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
50 | torch.set_grad_enabled(False)
51 |
52 | # Setup DDP:
53 | dist.init_process_group("nccl")
54 | rank = dist.get_rank()
55 | device = rank % torch.cuda.device_count()
56 | seed = args.global_seed * dist.get_world_size() + rank
57 | torch.manual_seed(seed)
58 | torch.cuda.set_device(device)
59 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
60 |
61 | if args.ckpt is None:
62 | assert args.model == "SiT-XL/2", "Only SiT-XL/2 models are available for auto-download."
63 | assert args.image_size in [256, 512]
64 | assert args.num_classes == 1000
65 | assert args.image_size == 256, "512x512 models are not yet available for auto-download." # remove this line when 512x512 models are available
66 | learn_sigma = args.image_size == 256
67 | else:
68 | learn_sigma = False
69 |
70 | # Load model:
71 | latent_size = args.image_size // 8
72 | model = SiT_models[args.model](
73 | input_size=latent_size,
74 | num_classes=args.num_classes,
75 | learn_sigma=learn_sigma,
76 | ).to(device)
77 | # Auto-download a pre-trained model or load a custom SiT checkpoint from train.py:
78 | ckpt_path = args.ckpt or f"SiT-XL-2-{args.image_size}x{args.image_size}.pt"
79 | state_dict = find_model(ckpt_path)
80 | model.load_state_dict(state_dict)
81 | model.eval() # important!
82 |
83 |
84 | transport = create_transport(
85 | args.path_type,
86 | args.prediction,
87 | args.loss_weight,
88 | args.train_eps,
89 | args.sample_eps
90 | )
91 | sampler = Sampler(transport)
92 | if mode == "ODE":
93 | if args.likelihood:
94 | assert args.cfg_scale == 1, "Likelihood is incompatible with guidance"
95 | sample_fn = sampler.sample_ode_likelihood(
96 | sampling_method=args.sampling_method,
97 | num_steps=args.num_sampling_steps,
98 | atol=args.atol,
99 | rtol=args.rtol,
100 | )
101 | else:
102 | sample_fn = sampler.sample_ode(
103 | sampling_method=args.sampling_method,
104 | num_steps=args.num_sampling_steps,
105 | atol=args.atol,
106 | rtol=args.rtol,
107 | reverse=args.reverse
108 | )
109 | elif mode == "SDE":
110 | sample_fn = sampler.sample_sde(
111 | sampling_method=args.sampling_method,
112 | diffusion_form=args.diffusion_form,
113 | diffusion_norm=args.diffusion_norm,
114 | last_step=args.last_step,
115 | last_step_size=args.last_step_size,
116 | num_steps=args.num_sampling_steps,
117 | )
118 | vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
119 | assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0"
120 | using_cfg = args.cfg_scale > 1.0
121 |
122 | # Create folder to save samples:
123 | model_string_name = args.model.replace("/", "-")
124 | ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained"
125 | if mode == "ODE":
126 | folder_name = f"{model_string_name}-{ckpt_string_name}-" \
127 | f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\
128 | f"{mode}-{args.num_sampling_steps}-{args.sampling_method}"
129 | elif mode == "SDE":
130 | folder_name = f"{model_string_name}-{ckpt_string_name}-" \
131 | f"cfg-{args.cfg_scale}-{args.per_proc_batch_size}-"\
132 | f"{mode}-{args.num_sampling_steps}-{args.sampling_method}-"\
133 | f"{args.diffusion_form}-{args.last_step}-{args.last_step_size}"
134 | sample_folder_dir = f"{args.sample_dir}/{folder_name}"
135 | if rank == 0:
136 | os.makedirs(sample_folder_dir, exist_ok=True)
137 | print(f"Saving .png samples at {sample_folder_dir}")
138 | dist.barrier()
139 |
140 | # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
141 | n = args.per_proc_batch_size
142 | global_batch_size = n * dist.get_world_size()
143 | # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
144 | num_samples = len([name for name in os.listdir(sample_folder_dir) if (os.path.isfile(os.path.join(sample_folder_dir, name)) and ".png" in name)])
145 | total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
146 | if rank == 0:
147 | print(f"Total number of images that will be sampled: {total_samples}")
148 | assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
149 | samples_needed_this_gpu = int(total_samples // dist.get_world_size())
150 | assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
151 | iterations = int(samples_needed_this_gpu // n)
152 | done_iterations = int( int(num_samples // dist.get_world_size()) // n)
153 | pbar = range(iterations)
154 | pbar = tqdm(pbar) if rank == 0 else pbar
155 | total = 0
156 |
157 | for i in pbar:
158 | # Sample inputs:
159 | z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device)
160 | y = torch.randint(0, args.num_classes, (n,), device=device)
161 |
162 | # Setup classifier-free guidance:
163 | if using_cfg:
164 | z = torch.cat([z, z], 0)
165 | y_null = torch.tensor([1000] * n, device=device)
166 | y = torch.cat([y, y_null], 0)
167 | model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)
168 | model_fn = model.forward_with_cfg
169 | else:
170 | model_kwargs = dict(y=y)
171 | model_fn = model.forward
172 |
173 | samples = sample_fn(z, model_fn, **model_kwargs)[-1]
174 | if using_cfg:
175 | samples, _ = samples.chunk(2, dim=0) # Remove null class samples
176 |
177 | samples = vae.decode(samples / 0.18215).sample
178 | samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
179 |
180 | # Save samples to disk as individual .png files
181 | for i, sample in enumerate(samples):
182 | index = i * dist.get_world_size() + rank + total
183 | Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
184 | total += global_batch_size
185 | dist.barrier()
186 |
187 | # Make sure all processes have finished saving their samples before attempting to convert to .npz
188 | dist.barrier()
189 | if rank == 0:
190 | create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
191 | print("Done.")
192 | dist.barrier()
193 | dist.destroy_process_group()
194 |
195 |
196 | if __name__ == "__main__":
197 |
198 | parser = argparse.ArgumentParser()
199 |
200 | if len(sys.argv) < 2:
201 | print("Usage: program.py [options]")
202 | sys.exit(1)
203 |
204 | mode = sys.argv[1]
205 |
206 | assert mode[:2] != "--", "Usage: program.py [options]"
207 | assert mode in ["ODE", "SDE"], "Invalid mode. Please choose 'ODE' or 'SDE'"
208 |
209 | parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2")
210 | parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
211 | parser.add_argument("--sample-dir", type=str, default="samples")
212 | parser.add_argument("--per-proc-batch-size", type=int, default=4)
213 | parser.add_argument("--num-fid-samples", type=int, default=50_000)
214 | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
215 | parser.add_argument("--num-classes", type=int, default=1000)
216 | parser.add_argument("--cfg-scale", type=float, default=1.0)
217 | parser.add_argument("--num-sampling-steps", type=int, default=250)
218 | parser.add_argument("--global-seed", type=int, default=0)
219 | parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True,
220 | help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.")
221 | parser.add_argument("--ckpt", type=str, default=None,
222 | help="Optional path to a SiT checkpoint (default: auto-download a pre-trained SiT-XL/2 model).")
223 |
224 | parse_transport_args(parser)
225 | if mode == "ODE":
226 | parse_ode_args(parser)
227 | # Further processing for ODE
228 | elif mode == "SDE":
229 | parse_sde_args(parser)
230 | # Further processing for SDE
231 |
232 | args = parser.parse_known_args()[0]
233 | main(mode, args)
234 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # This source code is licensed under the license found in the
2 | # LICENSE file in the root directory of this source tree.
3 |
4 | """
5 | A minimal training script for SiT using PyTorch DDP.
6 | """
7 | import torch
8 | # the first flag below was False when we tested this script but True makes A100 training a lot faster:
9 | torch.backends.cuda.matmul.allow_tf32 = True
10 | torch.backends.cudnn.allow_tf32 = True
11 | import torch.distributed as dist
12 | from torch.nn.parallel import DistributedDataParallel as DDP
13 | from torch.utils.data import DataLoader
14 | from torch.utils.data.distributed import DistributedSampler
15 | from torchvision.datasets import ImageFolder
16 | from torchvision import transforms
17 | import numpy as np
18 | from collections import OrderedDict
19 | from PIL import Image
20 | from copy import deepcopy
21 | from glob import glob
22 | from time import time
23 | import argparse
24 | import logging
25 | import os
26 |
27 | from models import SiT_models
28 | from download import find_model
29 | from transport import create_transport, Sampler
30 | from diffusers.models import AutoencoderKL
31 | from train_utils import parse_transport_args
32 | import wandb_utils
33 |
34 |
35 | #################################################################################
36 | # Training Helper Functions #
37 | #################################################################################
38 |
39 | @torch.no_grad()
40 | def update_ema(ema_model, model, decay=0.9999):
41 | """
42 | Step the EMA model towards the current model.
43 | """
44 | ema_params = OrderedDict(ema_model.named_parameters())
45 | model_params = OrderedDict(model.named_parameters())
46 |
47 | for name, param in model_params.items():
48 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
49 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
50 |
51 |
52 | def requires_grad(model, flag=True):
53 | """
54 | Set requires_grad flag for all parameters in a model.
55 | """
56 | for p in model.parameters():
57 | p.requires_grad = flag
58 |
59 |
60 | def cleanup():
61 | """
62 | End DDP training.
63 | """
64 | dist.destroy_process_group()
65 |
66 |
67 | def create_logger(logging_dir):
68 | """
69 | Create a logger that writes to a log file and stdout.
70 | """
71 | if dist.get_rank() == 0: # real logger
72 | logging.basicConfig(
73 | level=logging.INFO,
74 | format='[\033[34m%(asctime)s\033[0m] %(message)s',
75 | datefmt='%Y-%m-%d %H:%M:%S',
76 | handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
77 | )
78 | logger = logging.getLogger(__name__)
79 | else: # dummy logger (does nothing)
80 | logger = logging.getLogger(__name__)
81 | logger.addHandler(logging.NullHandler())
82 | return logger
83 |
84 |
85 | def center_crop_arr(pil_image, image_size):
86 | """
87 | Center cropping implementation from ADM.
88 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
89 | """
90 | while min(*pil_image.size) >= 2 * image_size:
91 | pil_image = pil_image.resize(
92 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
93 | )
94 |
95 | scale = image_size / min(*pil_image.size)
96 | pil_image = pil_image.resize(
97 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
98 | )
99 |
100 | arr = np.array(pil_image)
101 | crop_y = (arr.shape[0] - image_size) // 2
102 | crop_x = (arr.shape[1] - image_size) // 2
103 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
104 |
105 |
106 | #################################################################################
107 | # Training Loop #
108 | #################################################################################
109 |
110 | def main(args):
111 | """
112 | Trains a new SiT model.
113 | """
114 | assert torch.cuda.is_available(), "Training currently requires at least one GPU."
115 |
116 | # Setup DDP:
117 | dist.init_process_group("nccl")
118 | assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size."
119 | rank = dist.get_rank()
120 | device = rank % torch.cuda.device_count()
121 | seed = args.global_seed * dist.get_world_size() + rank
122 | torch.manual_seed(seed)
123 | torch.cuda.set_device(device)
124 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
125 | local_batch_size = int(args.global_batch_size // dist.get_world_size())
126 |
127 | # Setup an experiment folder:
128 | if rank == 0:
129 | os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders)
130 | experiment_index = len(glob(f"{args.results_dir}/*"))
131 | model_string_name = args.model.replace("/", "-") # e.g., SiT-XL/2 --> SiT-XL-2 (for naming folders)
132 | experiment_name = f"{experiment_index:03d}-{model_string_name}-" \
133 | f"{args.path_type}-{args.prediction}-{args.loss_weight}"
134 | experiment_dir = f"{args.results_dir}/{experiment_name}" # Create an experiment folder
135 | checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
136 | os.makedirs(checkpoint_dir, exist_ok=True)
137 | logger = create_logger(experiment_dir)
138 | logger.info(f"Experiment directory created at {experiment_dir}")
139 |
140 | entity = os.environ["ENTITY"]
141 | project = os.environ["PROJECT"]
142 | if args.wandb:
143 | wandb_utils.initialize(args, entity, experiment_name, project)
144 | else:
145 | logger = create_logger(None)
146 |
147 | # Create model:
148 | assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
149 | latent_size = args.image_size // 8
150 | model = SiT_models[args.model](
151 | input_size=latent_size,
152 | num_classes=args.num_classes
153 | )
154 |
155 | # Note that parameter initialization is done within the SiT constructor
156 | ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
157 |
158 | if args.ckpt is not None:
159 | ckpt_path = args.ckpt
160 | state_dict = find_model(ckpt_path)
161 | model.load_state_dict(state_dict["model"])
162 | ema.load_state_dict(state_dict["ema"])
163 | opt.load_state_dict(state_dict["opt"])
164 | args = state_dict["args"]
165 |
166 | requires_grad(ema, False)
167 |
168 | model = DDP(model.to(device), device_ids=[rank])
169 | transport = create_transport(
170 | args.path_type,
171 | args.prediction,
172 | args.loss_weight,
173 | args.train_eps,
174 | args.sample_eps
175 | ) # default: velocity;
176 | transport_sampler = Sampler(transport)
177 | vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
178 | logger.info(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
179 |
180 | # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
181 | opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0)
182 |
183 | # Setup data:
184 | transform = transforms.Compose([
185 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
186 | transforms.RandomHorizontalFlip(),
187 | transforms.ToTensor(),
188 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
189 | ])
190 | dataset = ImageFolder(args.data_path, transform=transform)
191 | sampler = DistributedSampler(
192 | dataset,
193 | num_replicas=dist.get_world_size(),
194 | rank=rank,
195 | shuffle=True,
196 | seed=args.global_seed
197 | )
198 | loader = DataLoader(
199 | dataset,
200 | batch_size=local_batch_size,
201 | shuffle=False,
202 | sampler=sampler,
203 | num_workers=args.num_workers,
204 | pin_memory=True,
205 | drop_last=True
206 | )
207 | logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})")
208 |
209 | # Prepare models for training:
210 | update_ema(ema, model.module, decay=0) # Ensure EMA is initialized with synced weights
211 | model.train() # important! This enables embedding dropout for classifier-free guidance
212 | ema.eval() # EMA model should always be in eval mode
213 |
214 | # Variables for monitoring/logging purposes:
215 | train_steps = 0
216 | log_steps = 0
217 | running_loss = 0
218 | start_time = time()
219 |
220 | # Labels to condition the model with (feel free to change):
221 | ys = torch.randint(1000, size=(local_batch_size,), device=device)
222 | use_cfg = args.cfg_scale > 1.0
223 | # Create sampling noise:
224 | n = ys.size(0)
225 | zs = torch.randn(n, 4, latent_size, latent_size, device=device)
226 |
227 | # Setup classifier-free guidance:
228 | if use_cfg:
229 | zs = torch.cat([zs, zs], 0)
230 | y_null = torch.tensor([1000] * n, device=device)
231 | ys = torch.cat([ys, y_null], 0)
232 | sample_model_kwargs = dict(y=ys, cfg_scale=args.cfg_scale)
233 | model_fn = ema.forward_with_cfg
234 | else:
235 | sample_model_kwargs = dict(y=ys)
236 | model_fn = ema.forward
237 |
238 | logger.info(f"Training for {args.epochs} epochs...")
239 | for epoch in range(args.epochs):
240 | sampler.set_epoch(epoch)
241 | logger.info(f"Beginning epoch {epoch}...")
242 | for x, y in loader:
243 | x = x.to(device)
244 | y = y.to(device)
245 | with torch.no_grad():
246 | # Map input images to latent space + normalize latents:
247 | x = vae.encode(x).latent_dist.sample().mul_(0.18215)
248 | model_kwargs = dict(y=y)
249 | loss_dict = transport.training_losses(model, x, model_kwargs)
250 | loss = loss_dict["loss"].mean()
251 | opt.zero_grad()
252 | loss.backward()
253 | opt.step()
254 | update_ema(ema, model.module)
255 |
256 | # Log loss values:
257 | running_loss += loss.item()
258 | log_steps += 1
259 | train_steps += 1
260 | if train_steps % args.log_every == 0:
261 | # Measure training speed:
262 | torch.cuda.synchronize()
263 | end_time = time()
264 | steps_per_sec = log_steps / (end_time - start_time)
265 | # Reduce loss history over all processes:
266 | avg_loss = torch.tensor(running_loss / log_steps, device=device)
267 | dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
268 | avg_loss = avg_loss.item() / dist.get_world_size()
269 | logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}")
270 | if args.wandb:
271 | wandb_utils.log(
272 | { "train loss": avg_loss, "train steps/sec": steps_per_sec },
273 | step=train_steps
274 | )
275 | # Reset monitoring variables:
276 | running_loss = 0
277 | log_steps = 0
278 | start_time = time()
279 |
280 | # Save SiT checkpoint:
281 | if train_steps % args.ckpt_every == 0 and train_steps > 0:
282 | if rank == 0:
283 | checkpoint = {
284 | "model": model.module.state_dict(),
285 | "ema": ema.state_dict(),
286 | "opt": opt.state_dict(),
287 | "args": args
288 | }
289 | checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
290 | torch.save(checkpoint, checkpoint_path)
291 | logger.info(f"Saved checkpoint to {checkpoint_path}")
292 | dist.barrier()
293 |
294 | if train_steps % args.sample_every == 0 and train_steps > 0:
295 | logger.info("Generating EMA samples...")
296 | sample_fn = transport_sampler.sample_ode() # default to ode sampling
297 | samples = sample_fn(zs, model_fn, **sample_model_kwargs)[-1]
298 | dist.barrier()
299 |
300 | if use_cfg: #remove null samples
301 | samples, _ = samples.chunk(2, dim=0)
302 | samples = vae.decode(samples / 0.18215).sample
303 | out_samples = torch.zeros((args.global_batch_size, 3, args.image_size, args.image_size), device=device)
304 | dist.all_gather_into_tensor(out_samples, samples)
305 | if args.wandb:
306 | wandb_utils.log_image(out_samples, train_steps)
307 | logging.info("Generating EMA samples done.")
308 |
309 | model.eval() # important! This disables randomized embedding dropout
310 | # do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
311 |
312 | logger.info("Done!")
313 | cleanup()
314 |
315 |
316 | if __name__ == "__main__":
317 | # Default args here will train SiT-XL/2 with the hyperparameters we used in our paper (except training iters).
318 | parser = argparse.ArgumentParser()
319 | parser.add_argument("--data-path", type=str, required=True)
320 | parser.add_argument("--results-dir", type=str, default="results")
321 | parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2")
322 | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
323 | parser.add_argument("--num-classes", type=int, default=1000)
324 | parser.add_argument("--epochs", type=int, default=1400)
325 | parser.add_argument("--global-batch-size", type=int, default=256)
326 | parser.add_argument("--global-seed", type=int, default=0)
327 | parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") # Choice doesn't affect training
328 | parser.add_argument("--num-workers", type=int, default=4)
329 | parser.add_argument("--log-every", type=int, default=100)
330 | parser.add_argument("--ckpt-every", type=int, default=50_000)
331 | parser.add_argument("--sample-every", type=int, default=10_000)
332 | parser.add_argument("--cfg-scale", type=float, default=4.0)
333 | parser.add_argument("--wandb", action="store_true")
334 | parser.add_argument("--ckpt", type=str, default=None,
335 | help="Optional path to a custom SiT checkpoint")
336 |
337 | parse_transport_args(parser)
338 | args = parser.parse_args()
339 | main(args)
340 |
--------------------------------------------------------------------------------
/train_utils.py:
--------------------------------------------------------------------------------
1 | def none_or_str(value):
2 | if value == 'None':
3 | return None
4 | return value
5 |
6 | def parse_transport_args(parser):
7 | group = parser.add_argument_group("Transport arguments")
8 | group.add_argument("--path-type", type=str, default="Linear", choices=["Linear", "GVP", "VP"])
9 | group.add_argument("--prediction", type=str, default="velocity", choices=["velocity", "score", "noise"])
10 | group.add_argument("--loss-weight", type=none_or_str, default=None, choices=[None, "velocity", "likelihood"])
11 | group.add_argument("--sample-eps", type=float)
12 | group.add_argument("--train-eps", type=float)
13 |
14 | def parse_ode_args(parser):
15 | group = parser.add_argument_group("ODE arguments")
16 | group.add_argument("--sampling-method", type=str, default="dopri5", help="blackbox ODE solver methods; for full list check https://github.com/rtqichen/torchdiffeq")
17 | group.add_argument("--atol", type=float, default=1e-6, help="Absolute tolerance")
18 | group.add_argument("--rtol", type=float, default=1e-3, help="Relative tolerance")
19 | group.add_argument("--reverse", action="store_true")
20 | group.add_argument("--likelihood", action="store_true")
21 |
22 | def parse_sde_args(parser):
23 | group = parser.add_argument_group("SDE arguments")
24 | group.add_argument("--sampling-method", type=str, default="Euler", choices=["Euler", "Heun"])
25 | group.add_argument("--diffusion-form", type=str, default="sigma", \
26 | choices=["constant", "SBDM", "sigma", "linear", "decreasing", "increasing-decreasing"],\
27 | help="form of diffusion coefficient in the SDE")
28 | group.add_argument("--diffusion-norm", type=float, default=1.0)
29 | group.add_argument("--last-step", type=none_or_str, default="Mean", choices=[None, "Mean", "Tweedie", "Euler"],\
30 | help="form of last step taken in the SDE")
31 | group.add_argument("--last-step-size", type=float, default=0.04, \
32 | help="size of the last step taken")
--------------------------------------------------------------------------------
/transport/__init__.py:
--------------------------------------------------------------------------------
1 | from .transport import Transport, ModelType, WeightType, PathType, Sampler
2 |
3 | def create_transport(
4 | path_type='Linear',
5 | prediction="velocity",
6 | loss_weight=None,
7 | train_eps=None,
8 | sample_eps=None,
9 | ):
10 | """function for creating Transport object
11 | **Note**: model prediction defaults to velocity
12 | Args:
13 | - path_type: type of path to use; default to linear
14 | - learn_score: set model prediction to score
15 | - learn_noise: set model prediction to noise
16 | - velocity_weighted: weight loss by velocity weight
17 | - likelihood_weighted: weight loss by likelihood weight
18 | - train_eps: small epsilon for avoiding instability during training
19 | - sample_eps: small epsilon for avoiding instability during sampling
20 | """
21 |
22 | if prediction == "noise":
23 | model_type = ModelType.NOISE
24 | elif prediction == "score":
25 | model_type = ModelType.SCORE
26 | else:
27 | model_type = ModelType.VELOCITY
28 |
29 | if loss_weight == "velocity":
30 | loss_type = WeightType.VELOCITY
31 | elif loss_weight == "likelihood":
32 | loss_type = WeightType.LIKELIHOOD
33 | else:
34 | loss_type = WeightType.NONE
35 |
36 | path_choice = {
37 | "Linear": PathType.LINEAR,
38 | "GVP": PathType.GVP,
39 | "VP": PathType.VP,
40 | }
41 |
42 | path_type = path_choice[path_type]
43 |
44 | if (path_type in [PathType.VP]):
45 | train_eps = 1e-5 if train_eps is None else train_eps
46 | sample_eps = 1e-3 if train_eps is None else sample_eps
47 | elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY):
48 | train_eps = 1e-3 if train_eps is None else train_eps
49 | sample_eps = 1e-3 if train_eps is None else sample_eps
50 | else: # velocity & [GVP, LINEAR] is stable everywhere
51 | train_eps = 0
52 | sample_eps = 0
53 |
54 | # create flow state
55 | state = Transport(
56 | model_type=model_type,
57 | path_type=path_type,
58 | loss_type=loss_type,
59 | train_eps=train_eps,
60 | sample_eps=sample_eps,
61 | )
62 |
63 | return state
--------------------------------------------------------------------------------
/transport/integrators.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch as th
3 | import torch.nn as nn
4 | from torchdiffeq import odeint
5 | from functools import partial
6 | from tqdm import tqdm
7 |
8 | class sde:
9 | """SDE solver class"""
10 | def __init__(
11 | self,
12 | drift,
13 | diffusion,
14 | *,
15 | t0,
16 | t1,
17 | num_steps,
18 | sampler_type,
19 | ):
20 | assert t0 < t1, "SDE sampler has to be in forward time"
21 |
22 | self.num_timesteps = num_steps
23 | self.t = th.linspace(t0, t1, num_steps)
24 | self.dt = self.t[1] - self.t[0]
25 | self.drift = drift
26 | self.diffusion = diffusion
27 | self.sampler_type = sampler_type
28 |
29 | def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):
30 | w_cur = th.randn(x.size()).to(x)
31 | t = th.ones(x.size(0)).to(x) * t
32 | dw = w_cur * th.sqrt(self.dt)
33 | drift = self.drift(x, t, model, **model_kwargs)
34 | diffusion = self.diffusion(x, t)
35 | mean_x = x + drift * self.dt
36 | x = mean_x + th.sqrt(2 * diffusion) * dw
37 | return x, mean_x
38 |
39 | def __Heun_step(self, x, _, t, model, **model_kwargs):
40 | w_cur = th.randn(x.size()).to(x)
41 | dw = w_cur * th.sqrt(self.dt)
42 | t_cur = th.ones(x.size(0)).to(x) * t
43 | diffusion = self.diffusion(x, t_cur)
44 | xhat = x + th.sqrt(2 * diffusion) * dw
45 | K1 = self.drift(xhat, t_cur, model, **model_kwargs)
46 | xp = xhat + self.dt * K1
47 | K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs)
48 | return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step
49 |
50 | def __forward_fn(self):
51 | """TODO: generalize here by adding all private functions ending with steps to it"""
52 | sampler_dict = {
53 | "Euler": self.__Euler_Maruyama_step,
54 | "Heun": self.__Heun_step,
55 | }
56 |
57 | try:
58 | sampler = sampler_dict[self.sampler_type]
59 | except:
60 | raise NotImplementedError("Smapler type not implemented.")
61 |
62 | return sampler
63 |
64 | def sample(self, init, model, **model_kwargs):
65 | """forward loop of sde"""
66 | x = init
67 | mean_x = init
68 | samples = []
69 | sampler = self.__forward_fn()
70 | for ti in self.t[:-1]:
71 | with th.no_grad():
72 | x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
73 | samples.append(x)
74 |
75 | return samples
76 |
77 | class ode:
78 | """ODE solver class"""
79 | def __init__(
80 | self,
81 | drift,
82 | *,
83 | t0,
84 | t1,
85 | sampler_type,
86 | num_steps,
87 | atol,
88 | rtol,
89 | ):
90 | assert t0 < t1, "ODE sampler has to be in forward time"
91 |
92 | self.drift = drift
93 | self.t = th.linspace(t0, t1, num_steps)
94 | self.atol = atol
95 | self.rtol = rtol
96 | self.sampler_type = sampler_type
97 |
98 | def sample(self, x, model, **model_kwargs):
99 |
100 | device = x[0].device if isinstance(x, tuple) else x.device
101 | def _fn(t, x):
102 | t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t
103 | model_output = self.drift(x, t, model, **model_kwargs)
104 | return model_output
105 |
106 | t = self.t.to(device)
107 | atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
108 | rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
109 | samples = odeint(
110 | _fn,
111 | x,
112 | t,
113 | method=self.sampler_type,
114 | atol=atol,
115 | rtol=rtol
116 | )
117 | return samples
--------------------------------------------------------------------------------
/transport/path.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 | import numpy as np
3 | from functools import partial
4 |
5 | def expand_t_like_x(t, x):
6 | """Function to reshape time t to broadcastable dimension of x
7 | Args:
8 | t: [batch_dim,], time vector
9 | x: [batch_dim,...], data point
10 | """
11 | dims = [1] * (len(x.size()) - 1)
12 | t = t.view(t.size(0), *dims)
13 | return t
14 |
15 |
16 | #################### Coupling Plans ####################
17 |
18 | class ICPlan:
19 | """Linear Coupling Plan"""
20 | def __init__(self, sigma=0.0):
21 | self.sigma = sigma
22 |
23 | def compute_alpha_t(self, t):
24 | """Compute the data coefficient along the path"""
25 | return t, 1
26 |
27 | def compute_sigma_t(self, t):
28 | """Compute the noise coefficient along the path"""
29 | return 1 - t, -1
30 |
31 | def compute_d_alpha_alpha_ratio_t(self, t):
32 | """Compute the ratio between d_alpha and alpha"""
33 | return 1 / t
34 |
35 | def compute_drift(self, x, t):
36 | """We always output sde according to score parametrization; """
37 | t = expand_t_like_x(t, x)
38 | alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t)
39 | sigma_t, d_sigma_t = self.compute_sigma_t(t)
40 | drift = alpha_ratio * x
41 | diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t
42 |
43 | return -drift, diffusion
44 |
45 | def compute_diffusion(self, x, t, form="constant", norm=1.0):
46 | """Compute the diffusion term of the SDE
47 | Args:
48 | x: [batch_dim, ...], data point
49 | t: [batch_dim,], time vector
50 | form: str, form of the diffusion term
51 | norm: float, norm of the diffusion term
52 | """
53 | t = expand_t_like_x(t, x)
54 | choices = {
55 | "constant": norm,
56 | "SBDM": norm * self.compute_drift(x, t)[1],
57 | "sigma": norm * self.compute_sigma_t(t)[0],
58 | "linear": norm * (1 - t),
59 | "decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2,
60 | "inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2,
61 | }
62 |
63 | try:
64 | diffusion = choices[form]
65 | except KeyError:
66 | raise NotImplementedError(f"Diffusion form {form} not implemented")
67 |
68 | return diffusion
69 |
70 | def get_score_from_velocity(self, velocity, x, t):
71 | """Wrapper function: transfrom velocity prediction model to score
72 | Args:
73 | velocity: [batch_dim, ...] shaped tensor; velocity model output
74 | x: [batch_dim, ...] shaped tensor; x_t data point
75 | t: [batch_dim,] time tensor
76 | """
77 | t = expand_t_like_x(t, x)
78 | alpha_t, d_alpha_t = self.compute_alpha_t(t)
79 | sigma_t, d_sigma_t = self.compute_sigma_t(t)
80 | mean = x
81 | reverse_alpha_ratio = alpha_t / d_alpha_t
82 | var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
83 | score = (reverse_alpha_ratio * velocity - mean) / var
84 | return score
85 |
86 | def get_noise_from_velocity(self, velocity, x, t):
87 | """Wrapper function: transfrom velocity prediction model to denoiser
88 | Args:
89 | velocity: [batch_dim, ...] shaped tensor; velocity model output
90 | x: [batch_dim, ...] shaped tensor; x_t data point
91 | t: [batch_dim,] time tensor
92 | """
93 | t = expand_t_like_x(t, x)
94 | alpha_t, d_alpha_t = self.compute_alpha_t(t)
95 | sigma_t, d_sigma_t = self.compute_sigma_t(t)
96 | mean = x
97 | reverse_alpha_ratio = alpha_t / d_alpha_t
98 | var = reverse_alpha_ratio * d_sigma_t - sigma_t
99 | noise = (reverse_alpha_ratio * velocity - mean) / var
100 | return noise
101 |
102 | def get_velocity_from_score(self, score, x, t):
103 | """Wrapper function: transfrom score prediction model to velocity
104 | Args:
105 | score: [batch_dim, ...] shaped tensor; score model output
106 | x: [batch_dim, ...] shaped tensor; x_t data point
107 | t: [batch_dim,] time tensor
108 | """
109 | t = expand_t_like_x(t, x)
110 | drift, var = self.compute_drift(x, t)
111 | velocity = var * score - drift
112 | return velocity
113 |
114 | def compute_mu_t(self, t, x0, x1):
115 | """Compute the mean of time-dependent density p_t"""
116 | t = expand_t_like_x(t, x1)
117 | alpha_t, _ = self.compute_alpha_t(t)
118 | sigma_t, _ = self.compute_sigma_t(t)
119 | return alpha_t * x1 + sigma_t * x0
120 |
121 | def compute_xt(self, t, x0, x1):
122 | """Sample xt from time-dependent density p_t; rng is required"""
123 | xt = self.compute_mu_t(t, x0, x1)
124 | return xt
125 |
126 | def compute_ut(self, t, x0, x1, xt):
127 | """Compute the vector field corresponding to p_t"""
128 | t = expand_t_like_x(t, x1)
129 | _, d_alpha_t = self.compute_alpha_t(t)
130 | _, d_sigma_t = self.compute_sigma_t(t)
131 | return d_alpha_t * x1 + d_sigma_t * x0
132 |
133 | def plan(self, t, x0, x1):
134 | xt = self.compute_xt(t, x0, x1)
135 | ut = self.compute_ut(t, x0, x1, xt)
136 | return t, xt, ut
137 |
138 |
139 | class VPCPlan(ICPlan):
140 | """class for VP path flow matching"""
141 |
142 | def __init__(self, sigma_min=0.1, sigma_max=20.0):
143 | self.sigma_min = sigma_min
144 | self.sigma_max = sigma_max
145 | self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min
146 | self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min
147 |
148 |
149 | def compute_alpha_t(self, t):
150 | """Compute coefficient of x1"""
151 | alpha_t = self.log_mean_coeff(t)
152 | alpha_t = th.exp(alpha_t)
153 | d_alpha_t = alpha_t * self.d_log_mean_coeff(t)
154 | return alpha_t, d_alpha_t
155 |
156 | def compute_sigma_t(self, t):
157 | """Compute coefficient of x0"""
158 | p_sigma_t = 2 * self.log_mean_coeff(t)
159 | sigma_t = th.sqrt(1 - th.exp(p_sigma_t))
160 | d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t)
161 | return sigma_t, d_sigma_t
162 |
163 | def compute_d_alpha_alpha_ratio_t(self, t):
164 | """Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
165 | return self.d_log_mean_coeff(t)
166 |
167 | def compute_drift(self, x, t):
168 | """Compute the drift term of the SDE"""
169 | t = expand_t_like_x(t, x)
170 | beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min)
171 | return -0.5 * beta_t * x, beta_t / 2
172 |
173 |
174 | class GVPCPlan(ICPlan):
175 | def __init__(self, sigma=0.0):
176 | super().__init__(sigma)
177 |
178 | def compute_alpha_t(self, t):
179 | """Compute coefficient of x1"""
180 | alpha_t = th.sin(t * np.pi / 2)
181 | d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2)
182 | return alpha_t, d_alpha_t
183 |
184 | def compute_sigma_t(self, t):
185 | """Compute coefficient of x0"""
186 | sigma_t = th.cos(t * np.pi / 2)
187 | d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2)
188 | return sigma_t, d_sigma_t
189 |
190 | def compute_d_alpha_alpha_ratio_t(self, t):
191 | """Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
192 | return np.pi / (2 * th.tan(t * np.pi / 2))
--------------------------------------------------------------------------------
/transport/transport.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 | import numpy as np
3 | import logging
4 |
5 | import enum
6 |
7 | from . import path
8 | from .utils import EasyDict, log_state, mean_flat
9 | from .integrators import ode, sde
10 |
11 | class ModelType(enum.Enum):
12 | """
13 | Which type of output the model predicts.
14 | """
15 |
16 | NOISE = enum.auto() # the model predicts epsilon
17 | SCORE = enum.auto() # the model predicts \nabla \log p(x)
18 | VELOCITY = enum.auto() # the model predicts v(x)
19 |
20 | class PathType(enum.Enum):
21 | """
22 | Which type of path to use.
23 | """
24 |
25 | LINEAR = enum.auto()
26 | GVP = enum.auto()
27 | VP = enum.auto()
28 |
29 | class WeightType(enum.Enum):
30 | """
31 | Which type of weighting to use.
32 | """
33 |
34 | NONE = enum.auto()
35 | VELOCITY = enum.auto()
36 | LIKELIHOOD = enum.auto()
37 |
38 |
39 | class Transport:
40 |
41 | def __init__(
42 | self,
43 | *,
44 | model_type,
45 | path_type,
46 | loss_type,
47 | train_eps,
48 | sample_eps,
49 | ):
50 | path_options = {
51 | PathType.LINEAR: path.ICPlan,
52 | PathType.GVP: path.GVPCPlan,
53 | PathType.VP: path.VPCPlan,
54 | }
55 |
56 | self.loss_type = loss_type
57 | self.model_type = model_type
58 | self.path_sampler = path_options[path_type]()
59 | self.train_eps = train_eps
60 | self.sample_eps = sample_eps
61 |
62 | def prior_logp(self, z):
63 | '''
64 | Standard multivariate normal prior
65 | Assume z is batched
66 | '''
67 | shape = th.tensor(z.size())
68 | N = th.prod(shape[1:])
69 | _fn = lambda x: -N / 2. * np.log(2 * np.pi) - th.sum(x ** 2) / 2.
70 | return th.vmap(_fn)(z)
71 |
72 |
73 | def check_interval(
74 | self,
75 | train_eps,
76 | sample_eps,
77 | *,
78 | diffusion_form="SBDM",
79 | sde=False,
80 | reverse=False,
81 | eval=False,
82 | last_step_size=0.0,
83 | ):
84 | t0 = 0
85 | t1 = 1
86 | eps = train_eps if not eval else sample_eps
87 | if (type(self.path_sampler) in [path.VPCPlan]):
88 |
89 | t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
90 |
91 | elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) \
92 | and (self.model_type != ModelType.VELOCITY or sde): # avoid numerical issue by taking a first semi-implicit step
93 |
94 | t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0
95 | t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
96 |
97 | if reverse:
98 | t0, t1 = 1 - t0, 1 - t1
99 |
100 | return t0, t1
101 |
102 |
103 | def sample(self, x1):
104 | """Sampling x0 & t based on shape of x1 (if needed)
105 | Args:
106 | x1 - data point; [batch, *dim]
107 | """
108 |
109 | x0 = th.randn_like(x1)
110 | t0, t1 = self.check_interval(self.train_eps, self.sample_eps)
111 | t = th.rand((x1.shape[0],)) * (t1 - t0) + t0
112 | t = t.to(x1)
113 | return t, x0, x1
114 |
115 |
116 | def training_losses(
117 | self,
118 | model,
119 | x1,
120 | model_kwargs=None
121 | ):
122 | """Loss for training the score model
123 | Args:
124 | - model: backbone model; could be score, noise, or velocity
125 | - x1: datapoint
126 | - model_kwargs: additional arguments for the model
127 | """
128 | if model_kwargs == None:
129 | model_kwargs = {}
130 |
131 | t, x0, x1 = self.sample(x1)
132 | t, xt, ut = self.path_sampler.plan(t, x0, x1)
133 | model_output = model(xt, t, **model_kwargs)
134 | B, *_, C = xt.shape
135 | assert model_output.size() == (B, *xt.size()[1:-1], C)
136 |
137 | terms = {}
138 | terms['pred'] = model_output
139 | if self.model_type == ModelType.VELOCITY:
140 | terms['loss'] = mean_flat(((model_output - ut) ** 2))
141 | else:
142 | _, drift_var = self.path_sampler.compute_drift(xt, t)
143 | sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt))
144 | if self.loss_type in [WeightType.VELOCITY]:
145 | weight = (drift_var / sigma_t) ** 2
146 | elif self.loss_type in [WeightType.LIKELIHOOD]:
147 | weight = drift_var / (sigma_t ** 2)
148 | elif self.loss_type in [WeightType.NONE]:
149 | weight = 1
150 | else:
151 | raise NotImplementedError()
152 |
153 | if self.model_type == ModelType.NOISE:
154 | terms['loss'] = mean_flat(weight * ((model_output - x0) ** 2))
155 | else:
156 | terms['loss'] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2))
157 |
158 | return terms
159 |
160 |
161 | def get_drift(
162 | self
163 | ):
164 | """member function for obtaining the drift of the probability flow ODE"""
165 | def score_ode(x, t, model, **model_kwargs):
166 | drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
167 | model_output = model(x, t, **model_kwargs)
168 | return (-drift_mean + drift_var * model_output) # by change of variable
169 |
170 | def noise_ode(x, t, model, **model_kwargs):
171 | drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
172 | sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))
173 | model_output = model(x, t, **model_kwargs)
174 | score = model_output / -sigma_t
175 | return (-drift_mean + drift_var * score)
176 |
177 | def velocity_ode(x, t, model, **model_kwargs):
178 | model_output = model(x, t, **model_kwargs)
179 | return model_output
180 |
181 | if self.model_type == ModelType.NOISE:
182 | drift_fn = noise_ode
183 | elif self.model_type == ModelType.SCORE:
184 | drift_fn = score_ode
185 | else:
186 | drift_fn = velocity_ode
187 |
188 | def body_fn(x, t, model, **model_kwargs):
189 | model_output = drift_fn(x, t, model, **model_kwargs)
190 | assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape"
191 | return model_output
192 |
193 | return body_fn
194 |
195 |
196 | def get_score(
197 | self,
198 | ):
199 | """member function for obtaining score of
200 | x_t = alpha_t * x + sigma_t * eps"""
201 | if self.model_type == ModelType.NOISE:
202 | score_fn = lambda x, t, model, **kwargs: model(x, t, **kwargs) / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0]
203 | elif self.model_type == ModelType.SCORE:
204 | score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs)
205 | elif self.model_type == ModelType.VELOCITY:
206 | score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x, t)
207 | else:
208 | raise NotImplementedError()
209 |
210 | return score_fn
211 |
212 |
213 | class Sampler:
214 | """Sampler class for the transport model"""
215 | def __init__(
216 | self,
217 | transport,
218 | ):
219 | """Constructor for a general sampler; supporting different sampling methods
220 | Args:
221 | - transport: an tranport object specify model prediction & interpolant type
222 | """
223 |
224 | self.transport = transport
225 | self.drift = self.transport.get_drift()
226 | self.score = self.transport.get_score()
227 |
228 | def __get_sde_diffusion_and_drift(
229 | self,
230 | *,
231 | diffusion_form="SBDM",
232 | diffusion_norm=1.0,
233 | ):
234 |
235 | def diffusion_fn(x, t):
236 | diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm)
237 | return diffusion
238 |
239 | sde_drift = \
240 | lambda x, t, model, **kwargs: \
241 | self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs)
242 |
243 | sde_diffusion = diffusion_fn
244 |
245 | return sde_drift, sde_diffusion
246 |
247 | def __get_last_step(
248 | self,
249 | sde_drift,
250 | *,
251 | last_step,
252 | last_step_size,
253 | ):
254 | """Get the last step function of the SDE solver"""
255 |
256 | if last_step is None:
257 | last_step_fn = \
258 | lambda x, t, model, **model_kwargs: \
259 | x
260 | elif last_step == "Mean":
261 | last_step_fn = \
262 | lambda x, t, model, **model_kwargs: \
263 | x + sde_drift(x, t, model, **model_kwargs) * last_step_size
264 | elif last_step == "Tweedie":
265 | alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long
266 | sigma = self.transport.path_sampler.compute_sigma_t
267 | last_step_fn = \
268 | lambda x, t, model, **model_kwargs: \
269 | x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][0] * self.score(x, t, model, **model_kwargs)
270 | elif last_step == "Euler":
271 | last_step_fn = \
272 | lambda x, t, model, **model_kwargs: \
273 | x + self.drift(x, t, model, **model_kwargs) * last_step_size
274 | else:
275 | raise NotImplementedError()
276 |
277 | return last_step_fn
278 |
279 | def sample_sde(
280 | self,
281 | *,
282 | sampling_method="Euler",
283 | diffusion_form="SBDM",
284 | diffusion_norm=1.0,
285 | last_step="Mean",
286 | last_step_size=0.04,
287 | num_steps=250,
288 | ):
289 | """returns a sampling function with given SDE settings
290 | Args:
291 | - sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama
292 | - diffusion_form: function form of diffusion coefficient; default to be matching SBDM
293 | - diffusion_norm: function magnitude of diffusion coefficient; default to 1
294 | - last_step: type of the last step; default to identity
295 | - last_step_size: size of the last step; default to match the stride of 250 steps over [0,1]
296 | - num_steps: total integration step of SDE
297 | """
298 |
299 | if last_step is None:
300 | last_step_size = 0.0
301 |
302 | sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift(
303 | diffusion_form=diffusion_form,
304 | diffusion_norm=diffusion_norm,
305 | )
306 |
307 | t0, t1 = self.transport.check_interval(
308 | self.transport.train_eps,
309 | self.transport.sample_eps,
310 | diffusion_form=diffusion_form,
311 | sde=True,
312 | eval=True,
313 | reverse=False,
314 | last_step_size=last_step_size,
315 | )
316 |
317 | _sde = sde(
318 | sde_drift,
319 | sde_diffusion,
320 | t0=t0,
321 | t1=t1,
322 | num_steps=num_steps,
323 | sampler_type=sampling_method
324 | )
325 |
326 | last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size)
327 |
328 |
329 | def _sample(init, model, **model_kwargs):
330 | xs = _sde.sample(init, model, **model_kwargs)
331 | ts = th.ones(init.size(0), device=init.device) * t1
332 | x = last_step_fn(xs[-1], ts, model, **model_kwargs)
333 | xs.append(x)
334 |
335 | assert len(xs) == num_steps, "Samples does not match the number of steps"
336 |
337 | return xs
338 |
339 | return _sample
340 |
341 | def sample_ode(
342 | self,
343 | *,
344 | sampling_method="dopri5",
345 | num_steps=50,
346 | atol=1e-6,
347 | rtol=1e-3,
348 | reverse=False,
349 | ):
350 | """returns a sampling function with given ODE settings
351 | Args:
352 | - sampling_method: type of sampler used in solving the ODE; default to be Dopri5
353 | - num_steps:
354 | - fixed solver (Euler, Heun): the actual number of integration steps performed
355 | - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
356 | - atol: absolute error tolerance for the solver
357 | - rtol: relative error tolerance for the solver
358 | - reverse: whether solving the ODE in reverse (data to noise); default to False
359 | """
360 | if reverse:
361 | drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs)
362 | else:
363 | drift = self.drift
364 |
365 | t0, t1 = self.transport.check_interval(
366 | self.transport.train_eps,
367 | self.transport.sample_eps,
368 | sde=False,
369 | eval=True,
370 | reverse=reverse,
371 | last_step_size=0.0,
372 | )
373 |
374 | _ode = ode(
375 | drift=drift,
376 | t0=t0,
377 | t1=t1,
378 | sampler_type=sampling_method,
379 | num_steps=num_steps,
380 | atol=atol,
381 | rtol=rtol,
382 | )
383 |
384 | return _ode.sample
385 |
386 | def sample_ode_likelihood(
387 | self,
388 | *,
389 | sampling_method="dopri5",
390 | num_steps=50,
391 | atol=1e-6,
392 | rtol=1e-3,
393 | ):
394 |
395 | """returns a sampling function for calculating likelihood with given ODE settings
396 | Args:
397 | - sampling_method: type of sampler used in solving the ODE; default to be Dopri5
398 | - num_steps:
399 | - fixed solver (Euler, Heun): the actual number of integration steps performed
400 | - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
401 | - atol: absolute error tolerance for the solver
402 | - rtol: relative error tolerance for the solver
403 | """
404 | def _likelihood_drift(x, t, model, **model_kwargs):
405 | x, _ = x
406 | eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1
407 | t = th.ones_like(t) * (1 - t)
408 | with th.enable_grad():
409 | x.requires_grad = True
410 | grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0]
411 | logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size()))))
412 | drift = self.drift(x, t, model, **model_kwargs)
413 | return (-drift, logp_grad)
414 |
415 | t0, t1 = self.transport.check_interval(
416 | self.transport.train_eps,
417 | self.transport.sample_eps,
418 | sde=False,
419 | eval=True,
420 | reverse=False,
421 | last_step_size=0.0,
422 | )
423 |
424 | _ode = ode(
425 | drift=_likelihood_drift,
426 | t0=t0,
427 | t1=t1,
428 | sampler_type=sampling_method,
429 | num_steps=num_steps,
430 | atol=atol,
431 | rtol=rtol,
432 | )
433 |
434 | def _sample_fn(x, model, **model_kwargs):
435 | init_logp = th.zeros(x.size(0)).to(x)
436 | input = (x, init_logp)
437 | drift, delta_logp = _ode.sample(input, model, **model_kwargs)
438 | drift, delta_logp = drift[-1], delta_logp[-1]
439 | prior_logp = self.transport.prior_logp(drift)
440 | logp = prior_logp - delta_logp
441 | return logp, drift
442 |
443 | return _sample_fn
--------------------------------------------------------------------------------
/transport/utils.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 |
3 | class EasyDict:
4 |
5 | def __init__(self, sub_dict):
6 | for k, v in sub_dict.items():
7 | setattr(self, k, v)
8 |
9 | def __getitem__(self, key):
10 | return getattr(self, key)
11 |
12 | def mean_flat(x):
13 | """
14 | Take the mean over all non-batch dimensions.
15 | """
16 | return th.mean(x, dim=list(range(1, len(x.size()))))
17 |
18 | def log_state(state):
19 | result = []
20 |
21 | sorted_state = dict(sorted(state.items()))
22 | for key, value in sorted_state.items():
23 | # Check if the value is an instance of a class
24 | if "