├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── ar-diffusion.png
├── autoregressive_diffusion_pytorch
├── __init__.py
├── autoregressive_diffusion.py
├── autoregressive_flow.py
└── image_trainer.py
├── images
└── results.96600.png
└── pyproject.toml
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | jobs:
16 | deploy:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: '3.x'
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install build
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34 | with:
35 | user: __token__
36 | password: ${{ secrets.PYPI_API_TOKEN }}
37 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Autoregressive Diffusion - Pytorch
4 |
5 | Implementation of the architecture behind Autoregressive Image Generation without Vector Quantization in Pytorch
6 |
7 | Official repository has been released here
8 |
9 | Alternative route
10 |
11 |
12 |
13 | *oxford flowers at 96k steps*
14 |
15 | ## Install
16 |
17 | ```bash
18 | $ pip install autoregressive-diffusion-pytorch
19 | ```
20 |
21 | ## Usage
22 |
23 | ```python
24 | import torch
25 | from autoregressive_diffusion_pytorch import AutoregressiveDiffusion
26 |
27 | model = AutoregressiveDiffusion(
28 | dim_input = 512,
29 | dim = 1024,
30 | max_seq_len = 32,
31 | depth = 8,
32 | mlp_depth = 3,
33 | mlp_width = 1024
34 | )
35 |
36 | seq = torch.randn(3, 32, 512)
37 |
38 | loss = model(seq)
39 | loss.backward()
40 |
41 | sampled = model.sample(batch_size = 3)
42 |
43 | assert sampled.shape == seq.shape
44 |
45 | ```
46 |
47 | For images treated as a sequence of tokens (as in paper)
48 |
49 | ```python
50 | import torch
51 | from autoregressive_diffusion_pytorch import ImageAutoregressiveDiffusion
52 |
53 | model = ImageAutoregressiveDiffusion(
54 | model = dict(
55 | dim = 1024,
56 | depth = 12,
57 | heads = 12,
58 | ),
59 | image_size = 64,
60 | patch_size = 8
61 | )
62 |
63 | images = torch.randn(3, 3, 64, 64)
64 |
65 | loss = model(images)
66 | loss.backward()
67 |
68 | sampled = model.sample(batch_size = 3)
69 |
70 | assert sampled.shape == images.shape
71 |
72 | ```
73 |
74 | An images trainer
75 |
76 | ```python
77 | import torch
78 |
79 | from autoregressive_diffusion_pytorch import (
80 | ImageDataset,
81 | ImageAutoregressiveDiffusion,
82 | ImageTrainer
83 | )
84 |
85 | dataset = ImageDataset(
86 | '/path/to/your/images',
87 | image_size = 128
88 | )
89 |
90 | model = ImageAutoregressiveDiffusion(
91 | model = dict(
92 | dim = 512
93 | ),
94 | image_size = 128,
95 | patch_size = 16
96 | )
97 |
98 | trainer = ImageTrainer(
99 | model = model,
100 | dataset = dataset
101 | )
102 |
103 | trainer()
104 | ```
105 |
106 | For an improvised version using flow matching, just import `ImageAutoregressiveFlow` and `AutoregressiveFlow` instead
107 |
108 | The rest is the same
109 |
110 | ex.
111 |
112 | ```python
113 | import torch
114 |
115 | from autoregressive_diffusion_pytorch import (
116 | ImageDataset,
117 | ImageTrainer,
118 | ImageAutoregressiveFlow,
119 | )
120 |
121 | dataset = ImageDataset(
122 | '/path/to/your/images',
123 | image_size = 128
124 | )
125 |
126 | model = ImageAutoregressiveFlow(
127 | model = dict(
128 | dim = 512
129 | ),
130 | image_size = 128,
131 | patch_size = 16
132 | )
133 |
134 | trainer = ImageTrainer(
135 | model = model,
136 | dataset = dataset
137 | )
138 |
139 | trainer()
140 | ```
141 |
142 | ## Citations
143 |
144 | ```bibtex
145 | @article{Li2024AutoregressiveIG,
146 | title = {Autoregressive Image Generation without Vector Quantization},
147 | author = {Tianhong Li and Yonglong Tian and He Li and Mingyang Deng and Kaiming He},
148 | journal = {ArXiv},
149 | year = {2024},
150 | volume = {abs/2406.11838},
151 | url = {https://api.semanticscholar.org/CorpusID:270560593}
152 | }
153 | ```
154 |
155 | ```bibtex
156 | @article{Wu2023ARDiffusionAD,
157 | title = {AR-Diffusion: Auto-Regressive Diffusion Model for Text Generation},
158 | author = {Tong Wu and Zhihao Fan and Xiao Liu and Yeyun Gong and Yelong Shen and Jian Jiao and Haitao Zheng and Juntao Li and Zhongyu Wei and Jian Guo and Nan Duan and Weizhu Chen},
159 | journal = {ArXiv},
160 | year = {2023},
161 | volume = {abs/2305.09515},
162 | url = {https://api.semanticscholar.org/CorpusID:258714669}
163 | }
164 | ```
165 |
166 | ```bibtex
167 | @article{Karras2022ElucidatingTD,
168 | title = {Elucidating the Design Space of Diffusion-Based Generative Models},
169 | author = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine},
170 | journal = {ArXiv},
171 | year = {2022},
172 | volume = {abs/2206.00364},
173 | url = {https://api.semanticscholar.org/CorpusID:249240415}
174 | }
175 | ```
176 |
177 | ```bibtex
178 | @article{Liu2022FlowSA,
179 | title = {Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow},
180 | author = {Xingchao Liu and Chengyue Gong and Qiang Liu},
181 | journal = {ArXiv},
182 | year = {2022},
183 | volume = {abs/2209.03003},
184 | url = {https://api.semanticscholar.org/CorpusID:252111177}
185 | }
186 | ```
187 |
188 | ```bibtex
189 | @article{Esser2024ScalingRF,
190 | title = {Scaling Rectified Flow Transformers for High-Resolution Image Synthesis},
191 | author = {Patrick Esser and Sumith Kulal and A. Blattmann and Rahim Entezari and Jonas Muller and Harry Saini and Yam Levi and Dominik Lorenz and Axel Sauer and Frederic Boesel and Dustin Podell and Tim Dockhorn and Zion English and Kyle Lacey and Alex Goodwin and Yannik Marek and Robin Rombach},
192 | journal = {ArXiv},
193 | year = {2024},
194 | volume = {abs/2403.03206},
195 | url = {https://api.semanticscholar.org/CorpusID:268247980}
196 | }
197 | ```
198 |
--------------------------------------------------------------------------------
/ar-diffusion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/autoregressive-diffusion-pytorch/6387e02449f8f5ebd69b0709c9aab11e869bda07/ar-diffusion.png
--------------------------------------------------------------------------------
/autoregressive_diffusion_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from autoregressive_diffusion_pytorch.autoregressive_diffusion import (
2 | MLP,
3 | AutoregressiveDiffusion,
4 | ImageAutoregressiveDiffusion
5 | )
6 |
7 | from autoregressive_diffusion_pytorch.autoregressive_flow import (
8 | AutoregressiveFlow,
9 | ImageAutoregressiveFlow
10 | )
11 |
12 | from autoregressive_diffusion_pytorch.image_trainer import (
13 | ImageDataset,
14 | ImageTrainer
15 | )
16 |
--------------------------------------------------------------------------------
/autoregressive_diffusion_pytorch/autoregressive_diffusion.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import math
4 | from math import sqrt
5 | from typing import Literal
6 | from functools import partial
7 |
8 | import torch
9 | from torch import nn, pi
10 | from torch.special import expm1
11 | import torch.nn.functional as F
12 | from torch.nn import Module, ModuleList
13 |
14 | import einx
15 | from einops import rearrange, repeat, reduce, pack, unpack
16 | from einops.layers.torch import Rearrange
17 |
18 | from tqdm import tqdm
19 |
20 | from x_transformers import Decoder
21 |
22 | # helpers
23 |
24 | def exists(v):
25 | return v is not None
26 |
27 | def default(v, d):
28 | return v if exists(v) else d
29 |
30 | def divisible_by(num, den):
31 | return (num % den) == 0
32 |
33 | # tensor helpers
34 |
35 | def log(t, eps = 1e-20):
36 | return torch.log(t.clamp(min = eps))
37 |
38 | def safe_div(num, den, eps = 1e-5):
39 | return num / den.clamp(min = eps)
40 |
41 | def right_pad_dims_to(x, t):
42 | padding_dims = x.ndim - t.ndim
43 |
44 | if padding_dims <= 0:
45 | return t
46 |
47 | return t.view(*t.shape, *((1,) * padding_dims))
48 |
49 | def pack_one(t, pattern):
50 | packed, ps = pack([t], pattern)
51 |
52 | def unpack_one(to_unpack, unpack_pattern = None):
53 | unpacked, = unpack(to_unpack, ps, default(unpack_pattern, pattern))
54 | return unpacked
55 |
56 | return packed, unpack_one
57 |
58 | # sinusoidal embedding
59 |
60 | class AdaptiveLayerNorm(Module):
61 | def __init__(
62 | self,
63 | dim,
64 | dim_condition = None
65 | ):
66 | super().__init__()
67 | dim_condition = default(dim_condition, dim)
68 |
69 | self.ln = nn.LayerNorm(dim, elementwise_affine = False)
70 | self.to_gamma = nn.Linear(dim_condition, dim, bias = False)
71 | nn.init.zeros_(self.to_gamma.weight)
72 |
73 | def forward(self, x, *, condition):
74 | normed = self.ln(x)
75 | gamma = self.to_gamma(condition)
76 | return normed * (gamma + 1.)
77 |
78 | class LearnedSinusoidalPosEmb(Module):
79 | def __init__(self, dim):
80 | super().__init__()
81 | assert divisible_by(dim, 2)
82 | half_dim = dim // 2
83 | self.weights = nn.Parameter(torch.randn(half_dim))
84 |
85 | def forward(self, x):
86 | x = rearrange(x, 'b -> b 1')
87 | freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * pi
88 | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
89 | fouriered = torch.cat((x, fouriered), dim = -1)
90 | return fouriered
91 |
92 | # simple mlp
93 |
94 | class MLP(Module):
95 | def __init__(
96 | self,
97 | dim_cond,
98 | dim_input,
99 | depth = 3,
100 | width = 1024,
101 | dropout = 0.
102 | ):
103 | super().__init__()
104 | layers = ModuleList([])
105 |
106 | self.to_time_emb = nn.Sequential(
107 | LearnedSinusoidalPosEmb(dim_cond),
108 | nn.Linear(dim_cond + 1, dim_cond),
109 | )
110 |
111 | for _ in range(depth):
112 |
113 | adaptive_layernorm = AdaptiveLayerNorm(
114 | dim_input,
115 | dim_condition = dim_cond
116 | )
117 |
118 | block = nn.Sequential(
119 | nn.Linear(dim_input, width),
120 | nn.SiLU(),
121 | nn.Dropout(dropout),
122 | nn.Linear(width, dim_input)
123 | )
124 |
125 | block_out_gamma = nn.Linear(dim_cond, dim_input, bias = False)
126 | nn.init.zeros_(block_out_gamma.weight)
127 |
128 | layers.append(ModuleList([
129 | adaptive_layernorm,
130 | block,
131 | block_out_gamma
132 | ]))
133 |
134 | self.layers = layers
135 |
136 | def forward(
137 | self,
138 | noised,
139 | *,
140 | times,
141 | cond
142 | ):
143 | assert noised.ndim == 2
144 |
145 | time_emb = self.to_time_emb(times)
146 | cond = F.silu(time_emb + cond)
147 |
148 | denoised = noised
149 |
150 | for adaln, block, block_out_gamma in self.layers:
151 | residual = denoised
152 | denoised = adaln(denoised, condition = cond)
153 |
154 | block_out = block(denoised) * (block_out_gamma(cond) + 1.)
155 | denoised = block_out + residual
156 |
157 | return denoised
158 |
159 | # gaussian diffusion
160 |
161 | class ElucidatedDiffusion(Module):
162 | def __init__(
163 | self,
164 | dim: int,
165 | net: MLP,
166 | *,
167 | num_sample_steps = 32, # number of sampling steps
168 | sigma_min = 0.002, # min noise level
169 | sigma_max = 80, # max noise level
170 | sigma_data = 0.5, # standard deviation of data distribution
171 | rho = 7, # controls the sampling schedule
172 | P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training
173 | P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training
174 | S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper
175 | S_tmin = 0.05,
176 | S_tmax = 50,
177 | S_noise = 1.003,
178 | clamp_during_sampling = True
179 | ):
180 | super().__init__()
181 |
182 | self.net = net
183 | self.dim = dim
184 |
185 | # parameters
186 |
187 | self.sigma_min = sigma_min
188 | self.sigma_max = sigma_max
189 | self.sigma_data = sigma_data
190 |
191 | self.rho = rho
192 |
193 | self.P_mean = P_mean
194 | self.P_std = P_std
195 |
196 | self.num_sample_steps = num_sample_steps # otherwise known as N in the paper
197 |
198 | self.S_churn = S_churn
199 | self.S_tmin = S_tmin
200 | self.S_tmax = S_tmax
201 | self.S_noise = S_noise
202 |
203 | self.clamp_during_sampling = clamp_during_sampling
204 |
205 | @property
206 | def device(self):
207 | return next(self.net.parameters()).device
208 |
209 | # derived preconditioning params - Table 1
210 |
211 | def c_skip(self, sigma):
212 | return (self.sigma_data ** 2) / (sigma ** 2 + self.sigma_data ** 2)
213 |
214 | def c_out(self, sigma):
215 | return sigma * self.sigma_data * (self.sigma_data ** 2 + sigma ** 2) ** -0.5
216 |
217 | def c_in(self, sigma):
218 | return 1 * (sigma ** 2 + self.sigma_data ** 2) ** -0.5
219 |
220 | def c_noise(self, sigma):
221 | return log(sigma) * 0.25
222 |
223 | # preconditioned network output
224 | # equation (7) in the paper
225 |
226 | def preconditioned_network_forward(self, noised_seq, sigma, *, cond, clamp = None):
227 | clamp = default(clamp, self.clamp_during_sampling)
228 |
229 | batch, device = noised_seq.shape[0], noised_seq.device
230 |
231 | if isinstance(sigma, float):
232 | sigma = torch.full((batch,), sigma, device = device)
233 |
234 | padded_sigma = right_pad_dims_to(noised_seq, sigma)
235 |
236 | net_out = self.net(
237 | self.c_in(padded_sigma) * noised_seq,
238 | times = self.c_noise(sigma),
239 | cond = cond
240 | )
241 |
242 | out = self.c_skip(padded_sigma) * noised_seq + self.c_out(padded_sigma) * net_out
243 |
244 | if clamp:
245 | out = out.clamp(-1., 1.)
246 |
247 | return out
248 |
249 | # sampling
250 |
251 | # sample schedule
252 | # equation (5) in the paper
253 |
254 | def sample_schedule(self, num_sample_steps = None):
255 | num_sample_steps = default(num_sample_steps, self.num_sample_steps)
256 |
257 | N = num_sample_steps
258 | inv_rho = 1 / self.rho
259 |
260 | steps = torch.arange(num_sample_steps, device = self.device, dtype = torch.float32)
261 | sigmas = (self.sigma_max ** inv_rho + steps / (N - 1) * (self.sigma_min ** inv_rho - self.sigma_max ** inv_rho)) ** self.rho
262 |
263 | sigmas = F.pad(sigmas, (0, 1), value = 0.) # last step is sigma value of 0.
264 | return sigmas
265 |
266 | @torch.no_grad()
267 | def sample(self, cond, num_sample_steps = None, clamp = None):
268 | clamp = default(clamp, self.clamp_during_sampling)
269 | num_sample_steps = default(num_sample_steps, self.num_sample_steps)
270 |
271 | shape = (cond.shape[0], self.dim)
272 |
273 | # get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma
274 |
275 | sigmas = self.sample_schedule(num_sample_steps)
276 |
277 | gammas = torch.where(
278 | (sigmas >= self.S_tmin) & (sigmas <= self.S_tmax),
279 | min(self.S_churn / num_sample_steps, sqrt(2) - 1),
280 | 0.
281 | )
282 |
283 | sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1]))
284 |
285 | # images is noise at the beginning
286 |
287 | init_sigma = sigmas[0]
288 |
289 | seq = init_sigma * torch.randn(shape, device = self.device)
290 |
291 | # gradually denoise
292 |
293 | for sigma, sigma_next, gamma in tqdm(sigmas_and_gammas, desc = 'sampling time step'):
294 | sigma, sigma_next, gamma = map(lambda t: t.item(), (sigma, sigma_next, gamma))
295 |
296 | eps = self.S_noise * torch.randn(shape, device = self.device) # stochastic sampling
297 |
298 | sigma_hat = sigma + gamma * sigma
299 | seq_hat = seq + sqrt(sigma_hat ** 2 - sigma ** 2) * eps
300 |
301 | model_output = self.preconditioned_network_forward(seq_hat, sigma_hat, cond = cond, clamp = clamp)
302 | denoised_over_sigma = (seq_hat - model_output) / sigma_hat
303 |
304 | seq_next = seq_hat + (sigma_next - sigma_hat) * denoised_over_sigma
305 |
306 | # second order correction, if not the last timestep
307 |
308 | if sigma_next != 0:
309 | model_output_next = self.preconditioned_network_forward(seq_next, sigma_next, cond = cond, clamp = clamp)
310 | denoised_prime_over_sigma = (seq_next - model_output_next) / sigma_next
311 | seq_next = seq_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma)
312 |
313 | seq = seq_next
314 |
315 | if clamp:
316 | seq = seq.clamp(-1., 1.)
317 |
318 | return seq
319 |
320 | # training
321 |
322 | def loss_weight(self, sigma):
323 | return (sigma ** 2 + self.sigma_data ** 2) * (sigma * self.sigma_data) ** -2
324 |
325 | def noise_distribution(self, batch_size):
326 | return (self.P_mean + self.P_std * torch.randn((batch_size,), device = self.device)).exp()
327 |
328 | def forward(self, seq, *, cond):
329 | batch_size, dim, device = *seq.shape, self.device
330 |
331 | assert dim == self.dim, f'dimension of sequence being passed in must be {self.dim} but received {dim}'
332 |
333 | sigmas = self.noise_distribution(batch_size)
334 | padded_sigmas = right_pad_dims_to(seq, sigmas)
335 |
336 | noise = torch.randn_like(seq)
337 |
338 | noised_seq = seq + padded_sigmas * noise # alphas are 1. in the paper
339 |
340 | denoised = self.preconditioned_network_forward(noised_seq, sigmas, cond = cond)
341 |
342 | losses = F.mse_loss(denoised, seq, reduction = 'none')
343 | losses = reduce(losses, 'b ... -> b', 'mean')
344 |
345 | losses = losses * self.loss_weight(sigmas)
346 |
347 | return losses.mean()
348 |
349 | # main model, a decoder with continuous wrapper + small denoising mlp
350 |
351 | class AutoregressiveDiffusion(Module):
352 | def __init__(
353 | self,
354 | dim,
355 | *,
356 | max_seq_len,
357 | depth = 8,
358 | dim_head = 64,
359 | heads = 8,
360 | mlp_depth = 3,
361 | mlp_width = None,
362 | dim_input = None,
363 | decoder_kwargs: dict = dict(),
364 | mlp_kwargs: dict = dict(),
365 | diffusion_kwargs: dict = dict(
366 | clamp_during_sampling = True
367 | )
368 | ):
369 | super().__init__()
370 |
371 | self.start_token = nn.Parameter(torch.zeros(dim))
372 | self.max_seq_len = max_seq_len
373 | self.abs_pos_emb = nn.Embedding(max_seq_len, dim)
374 |
375 | dim_input = default(dim_input, dim)
376 | self.dim_input = dim_input
377 | self.proj_in = nn.Linear(dim_input, dim)
378 |
379 | self.transformer = Decoder(
380 | dim = dim,
381 | depth = depth,
382 | heads = heads,
383 | attn_dim_head = dim_head,
384 | **decoder_kwargs
385 | )
386 |
387 | self.denoiser = MLP(
388 | dim_cond = dim,
389 | dim_input = dim_input,
390 | depth = mlp_depth,
391 | width = default(mlp_width, dim),
392 | **mlp_kwargs
393 | )
394 |
395 | self.diffusion = ElucidatedDiffusion(
396 | dim_input,
397 | self.denoiser,
398 | **diffusion_kwargs
399 | )
400 |
401 | @property
402 | def device(self):
403 | return next(self.transformer.parameters()).device
404 |
405 | @torch.no_grad()
406 | def sample(
407 | self,
408 | batch_size = 1,
409 | prompt = None
410 | ):
411 | self.eval()
412 |
413 | start_tokens = repeat(self.start_token, 'd -> b 1 d', b = batch_size)
414 |
415 | if not exists(prompt):
416 | out = torch.empty((batch_size, 0, self.dim_input), device = self.device, dtype = torch.float32)
417 | else:
418 | out = prompt
419 |
420 | cache = None
421 |
422 | for _ in tqdm(range(self.max_seq_len - out.shape[1]), desc = 'tokens'):
423 |
424 | cond = self.proj_in(out)
425 |
426 | cond = torch.cat((start_tokens, cond), dim = 1)
427 | cond = cond + self.abs_pos_emb(torch.arange(cond.shape[1], device = self.device))
428 |
429 | cond, cache = self.transformer(cond, cache = cache, return_hiddens = True)
430 |
431 | last_cond = cond[:, -1]
432 |
433 | denoised_pred = self.diffusion.sample(cond = last_cond)
434 |
435 | denoised_pred = rearrange(denoised_pred, 'b d -> b 1 d')
436 | out = torch.cat((out, denoised_pred), dim = 1)
437 |
438 | return out
439 |
440 | def forward(
441 | self,
442 | seq
443 | ):
444 | b, seq_len, dim = seq.shape
445 |
446 | assert dim == self.dim_input
447 | assert seq_len == self.max_seq_len
448 |
449 | # break into seq and the continuous targets to be predicted
450 |
451 | seq, target = seq[:, :-1], seq
452 |
453 | # append start tokens
454 |
455 | seq = self.proj_in(seq)
456 | start_token = repeat(self.start_token, 'd -> b 1 d', b = b)
457 |
458 | seq = torch.cat((start_token, seq), dim = 1)
459 | seq = seq + self.abs_pos_emb(torch.arange(seq_len, device = self.device))
460 |
461 | cond = self.transformer(seq)
462 |
463 | # pack batch and sequence dimensions, so to train each token with different noise levels
464 |
465 | target, _ = pack_one(target, '* d')
466 | cond, _ = pack_one(cond, '* d')
467 |
468 | diffusion_loss = self.diffusion(target, cond = cond)
469 |
470 | return diffusion_loss
471 |
472 | # image wrapper
473 |
474 | def normalize_to_neg_one_to_one(img):
475 | return img * 2 - 1
476 |
477 | def unnormalize_to_zero_to_one(t):
478 | return (t + 1) * 0.5
479 |
480 | class ImageAutoregressiveDiffusion(Module):
481 | def __init__(
482 | self,
483 | *,
484 | image_size,
485 | patch_size,
486 | channels = 3,
487 | model: dict = dict(),
488 | ):
489 | super().__init__()
490 | assert divisible_by(image_size, patch_size)
491 |
492 | num_patches = (image_size // patch_size) ** 2
493 | dim_in = channels * patch_size ** 2
494 |
495 | self.image_size = image_size
496 | self.patch_size = patch_size
497 |
498 | self.to_tokens = Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size)
499 |
500 | self.model = AutoregressiveDiffusion(
501 | **model,
502 | dim_input = dim_in,
503 | max_seq_len = num_patches
504 | )
505 |
506 | self.to_image = Rearrange('b (h w) (c p1 p2) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size, h = int(math.sqrt(num_patches)))
507 |
508 | def sample(self, batch_size = 1):
509 | tokens = self.model.sample(batch_size = batch_size)
510 | images = self.to_image(tokens)
511 | return unnormalize_to_zero_to_one(images)
512 |
513 | def forward(self, images):
514 | images = normalize_to_neg_one_to_one(images)
515 | tokens = self.to_tokens(images)
516 | return self.model(tokens)
517 |
--------------------------------------------------------------------------------
/autoregressive_diffusion_pytorch/autoregressive_flow.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import math
4 | from math import sqrt
5 | from typing import Literal
6 | from functools import partial
7 |
8 | import torch
9 | from torch import nn, pi
10 | import torch.nn.functional as F
11 | from torch.nn import Module, ModuleList
12 |
13 | from torchdiffeq import odeint
14 |
15 | import einx
16 | from einops import rearrange, repeat, reduce, pack, unpack
17 | from einops.layers.torch import Rearrange
18 |
19 | from tqdm import tqdm
20 |
21 | from x_transformers import Decoder
22 |
23 | from autoregressive_diffusion_pytorch.autoregressive_diffusion import MLP
24 |
25 | # helpers
26 |
27 | def exists(v):
28 | return v is not None
29 |
30 | def default(v, d):
31 | return v if exists(v) else d
32 |
33 | def divisible_by(num, den):
34 | return (num % den) == 0
35 |
36 | def cast_tuple(t):
37 | return (t,) if not isinstance(t, tuple) else t
38 |
39 | # tensor helpers
40 |
41 | def log(t, eps = 1e-20):
42 | return torch.log(t.clamp(min = eps))
43 |
44 | def safe_div(num, den, eps = 1e-5):
45 | return num / den.clamp(min = eps)
46 |
47 | def right_pad_dims_to(x, t):
48 | padding_dims = x.ndim - t.ndim
49 |
50 | if padding_dims <= 0:
51 | return t
52 |
53 | return t.view(*t.shape, *((1,) * padding_dims))
54 |
55 | def pack_one(t, pattern):
56 | packed, ps = pack([t], pattern)
57 |
58 | def unpack_one(to_unpack, unpack_pattern = None):
59 | unpacked, = unpack(to_unpack, ps, default(unpack_pattern, pattern))
60 | return unpacked
61 |
62 | return packed, unpack_one
63 |
64 | # rectified flow
65 |
66 | class Flow(Module):
67 | def __init__(
68 | self,
69 | dim: int,
70 | net: MLP,
71 | *,
72 | atol = 1e-5,
73 | rtol = 1e-5,
74 | method = 'midpoint'
75 | ):
76 | super().__init__()
77 | self.net = net
78 | self.dim = dim
79 |
80 | self.odeint_kwargs = dict(
81 | atol = atol,
82 | rtol = rtol,
83 | method = method
84 | )
85 |
86 | @property
87 | def device(self):
88 | return next(self.net.parameters()).device
89 |
90 | @torch.no_grad()
91 | def sample(
92 | self,
93 | cond,
94 | num_sample_steps = 16
95 | ):
96 |
97 | batch = cond.shape[0]
98 |
99 | sampled_data_shape = (batch, self.dim)
100 |
101 | # start with random gaussian noise - y0
102 |
103 | noise = torch.randn(sampled_data_shape, device = self.device)
104 |
105 | # time steps
106 |
107 | times = torch.linspace(0., 1., num_sample_steps, device = self.device)
108 |
109 | # ode
110 |
111 | def ode_fn(t, x):
112 | t = repeat(t, '-> b', b = batch)
113 | flow = self.net(x, times = t, cond = cond)
114 | return flow
115 |
116 | trajectory = odeint(ode_fn, noise, times, **self.odeint_kwargs)
117 |
118 | sampled = trajectory[-1]
119 |
120 | return sampled
121 |
122 | # training
123 |
124 | def forward(self, seq, *, cond):
125 | batch_size, dim, device = *seq.shape, self.device
126 |
127 | assert dim == self.dim, f'dimension of sequence being passed in must be {self.dim} but received {dim}'
128 |
129 | times = torch.rand(batch_size, device = device)
130 | noise = torch.randn_like(seq)
131 | padded_times = right_pad_dims_to(seq, times)
132 |
133 | flow = seq - noise
134 |
135 | noised = (1.- padded_times) * noise + padded_times * seq
136 |
137 | pred_flow = self.net(noised, times = times, cond = cond)
138 |
139 | return F.mse_loss(pred_flow, flow)
140 |
141 | # main model, a decoder with continuous wrapper + small denoising mlp
142 |
143 | class AutoregressiveFlow(Module):
144 | def __init__(
145 | self,
146 | dim,
147 | *,
148 | max_seq_len: int | tuple[int, ...],
149 | depth = 8,
150 | dim_head = 64,
151 | heads = 8,
152 | mlp_depth = 3,
153 | mlp_width = 1024,
154 | dim_input = None,
155 | decoder_kwargs: dict = dict(),
156 | mlp_kwargs: dict = dict(),
157 | flow_kwargs: dict = dict()
158 | ):
159 | super().__init__()
160 |
161 | self.start_token = nn.Parameter(torch.zeros(dim))
162 |
163 | max_seq_len = cast_tuple(max_seq_len)
164 | self.abs_pos_emb = nn.ParameterList([nn.Parameter(torch.zeros(seq_len, dim)) for seq_len in max_seq_len])
165 |
166 | self.max_seq_len = math.prod(max_seq_len)
167 |
168 | dim_input = default(dim_input, dim)
169 | self.dim_input = dim_input
170 | self.proj_in = nn.Linear(dim_input, dim)
171 |
172 | self.transformer = Decoder(
173 | dim = dim,
174 | depth = depth,
175 | heads = heads,
176 | attn_dim_head = dim_head,
177 | **decoder_kwargs
178 | )
179 |
180 | self.to_cond_emb = nn.Linear(dim, dim, bias = False)
181 |
182 | self.denoiser = MLP(
183 | dim_cond = dim,
184 | dim_input = dim_input,
185 | depth = mlp_depth,
186 | width = mlp_width,
187 | **mlp_kwargs
188 | )
189 |
190 | self.flow = Flow(
191 | dim_input,
192 | self.denoiser,
193 | **flow_kwargs
194 | )
195 |
196 | @property
197 | def device(self):
198 | return next(self.transformer.parameters()).device
199 |
200 | def axial_pos_emb(self):
201 | # prepare maybe axial positional embedding
202 |
203 | pos_emb, *rest_pos_embs = self.abs_pos_emb
204 |
205 | for rest_pos_emb in rest_pos_embs:
206 | pos_emb = einx.add('i d, j d -> (i j) d', pos_emb, rest_pos_emb)
207 |
208 | return F.pad(pos_emb, (0, 0, 1, 0), value = 0.)
209 |
210 | @torch.no_grad()
211 | def sample(
212 | self,
213 | batch_size = 1,
214 | prompt = None
215 | ):
216 | self.eval()
217 |
218 | start_tokens = repeat(self.start_token, 'd -> b 1 d', b = batch_size)
219 |
220 | if not exists(prompt):
221 | out = torch.empty((batch_size, 0, self.dim_input), device = self.device, dtype = torch.float32)
222 | else:
223 | out = prompt
224 |
225 | cache = None
226 |
227 | for _ in tqdm(range(self.max_seq_len - out.shape[1]), desc = 'tokens'):
228 |
229 | cond = self.proj_in(out)
230 |
231 | cond = torch.cat((start_tokens, cond), dim = 1)
232 |
233 | seq_len = cond.shape[-2]
234 | axial_pos_emb = self.axial_pos_emb()
235 | cond += axial_pos_emb[:seq_len]
236 |
237 | cond, cache = self.transformer(cond, cache = cache, return_hiddens = True)
238 |
239 | last_cond = cond[:, -1]
240 |
241 | last_cond += axial_pos_emb[seq_len]
242 | last_cond = self.to_cond_emb(last_cond)
243 |
244 | denoised_pred = self.flow.sample(cond = last_cond)
245 |
246 | denoised_pred = rearrange(denoised_pred, 'b d -> b 1 d')
247 | out = torch.cat((out, denoised_pred), dim = 1)
248 |
249 | return out
250 |
251 | def forward(
252 | self,
253 | seq,
254 | noised_seq = None
255 | ):
256 | b, seq_len, dim = seq.shape
257 |
258 | assert dim == self.dim_input
259 | assert seq_len == self.max_seq_len
260 |
261 | # break into seq and the continuous targets to be predicted
262 |
263 | seq, target = seq[:, :-1], seq
264 |
265 | if exists(noised_seq):
266 | seq = noised_seq[:, :-1]
267 |
268 | # append start tokens
269 |
270 | seq = self.proj_in(seq)
271 | start_token = repeat(self.start_token, 'd -> b 1 d', b = b)
272 |
273 | seq = torch.cat((start_token, seq), dim = 1)
274 |
275 | axial_pos_emb = self.axial_pos_emb()
276 | seq = seq + axial_pos_emb[:seq_len]
277 |
278 | cond = self.transformer(seq)
279 |
280 | cond = cond + axial_pos_emb[1:(seq_len + 1)]
281 | cond = self.to_cond_emb(cond)
282 |
283 | # pack batch and sequence dimensions, so to train each token with different noise levels
284 |
285 | target, _ = pack_one(target, '* d')
286 | cond, _ = pack_one(cond, '* d')
287 |
288 | return self.flow(target, cond = cond)
289 |
290 | # image wrapper
291 |
292 | def normalize_to_neg_one_to_one(img):
293 | return img * 2 - 1
294 |
295 | def unnormalize_to_zero_to_one(t):
296 | return (t + 1) * 0.5
297 |
298 | class ImageAutoregressiveFlow(Module):
299 | def __init__(
300 | self,
301 | *,
302 | image_size,
303 | patch_size,
304 | channels = 3,
305 | train_max_noise = 0.,
306 | model: dict = dict(),
307 | ):
308 | super().__init__()
309 | assert divisible_by(image_size, patch_size)
310 |
311 | patch_height_width = image_size // patch_size
312 | num_patches = patch_height_width ** 2
313 | dim_in = channels * patch_size ** 2
314 |
315 | self.image_size = image_size
316 | self.patch_size = patch_size
317 |
318 | assert 0. <= train_max_noise < 1.
319 |
320 | self.train_max_noise = train_max_noise
321 |
322 | self.to_tokens = Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size)
323 |
324 | self.model = AutoregressiveFlow(
325 | **model,
326 | dim_input = dim_in,
327 | max_seq_len = (patch_height_width, patch_height_width)
328 | )
329 |
330 | self.to_image = Rearrange('b (h w) (c p1 p2) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size, h = int(math.sqrt(num_patches)))
331 |
332 | def sample(self, batch_size = 1):
333 | tokens = self.model.sample(batch_size = batch_size)
334 | images = self.to_image(tokens)
335 | return unnormalize_to_zero_to_one(images)
336 |
337 | def forward(self, images):
338 | train_under_noise, device = self.train_max_noise > 0., images.device
339 |
340 | images = normalize_to_neg_one_to_one(images)
341 | tokens = self.to_tokens(images)
342 |
343 | if not train_under_noise:
344 | return self.model(tokens)
345 |
346 | # allow for the network to predict from slightly noised images of the past
347 |
348 | times = torch.rand(images.shape[0], device = device) * self.train_max_noise
349 | noise = torch.randn_like(images)
350 | padded_times = right_pad_dims_to(images, times)
351 | noised_images = images * (1. - padded_times) + noise * padded_times
352 | noised_tokens = self.to_tokens(noised_images)
353 |
354 | return self.model(tokens, noised_tokens)
355 |
--------------------------------------------------------------------------------
/autoregressive_diffusion_pytorch/image_trainer.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import math
4 | from pathlib import Path
5 |
6 | from accelerate import Accelerator
7 | from ema_pytorch import EMA
8 |
9 | import torch
10 | from torch import nn
11 | from torch.optim import Adam
12 | from torch.utils.data import DataLoader
13 | from torch.nn import Module, ModuleList
14 | from torch.utils.data import Dataset
15 |
16 | from torchvision.utils import save_image
17 | import torchvision.transforms as T
18 |
19 | from PIL import Image
20 |
21 | # functions
22 |
23 | def exists(v):
24 | return v is not None
25 |
26 | def default(v, d):
27 | return v if exists(v) else d
28 |
29 | def divisible_by(num, den):
30 | return (num % den) == 0
31 |
32 | def cycle(dl):
33 | while True:
34 | for batch in dl:
35 | yield batch
36 |
37 | # dataset classes
38 |
39 | class ImageDataset(Dataset):
40 | def __init__(
41 | self,
42 | folder: str | Path,
43 | image_size: int,
44 | exts: List[str] = ['jpg', 'jpeg', 'png', 'tiff'],
45 | augment_horizontal_flip = False,
46 | convert_image_to = None
47 | ):
48 | super().__init__()
49 | if isinstance(folder, str):
50 | folder = Path(folder)
51 |
52 | assert folder.is_dir()
53 |
54 | self.folder = folder
55 | self.image_size = image_size
56 |
57 | self.paths = [p for ext in exts for p in folder.glob(f'**/*.{ext}')]
58 |
59 | def convert_image_to_fn(img_type, image):
60 | if image.mode == img_type:
61 | return image
62 |
63 | return image.convert(img_type)
64 |
65 | maybe_convert_fn = partial(convert_image_to_fn, convert_image_to) if exists(convert_image_to) else nn.Identity()
66 |
67 | self.transform = T.Compose([
68 | T.Lambda(maybe_convert_fn),
69 | T.Resize(image_size),
70 | T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
71 | T.CenterCrop(image_size),
72 | T.ToTensor()
73 | ])
74 |
75 | def __len__(self):
76 | return len(self.paths)
77 |
78 | def __getitem__(self, index):
79 | path = self.paths[index]
80 | img = Image.open(path)
81 | return self.transform(img)
82 |
83 | # trainer
84 |
85 | class ImageTrainer(Module):
86 | def __init__(
87 | self,
88 | model,
89 | *,
90 | dataset: Dataset,
91 | num_train_steps = 70_000,
92 | learning_rate = 3e-4,
93 | batch_size = 16,
94 | checkpoints_folder: str = './checkpoints',
95 | results_folder: str = './results',
96 | save_results_every: int = 100,
97 | checkpoint_every: int = 1000,
98 | num_samples: int = 16,
99 | adam_kwargs: dict = dict(),
100 | accelerate_kwargs: dict = dict(),
101 | ema_kwargs: dict = dict()
102 | ):
103 | super().__init__()
104 | self.accelerator = Accelerator(**accelerate_kwargs)
105 |
106 | self.model = model
107 |
108 | if self.is_main:
109 | self.ema_model = EMA(
110 | self.model,
111 | forward_method_names = ('sample',),
112 | **ema_kwargs
113 | )
114 |
115 | self.ema_model.to(self.accelerator.device)
116 |
117 | self.optimizer = Adam(model.parameters(), lr = learning_rate, **adam_kwargs)
118 | self.dl = DataLoader(dataset, batch_size = batch_size, shuffle = True, drop_last = True)
119 |
120 | self.model, self.optimizer, self.dl = self.accelerator.prepare(self.model, self.optimizer, self.dl)
121 |
122 | self.num_train_steps = num_train_steps
123 |
124 | self.checkpoints_folder = Path(checkpoints_folder)
125 | self.results_folder = Path(results_folder)
126 |
127 | self.checkpoints_folder.mkdir(exist_ok = True, parents = True)
128 | self.results_folder.mkdir(exist_ok = True, parents = True)
129 |
130 | self.checkpoint_every = checkpoint_every
131 | self.save_results_every = save_results_every
132 |
133 | self.num_sample_rows = int(math.sqrt(num_samples))
134 | assert (self.num_sample_rows ** 2) == num_samples, f'{num_samples} must be a square'
135 | self.num_samples = num_samples
136 |
137 | assert self.checkpoints_folder.is_dir()
138 | assert self.results_folder.is_dir()
139 |
140 | @property
141 | def is_main(self):
142 | return self.accelerator.is_main_process
143 |
144 | def save(self, path):
145 | if not self.is_main:
146 | return
147 |
148 | save_package = dict(
149 | model = self.accelerator.unwrap_model(self.model).state_dict(),
150 | ema_model = self.ema_model.state_dict(),
151 | optimizer = self.accelerator.unwrap_model(self.optimizer).state_dict(),
152 | )
153 |
154 | torch.save(save_package, str(self.checkpoints_folder / path))
155 |
156 | def forward(self):
157 |
158 | dl = cycle(self.dl)
159 |
160 | for ind in range(self.num_train_steps):
161 | step = ind + 1
162 |
163 | self.model.train()
164 |
165 | data = next(dl)
166 | loss = self.model(data)
167 |
168 | self.accelerator.print(f'[{step}] loss: {loss.item():.3f}')
169 | self.accelerator.backward(loss)
170 |
171 | self.optimizer.step()
172 | self.optimizer.zero_grad()
173 |
174 | if self.is_main:
175 | self.ema_model.update()
176 |
177 | self.accelerator.wait_for_everyone()
178 |
179 | if self.is_main:
180 | if divisible_by(step, self.save_results_every):
181 |
182 | with torch.no_grad():
183 | sampled = self.ema_model.sample(batch_size = self.num_samples)
184 |
185 | sampled.clamp_(0., 1.)
186 | save_image(sampled, str(self.results_folder / f'results.{step}.png'), nrow = self.num_sample_rows)
187 |
188 | if divisible_by(step, self.checkpoint_every):
189 | self.save(f'checkpoint.{step}.pt')
190 |
191 | self.accelerator.wait_for_everyone()
192 |
193 |
194 | print('training complete')
195 |
--------------------------------------------------------------------------------
/images/results.96600.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/autoregressive-diffusion-pytorch/6387e02449f8f5ebd69b0709c9aab11e869bda07/images/results.96600.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "autoregressive-diffusion-pytorch"
3 | version = "0.2.8"
4 | description = "Autoregressive Diffusion - Pytorch"
5 | authors = [
6 | { name = "Phil Wang", email = "lucidrains@gmail.com" }
7 | ]
8 | readme = "README.md"
9 | requires-python = ">= 3.8"
10 | license = { file = "LICENSE" }
11 | keywords = [
12 | 'artificial intelligence',
13 | 'deep learning',
14 | 'transformers',
15 | 'denoising diffusion',
16 | ]
17 | classifiers=[
18 | 'Development Status :: 4 - Beta',
19 | 'Intended Audience :: Developers',
20 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
21 | 'License :: OSI Approved :: MIT License',
22 | 'Programming Language :: Python :: 3.8',
23 | ]
24 |
25 | dependencies = [
26 | 'einx>=0.3.0',
27 | 'einops>=0.8.0',
28 | 'ema-pytorch',
29 | 'x-transformers>=1.31.14',
30 | 'torch>=2.0',
31 | 'torchdiffeq',
32 | 'tqdm'
33 | ]
34 |
35 | [project.urls]
36 | Homepage = "https://pypi.org/project/autoregressive-diffusion-pytorch/"
37 | Repository = "https://github.com/lucidrains/autoregressive-diffusion-pytorch"
38 |
39 | [project.optional-dependencies]
40 | examples = ["tqdm", "numpy"]
41 |
42 | [build-system]
43 | requires = ["hatchling"]
44 | build-backend = "hatchling.build"
45 |
46 | [tool.hatch.metadata]
47 | allow-direct-references = true
48 |
49 | [tool.hatch.build.targets.wheel]
50 | packages = ["autoregressive_diffusion_pytorch"]
51 |
--------------------------------------------------------------------------------