├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── ar-diffusion.png ├── autoregressive_diffusion_pytorch ├── __init__.py ├── autoregressive_diffusion.py ├── autoregressive_flow.py └── image_trainer.py ├── images └── results.96600.png └── pyproject.toml /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Autoregressive Diffusion - Pytorch 4 | 5 | Implementation of the architecture behind Autoregressive Image Generation without Vector Quantization in Pytorch 6 | 7 | Official repository has been released here 8 | 9 | Alternative route 10 | 11 | 12 | 13 | *oxford flowers at 96k steps* 14 | 15 | ## Install 16 | 17 | ```bash 18 | $ pip install autoregressive-diffusion-pytorch 19 | ``` 20 | 21 | ## Usage 22 | 23 | ```python 24 | import torch 25 | from autoregressive_diffusion_pytorch import AutoregressiveDiffusion 26 | 27 | model = AutoregressiveDiffusion( 28 | dim_input = 512, 29 | dim = 1024, 30 | max_seq_len = 32, 31 | depth = 8, 32 | mlp_depth = 3, 33 | mlp_width = 1024 34 | ) 35 | 36 | seq = torch.randn(3, 32, 512) 37 | 38 | loss = model(seq) 39 | loss.backward() 40 | 41 | sampled = model.sample(batch_size = 3) 42 | 43 | assert sampled.shape == seq.shape 44 | 45 | ``` 46 | 47 | For images treated as a sequence of tokens (as in paper) 48 | 49 | ```python 50 | import torch 51 | from autoregressive_diffusion_pytorch import ImageAutoregressiveDiffusion 52 | 53 | model = ImageAutoregressiveDiffusion( 54 | model = dict( 55 | dim = 1024, 56 | depth = 12, 57 | heads = 12, 58 | ), 59 | image_size = 64, 60 | patch_size = 8 61 | ) 62 | 63 | images = torch.randn(3, 3, 64, 64) 64 | 65 | loss = model(images) 66 | loss.backward() 67 | 68 | sampled = model.sample(batch_size = 3) 69 | 70 | assert sampled.shape == images.shape 71 | 72 | ``` 73 | 74 | An images trainer 75 | 76 | ```python 77 | import torch 78 | 79 | from autoregressive_diffusion_pytorch import ( 80 | ImageDataset, 81 | ImageAutoregressiveDiffusion, 82 | ImageTrainer 83 | ) 84 | 85 | dataset = ImageDataset( 86 | '/path/to/your/images', 87 | image_size = 128 88 | ) 89 | 90 | model = ImageAutoregressiveDiffusion( 91 | model = dict( 92 | dim = 512 93 | ), 94 | image_size = 128, 95 | patch_size = 16 96 | ) 97 | 98 | trainer = ImageTrainer( 99 | model = model, 100 | dataset = dataset 101 | ) 102 | 103 | trainer() 104 | ``` 105 | 106 | For an improvised version using flow matching, just import `ImageAutoregressiveFlow` and `AutoregressiveFlow` instead 107 | 108 | The rest is the same 109 | 110 | ex. 111 | 112 | ```python 113 | import torch 114 | 115 | from autoregressive_diffusion_pytorch import ( 116 | ImageDataset, 117 | ImageTrainer, 118 | ImageAutoregressiveFlow, 119 | ) 120 | 121 | dataset = ImageDataset( 122 | '/path/to/your/images', 123 | image_size = 128 124 | ) 125 | 126 | model = ImageAutoregressiveFlow( 127 | model = dict( 128 | dim = 512 129 | ), 130 | image_size = 128, 131 | patch_size = 16 132 | ) 133 | 134 | trainer = ImageTrainer( 135 | model = model, 136 | dataset = dataset 137 | ) 138 | 139 | trainer() 140 | ``` 141 | 142 | ## Citations 143 | 144 | ```bibtex 145 | @article{Li2024AutoregressiveIG, 146 | title = {Autoregressive Image Generation without Vector Quantization}, 147 | author = {Tianhong Li and Yonglong Tian and He Li and Mingyang Deng and Kaiming He}, 148 | journal = {ArXiv}, 149 | year = {2024}, 150 | volume = {abs/2406.11838}, 151 | url = {https://api.semanticscholar.org/CorpusID:270560593} 152 | } 153 | ``` 154 | 155 | ```bibtex 156 | @article{Wu2023ARDiffusionAD, 157 | title = {AR-Diffusion: Auto-Regressive Diffusion Model for Text Generation}, 158 | author = {Tong Wu and Zhihao Fan and Xiao Liu and Yeyun Gong and Yelong Shen and Jian Jiao and Haitao Zheng and Juntao Li and Zhongyu Wei and Jian Guo and Nan Duan and Weizhu Chen}, 159 | journal = {ArXiv}, 160 | year = {2023}, 161 | volume = {abs/2305.09515}, 162 | url = {https://api.semanticscholar.org/CorpusID:258714669} 163 | } 164 | ``` 165 | 166 | ```bibtex 167 | @article{Karras2022ElucidatingTD, 168 | title = {Elucidating the Design Space of Diffusion-Based Generative Models}, 169 | author = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine}, 170 | journal = {ArXiv}, 171 | year = {2022}, 172 | volume = {abs/2206.00364}, 173 | url = {https://api.semanticscholar.org/CorpusID:249240415} 174 | } 175 | ``` 176 | 177 | ```bibtex 178 | @article{Liu2022FlowSA, 179 | title = {Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow}, 180 | author = {Xingchao Liu and Chengyue Gong and Qiang Liu}, 181 | journal = {ArXiv}, 182 | year = {2022}, 183 | volume = {abs/2209.03003}, 184 | url = {https://api.semanticscholar.org/CorpusID:252111177} 185 | } 186 | ``` 187 | 188 | ```bibtex 189 | @article{Esser2024ScalingRF, 190 | title = {Scaling Rectified Flow Transformers for High-Resolution Image Synthesis}, 191 | author = {Patrick Esser and Sumith Kulal and A. Blattmann and Rahim Entezari and Jonas Muller and Harry Saini and Yam Levi and Dominik Lorenz and Axel Sauer and Frederic Boesel and Dustin Podell and Tim Dockhorn and Zion English and Kyle Lacey and Alex Goodwin and Yannik Marek and Robin Rombach}, 192 | journal = {ArXiv}, 193 | year = {2024}, 194 | volume = {abs/2403.03206}, 195 | url = {https://api.semanticscholar.org/CorpusID:268247980} 196 | } 197 | ``` 198 | -------------------------------------------------------------------------------- /ar-diffusion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/autoregressive-diffusion-pytorch/6387e02449f8f5ebd69b0709c9aab11e869bda07/ar-diffusion.png -------------------------------------------------------------------------------- /autoregressive_diffusion_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from autoregressive_diffusion_pytorch.autoregressive_diffusion import ( 2 | MLP, 3 | AutoregressiveDiffusion, 4 | ImageAutoregressiveDiffusion 5 | ) 6 | 7 | from autoregressive_diffusion_pytorch.autoregressive_flow import ( 8 | AutoregressiveFlow, 9 | ImageAutoregressiveFlow 10 | ) 11 | 12 | from autoregressive_diffusion_pytorch.image_trainer import ( 13 | ImageDataset, 14 | ImageTrainer 15 | ) 16 | -------------------------------------------------------------------------------- /autoregressive_diffusion_pytorch/autoregressive_diffusion.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | from math import sqrt 5 | from typing import Literal 6 | from functools import partial 7 | 8 | import torch 9 | from torch import nn, pi 10 | from torch.special import expm1 11 | import torch.nn.functional as F 12 | from torch.nn import Module, ModuleList 13 | 14 | import einx 15 | from einops import rearrange, repeat, reduce, pack, unpack 16 | from einops.layers.torch import Rearrange 17 | 18 | from tqdm import tqdm 19 | 20 | from x_transformers import Decoder 21 | 22 | # helpers 23 | 24 | def exists(v): 25 | return v is not None 26 | 27 | def default(v, d): 28 | return v if exists(v) else d 29 | 30 | def divisible_by(num, den): 31 | return (num % den) == 0 32 | 33 | # tensor helpers 34 | 35 | def log(t, eps = 1e-20): 36 | return torch.log(t.clamp(min = eps)) 37 | 38 | def safe_div(num, den, eps = 1e-5): 39 | return num / den.clamp(min = eps) 40 | 41 | def right_pad_dims_to(x, t): 42 | padding_dims = x.ndim - t.ndim 43 | 44 | if padding_dims <= 0: 45 | return t 46 | 47 | return t.view(*t.shape, *((1,) * padding_dims)) 48 | 49 | def pack_one(t, pattern): 50 | packed, ps = pack([t], pattern) 51 | 52 | def unpack_one(to_unpack, unpack_pattern = None): 53 | unpacked, = unpack(to_unpack, ps, default(unpack_pattern, pattern)) 54 | return unpacked 55 | 56 | return packed, unpack_one 57 | 58 | # sinusoidal embedding 59 | 60 | class AdaptiveLayerNorm(Module): 61 | def __init__( 62 | self, 63 | dim, 64 | dim_condition = None 65 | ): 66 | super().__init__() 67 | dim_condition = default(dim_condition, dim) 68 | 69 | self.ln = nn.LayerNorm(dim, elementwise_affine = False) 70 | self.to_gamma = nn.Linear(dim_condition, dim, bias = False) 71 | nn.init.zeros_(self.to_gamma.weight) 72 | 73 | def forward(self, x, *, condition): 74 | normed = self.ln(x) 75 | gamma = self.to_gamma(condition) 76 | return normed * (gamma + 1.) 77 | 78 | class LearnedSinusoidalPosEmb(Module): 79 | def __init__(self, dim): 80 | super().__init__() 81 | assert divisible_by(dim, 2) 82 | half_dim = dim // 2 83 | self.weights = nn.Parameter(torch.randn(half_dim)) 84 | 85 | def forward(self, x): 86 | x = rearrange(x, 'b -> b 1') 87 | freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * pi 88 | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) 89 | fouriered = torch.cat((x, fouriered), dim = -1) 90 | return fouriered 91 | 92 | # simple mlp 93 | 94 | class MLP(Module): 95 | def __init__( 96 | self, 97 | dim_cond, 98 | dim_input, 99 | depth = 3, 100 | width = 1024, 101 | dropout = 0. 102 | ): 103 | super().__init__() 104 | layers = ModuleList([]) 105 | 106 | self.to_time_emb = nn.Sequential( 107 | LearnedSinusoidalPosEmb(dim_cond), 108 | nn.Linear(dim_cond + 1, dim_cond), 109 | ) 110 | 111 | for _ in range(depth): 112 | 113 | adaptive_layernorm = AdaptiveLayerNorm( 114 | dim_input, 115 | dim_condition = dim_cond 116 | ) 117 | 118 | block = nn.Sequential( 119 | nn.Linear(dim_input, width), 120 | nn.SiLU(), 121 | nn.Dropout(dropout), 122 | nn.Linear(width, dim_input) 123 | ) 124 | 125 | block_out_gamma = nn.Linear(dim_cond, dim_input, bias = False) 126 | nn.init.zeros_(block_out_gamma.weight) 127 | 128 | layers.append(ModuleList([ 129 | adaptive_layernorm, 130 | block, 131 | block_out_gamma 132 | ])) 133 | 134 | self.layers = layers 135 | 136 | def forward( 137 | self, 138 | noised, 139 | *, 140 | times, 141 | cond 142 | ): 143 | assert noised.ndim == 2 144 | 145 | time_emb = self.to_time_emb(times) 146 | cond = F.silu(time_emb + cond) 147 | 148 | denoised = noised 149 | 150 | for adaln, block, block_out_gamma in self.layers: 151 | residual = denoised 152 | denoised = adaln(denoised, condition = cond) 153 | 154 | block_out = block(denoised) * (block_out_gamma(cond) + 1.) 155 | denoised = block_out + residual 156 | 157 | return denoised 158 | 159 | # gaussian diffusion 160 | 161 | class ElucidatedDiffusion(Module): 162 | def __init__( 163 | self, 164 | dim: int, 165 | net: MLP, 166 | *, 167 | num_sample_steps = 32, # number of sampling steps 168 | sigma_min = 0.002, # min noise level 169 | sigma_max = 80, # max noise level 170 | sigma_data = 0.5, # standard deviation of data distribution 171 | rho = 7, # controls the sampling schedule 172 | P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training 173 | P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training 174 | S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper 175 | S_tmin = 0.05, 176 | S_tmax = 50, 177 | S_noise = 1.003, 178 | clamp_during_sampling = True 179 | ): 180 | super().__init__() 181 | 182 | self.net = net 183 | self.dim = dim 184 | 185 | # parameters 186 | 187 | self.sigma_min = sigma_min 188 | self.sigma_max = sigma_max 189 | self.sigma_data = sigma_data 190 | 191 | self.rho = rho 192 | 193 | self.P_mean = P_mean 194 | self.P_std = P_std 195 | 196 | self.num_sample_steps = num_sample_steps # otherwise known as N in the paper 197 | 198 | self.S_churn = S_churn 199 | self.S_tmin = S_tmin 200 | self.S_tmax = S_tmax 201 | self.S_noise = S_noise 202 | 203 | self.clamp_during_sampling = clamp_during_sampling 204 | 205 | @property 206 | def device(self): 207 | return next(self.net.parameters()).device 208 | 209 | # derived preconditioning params - Table 1 210 | 211 | def c_skip(self, sigma): 212 | return (self.sigma_data ** 2) / (sigma ** 2 + self.sigma_data ** 2) 213 | 214 | def c_out(self, sigma): 215 | return sigma * self.sigma_data * (self.sigma_data ** 2 + sigma ** 2) ** -0.5 216 | 217 | def c_in(self, sigma): 218 | return 1 * (sigma ** 2 + self.sigma_data ** 2) ** -0.5 219 | 220 | def c_noise(self, sigma): 221 | return log(sigma) * 0.25 222 | 223 | # preconditioned network output 224 | # equation (7) in the paper 225 | 226 | def preconditioned_network_forward(self, noised_seq, sigma, *, cond, clamp = None): 227 | clamp = default(clamp, self.clamp_during_sampling) 228 | 229 | batch, device = noised_seq.shape[0], noised_seq.device 230 | 231 | if isinstance(sigma, float): 232 | sigma = torch.full((batch,), sigma, device = device) 233 | 234 | padded_sigma = right_pad_dims_to(noised_seq, sigma) 235 | 236 | net_out = self.net( 237 | self.c_in(padded_sigma) * noised_seq, 238 | times = self.c_noise(sigma), 239 | cond = cond 240 | ) 241 | 242 | out = self.c_skip(padded_sigma) * noised_seq + self.c_out(padded_sigma) * net_out 243 | 244 | if clamp: 245 | out = out.clamp(-1., 1.) 246 | 247 | return out 248 | 249 | # sampling 250 | 251 | # sample schedule 252 | # equation (5) in the paper 253 | 254 | def sample_schedule(self, num_sample_steps = None): 255 | num_sample_steps = default(num_sample_steps, self.num_sample_steps) 256 | 257 | N = num_sample_steps 258 | inv_rho = 1 / self.rho 259 | 260 | steps = torch.arange(num_sample_steps, device = self.device, dtype = torch.float32) 261 | sigmas = (self.sigma_max ** inv_rho + steps / (N - 1) * (self.sigma_min ** inv_rho - self.sigma_max ** inv_rho)) ** self.rho 262 | 263 | sigmas = F.pad(sigmas, (0, 1), value = 0.) # last step is sigma value of 0. 264 | return sigmas 265 | 266 | @torch.no_grad() 267 | def sample(self, cond, num_sample_steps = None, clamp = None): 268 | clamp = default(clamp, self.clamp_during_sampling) 269 | num_sample_steps = default(num_sample_steps, self.num_sample_steps) 270 | 271 | shape = (cond.shape[0], self.dim) 272 | 273 | # get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma 274 | 275 | sigmas = self.sample_schedule(num_sample_steps) 276 | 277 | gammas = torch.where( 278 | (sigmas >= self.S_tmin) & (sigmas <= self.S_tmax), 279 | min(self.S_churn / num_sample_steps, sqrt(2) - 1), 280 | 0. 281 | ) 282 | 283 | sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1])) 284 | 285 | # images is noise at the beginning 286 | 287 | init_sigma = sigmas[0] 288 | 289 | seq = init_sigma * torch.randn(shape, device = self.device) 290 | 291 | # gradually denoise 292 | 293 | for sigma, sigma_next, gamma in tqdm(sigmas_and_gammas, desc = 'sampling time step'): 294 | sigma, sigma_next, gamma = map(lambda t: t.item(), (sigma, sigma_next, gamma)) 295 | 296 | eps = self.S_noise * torch.randn(shape, device = self.device) # stochastic sampling 297 | 298 | sigma_hat = sigma + gamma * sigma 299 | seq_hat = seq + sqrt(sigma_hat ** 2 - sigma ** 2) * eps 300 | 301 | model_output = self.preconditioned_network_forward(seq_hat, sigma_hat, cond = cond, clamp = clamp) 302 | denoised_over_sigma = (seq_hat - model_output) / sigma_hat 303 | 304 | seq_next = seq_hat + (sigma_next - sigma_hat) * denoised_over_sigma 305 | 306 | # second order correction, if not the last timestep 307 | 308 | if sigma_next != 0: 309 | model_output_next = self.preconditioned_network_forward(seq_next, sigma_next, cond = cond, clamp = clamp) 310 | denoised_prime_over_sigma = (seq_next - model_output_next) / sigma_next 311 | seq_next = seq_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma) 312 | 313 | seq = seq_next 314 | 315 | if clamp: 316 | seq = seq.clamp(-1., 1.) 317 | 318 | return seq 319 | 320 | # training 321 | 322 | def loss_weight(self, sigma): 323 | return (sigma ** 2 + self.sigma_data ** 2) * (sigma * self.sigma_data) ** -2 324 | 325 | def noise_distribution(self, batch_size): 326 | return (self.P_mean + self.P_std * torch.randn((batch_size,), device = self.device)).exp() 327 | 328 | def forward(self, seq, *, cond): 329 | batch_size, dim, device = *seq.shape, self.device 330 | 331 | assert dim == self.dim, f'dimension of sequence being passed in must be {self.dim} but received {dim}' 332 | 333 | sigmas = self.noise_distribution(batch_size) 334 | padded_sigmas = right_pad_dims_to(seq, sigmas) 335 | 336 | noise = torch.randn_like(seq) 337 | 338 | noised_seq = seq + padded_sigmas * noise # alphas are 1. in the paper 339 | 340 | denoised = self.preconditioned_network_forward(noised_seq, sigmas, cond = cond) 341 | 342 | losses = F.mse_loss(denoised, seq, reduction = 'none') 343 | losses = reduce(losses, 'b ... -> b', 'mean') 344 | 345 | losses = losses * self.loss_weight(sigmas) 346 | 347 | return losses.mean() 348 | 349 | # main model, a decoder with continuous wrapper + small denoising mlp 350 | 351 | class AutoregressiveDiffusion(Module): 352 | def __init__( 353 | self, 354 | dim, 355 | *, 356 | max_seq_len, 357 | depth = 8, 358 | dim_head = 64, 359 | heads = 8, 360 | mlp_depth = 3, 361 | mlp_width = None, 362 | dim_input = None, 363 | decoder_kwargs: dict = dict(), 364 | mlp_kwargs: dict = dict(), 365 | diffusion_kwargs: dict = dict( 366 | clamp_during_sampling = True 367 | ) 368 | ): 369 | super().__init__() 370 | 371 | self.start_token = nn.Parameter(torch.zeros(dim)) 372 | self.max_seq_len = max_seq_len 373 | self.abs_pos_emb = nn.Embedding(max_seq_len, dim) 374 | 375 | dim_input = default(dim_input, dim) 376 | self.dim_input = dim_input 377 | self.proj_in = nn.Linear(dim_input, dim) 378 | 379 | self.transformer = Decoder( 380 | dim = dim, 381 | depth = depth, 382 | heads = heads, 383 | attn_dim_head = dim_head, 384 | **decoder_kwargs 385 | ) 386 | 387 | self.denoiser = MLP( 388 | dim_cond = dim, 389 | dim_input = dim_input, 390 | depth = mlp_depth, 391 | width = default(mlp_width, dim), 392 | **mlp_kwargs 393 | ) 394 | 395 | self.diffusion = ElucidatedDiffusion( 396 | dim_input, 397 | self.denoiser, 398 | **diffusion_kwargs 399 | ) 400 | 401 | @property 402 | def device(self): 403 | return next(self.transformer.parameters()).device 404 | 405 | @torch.no_grad() 406 | def sample( 407 | self, 408 | batch_size = 1, 409 | prompt = None 410 | ): 411 | self.eval() 412 | 413 | start_tokens = repeat(self.start_token, 'd -> b 1 d', b = batch_size) 414 | 415 | if not exists(prompt): 416 | out = torch.empty((batch_size, 0, self.dim_input), device = self.device, dtype = torch.float32) 417 | else: 418 | out = prompt 419 | 420 | cache = None 421 | 422 | for _ in tqdm(range(self.max_seq_len - out.shape[1]), desc = 'tokens'): 423 | 424 | cond = self.proj_in(out) 425 | 426 | cond = torch.cat((start_tokens, cond), dim = 1) 427 | cond = cond + self.abs_pos_emb(torch.arange(cond.shape[1], device = self.device)) 428 | 429 | cond, cache = self.transformer(cond, cache = cache, return_hiddens = True) 430 | 431 | last_cond = cond[:, -1] 432 | 433 | denoised_pred = self.diffusion.sample(cond = last_cond) 434 | 435 | denoised_pred = rearrange(denoised_pred, 'b d -> b 1 d') 436 | out = torch.cat((out, denoised_pred), dim = 1) 437 | 438 | return out 439 | 440 | def forward( 441 | self, 442 | seq 443 | ): 444 | b, seq_len, dim = seq.shape 445 | 446 | assert dim == self.dim_input 447 | assert seq_len == self.max_seq_len 448 | 449 | # break into seq and the continuous targets to be predicted 450 | 451 | seq, target = seq[:, :-1], seq 452 | 453 | # append start tokens 454 | 455 | seq = self.proj_in(seq) 456 | start_token = repeat(self.start_token, 'd -> b 1 d', b = b) 457 | 458 | seq = torch.cat((start_token, seq), dim = 1) 459 | seq = seq + self.abs_pos_emb(torch.arange(seq_len, device = self.device)) 460 | 461 | cond = self.transformer(seq) 462 | 463 | # pack batch and sequence dimensions, so to train each token with different noise levels 464 | 465 | target, _ = pack_one(target, '* d') 466 | cond, _ = pack_one(cond, '* d') 467 | 468 | diffusion_loss = self.diffusion(target, cond = cond) 469 | 470 | return diffusion_loss 471 | 472 | # image wrapper 473 | 474 | def normalize_to_neg_one_to_one(img): 475 | return img * 2 - 1 476 | 477 | def unnormalize_to_zero_to_one(t): 478 | return (t + 1) * 0.5 479 | 480 | class ImageAutoregressiveDiffusion(Module): 481 | def __init__( 482 | self, 483 | *, 484 | image_size, 485 | patch_size, 486 | channels = 3, 487 | model: dict = dict(), 488 | ): 489 | super().__init__() 490 | assert divisible_by(image_size, patch_size) 491 | 492 | num_patches = (image_size // patch_size) ** 2 493 | dim_in = channels * patch_size ** 2 494 | 495 | self.image_size = image_size 496 | self.patch_size = patch_size 497 | 498 | self.to_tokens = Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size) 499 | 500 | self.model = AutoregressiveDiffusion( 501 | **model, 502 | dim_input = dim_in, 503 | max_seq_len = num_patches 504 | ) 505 | 506 | self.to_image = Rearrange('b (h w) (c p1 p2) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size, h = int(math.sqrt(num_patches))) 507 | 508 | def sample(self, batch_size = 1): 509 | tokens = self.model.sample(batch_size = batch_size) 510 | images = self.to_image(tokens) 511 | return unnormalize_to_zero_to_one(images) 512 | 513 | def forward(self, images): 514 | images = normalize_to_neg_one_to_one(images) 515 | tokens = self.to_tokens(images) 516 | return self.model(tokens) 517 | -------------------------------------------------------------------------------- /autoregressive_diffusion_pytorch/autoregressive_flow.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | from math import sqrt 5 | from typing import Literal 6 | from functools import partial 7 | 8 | import torch 9 | from torch import nn, pi 10 | import torch.nn.functional as F 11 | from torch.nn import Module, ModuleList 12 | 13 | from torchdiffeq import odeint 14 | 15 | import einx 16 | from einops import rearrange, repeat, reduce, pack, unpack 17 | from einops.layers.torch import Rearrange 18 | 19 | from tqdm import tqdm 20 | 21 | from x_transformers import Decoder 22 | 23 | from autoregressive_diffusion_pytorch.autoregressive_diffusion import MLP 24 | 25 | # helpers 26 | 27 | def exists(v): 28 | return v is not None 29 | 30 | def default(v, d): 31 | return v if exists(v) else d 32 | 33 | def divisible_by(num, den): 34 | return (num % den) == 0 35 | 36 | def cast_tuple(t): 37 | return (t,) if not isinstance(t, tuple) else t 38 | 39 | # tensor helpers 40 | 41 | def log(t, eps = 1e-20): 42 | return torch.log(t.clamp(min = eps)) 43 | 44 | def safe_div(num, den, eps = 1e-5): 45 | return num / den.clamp(min = eps) 46 | 47 | def right_pad_dims_to(x, t): 48 | padding_dims = x.ndim - t.ndim 49 | 50 | if padding_dims <= 0: 51 | return t 52 | 53 | return t.view(*t.shape, *((1,) * padding_dims)) 54 | 55 | def pack_one(t, pattern): 56 | packed, ps = pack([t], pattern) 57 | 58 | def unpack_one(to_unpack, unpack_pattern = None): 59 | unpacked, = unpack(to_unpack, ps, default(unpack_pattern, pattern)) 60 | return unpacked 61 | 62 | return packed, unpack_one 63 | 64 | # rectified flow 65 | 66 | class Flow(Module): 67 | def __init__( 68 | self, 69 | dim: int, 70 | net: MLP, 71 | *, 72 | atol = 1e-5, 73 | rtol = 1e-5, 74 | method = 'midpoint' 75 | ): 76 | super().__init__() 77 | self.net = net 78 | self.dim = dim 79 | 80 | self.odeint_kwargs = dict( 81 | atol = atol, 82 | rtol = rtol, 83 | method = method 84 | ) 85 | 86 | @property 87 | def device(self): 88 | return next(self.net.parameters()).device 89 | 90 | @torch.no_grad() 91 | def sample( 92 | self, 93 | cond, 94 | num_sample_steps = 16 95 | ): 96 | 97 | batch = cond.shape[0] 98 | 99 | sampled_data_shape = (batch, self.dim) 100 | 101 | # start with random gaussian noise - y0 102 | 103 | noise = torch.randn(sampled_data_shape, device = self.device) 104 | 105 | # time steps 106 | 107 | times = torch.linspace(0., 1., num_sample_steps, device = self.device) 108 | 109 | # ode 110 | 111 | def ode_fn(t, x): 112 | t = repeat(t, '-> b', b = batch) 113 | flow = self.net(x, times = t, cond = cond) 114 | return flow 115 | 116 | trajectory = odeint(ode_fn, noise, times, **self.odeint_kwargs) 117 | 118 | sampled = trajectory[-1] 119 | 120 | return sampled 121 | 122 | # training 123 | 124 | def forward(self, seq, *, cond): 125 | batch_size, dim, device = *seq.shape, self.device 126 | 127 | assert dim == self.dim, f'dimension of sequence being passed in must be {self.dim} but received {dim}' 128 | 129 | times = torch.rand(batch_size, device = device) 130 | noise = torch.randn_like(seq) 131 | padded_times = right_pad_dims_to(seq, times) 132 | 133 | flow = seq - noise 134 | 135 | noised = (1.- padded_times) * noise + padded_times * seq 136 | 137 | pred_flow = self.net(noised, times = times, cond = cond) 138 | 139 | return F.mse_loss(pred_flow, flow) 140 | 141 | # main model, a decoder with continuous wrapper + small denoising mlp 142 | 143 | class AutoregressiveFlow(Module): 144 | def __init__( 145 | self, 146 | dim, 147 | *, 148 | max_seq_len: int | tuple[int, ...], 149 | depth = 8, 150 | dim_head = 64, 151 | heads = 8, 152 | mlp_depth = 3, 153 | mlp_width = 1024, 154 | dim_input = None, 155 | decoder_kwargs: dict = dict(), 156 | mlp_kwargs: dict = dict(), 157 | flow_kwargs: dict = dict() 158 | ): 159 | super().__init__() 160 | 161 | self.start_token = nn.Parameter(torch.zeros(dim)) 162 | 163 | max_seq_len = cast_tuple(max_seq_len) 164 | self.abs_pos_emb = nn.ParameterList([nn.Parameter(torch.zeros(seq_len, dim)) for seq_len in max_seq_len]) 165 | 166 | self.max_seq_len = math.prod(max_seq_len) 167 | 168 | dim_input = default(dim_input, dim) 169 | self.dim_input = dim_input 170 | self.proj_in = nn.Linear(dim_input, dim) 171 | 172 | self.transformer = Decoder( 173 | dim = dim, 174 | depth = depth, 175 | heads = heads, 176 | attn_dim_head = dim_head, 177 | **decoder_kwargs 178 | ) 179 | 180 | self.to_cond_emb = nn.Linear(dim, dim, bias = False) 181 | 182 | self.denoiser = MLP( 183 | dim_cond = dim, 184 | dim_input = dim_input, 185 | depth = mlp_depth, 186 | width = mlp_width, 187 | **mlp_kwargs 188 | ) 189 | 190 | self.flow = Flow( 191 | dim_input, 192 | self.denoiser, 193 | **flow_kwargs 194 | ) 195 | 196 | @property 197 | def device(self): 198 | return next(self.transformer.parameters()).device 199 | 200 | def axial_pos_emb(self): 201 | # prepare maybe axial positional embedding 202 | 203 | pos_emb, *rest_pos_embs = self.abs_pos_emb 204 | 205 | for rest_pos_emb in rest_pos_embs: 206 | pos_emb = einx.add('i d, j d -> (i j) d', pos_emb, rest_pos_emb) 207 | 208 | return F.pad(pos_emb, (0, 0, 1, 0), value = 0.) 209 | 210 | @torch.no_grad() 211 | def sample( 212 | self, 213 | batch_size = 1, 214 | prompt = None 215 | ): 216 | self.eval() 217 | 218 | start_tokens = repeat(self.start_token, 'd -> b 1 d', b = batch_size) 219 | 220 | if not exists(prompt): 221 | out = torch.empty((batch_size, 0, self.dim_input), device = self.device, dtype = torch.float32) 222 | else: 223 | out = prompt 224 | 225 | cache = None 226 | 227 | for _ in tqdm(range(self.max_seq_len - out.shape[1]), desc = 'tokens'): 228 | 229 | cond = self.proj_in(out) 230 | 231 | cond = torch.cat((start_tokens, cond), dim = 1) 232 | 233 | seq_len = cond.shape[-2] 234 | axial_pos_emb = self.axial_pos_emb() 235 | cond += axial_pos_emb[:seq_len] 236 | 237 | cond, cache = self.transformer(cond, cache = cache, return_hiddens = True) 238 | 239 | last_cond = cond[:, -1] 240 | 241 | last_cond += axial_pos_emb[seq_len] 242 | last_cond = self.to_cond_emb(last_cond) 243 | 244 | denoised_pred = self.flow.sample(cond = last_cond) 245 | 246 | denoised_pred = rearrange(denoised_pred, 'b d -> b 1 d') 247 | out = torch.cat((out, denoised_pred), dim = 1) 248 | 249 | return out 250 | 251 | def forward( 252 | self, 253 | seq, 254 | noised_seq = None 255 | ): 256 | b, seq_len, dim = seq.shape 257 | 258 | assert dim == self.dim_input 259 | assert seq_len == self.max_seq_len 260 | 261 | # break into seq and the continuous targets to be predicted 262 | 263 | seq, target = seq[:, :-1], seq 264 | 265 | if exists(noised_seq): 266 | seq = noised_seq[:, :-1] 267 | 268 | # append start tokens 269 | 270 | seq = self.proj_in(seq) 271 | start_token = repeat(self.start_token, 'd -> b 1 d', b = b) 272 | 273 | seq = torch.cat((start_token, seq), dim = 1) 274 | 275 | axial_pos_emb = self.axial_pos_emb() 276 | seq = seq + axial_pos_emb[:seq_len] 277 | 278 | cond = self.transformer(seq) 279 | 280 | cond = cond + axial_pos_emb[1:(seq_len + 1)] 281 | cond = self.to_cond_emb(cond) 282 | 283 | # pack batch and sequence dimensions, so to train each token with different noise levels 284 | 285 | target, _ = pack_one(target, '* d') 286 | cond, _ = pack_one(cond, '* d') 287 | 288 | return self.flow(target, cond = cond) 289 | 290 | # image wrapper 291 | 292 | def normalize_to_neg_one_to_one(img): 293 | return img * 2 - 1 294 | 295 | def unnormalize_to_zero_to_one(t): 296 | return (t + 1) * 0.5 297 | 298 | class ImageAutoregressiveFlow(Module): 299 | def __init__( 300 | self, 301 | *, 302 | image_size, 303 | patch_size, 304 | channels = 3, 305 | train_max_noise = 0., 306 | model: dict = dict(), 307 | ): 308 | super().__init__() 309 | assert divisible_by(image_size, patch_size) 310 | 311 | patch_height_width = image_size // patch_size 312 | num_patches = patch_height_width ** 2 313 | dim_in = channels * patch_size ** 2 314 | 315 | self.image_size = image_size 316 | self.patch_size = patch_size 317 | 318 | assert 0. <= train_max_noise < 1. 319 | 320 | self.train_max_noise = train_max_noise 321 | 322 | self.to_tokens = Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size) 323 | 324 | self.model = AutoregressiveFlow( 325 | **model, 326 | dim_input = dim_in, 327 | max_seq_len = (patch_height_width, patch_height_width) 328 | ) 329 | 330 | self.to_image = Rearrange('b (h w) (c p1 p2) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size, h = int(math.sqrt(num_patches))) 331 | 332 | def sample(self, batch_size = 1): 333 | tokens = self.model.sample(batch_size = batch_size) 334 | images = self.to_image(tokens) 335 | return unnormalize_to_zero_to_one(images) 336 | 337 | def forward(self, images): 338 | train_under_noise, device = self.train_max_noise > 0., images.device 339 | 340 | images = normalize_to_neg_one_to_one(images) 341 | tokens = self.to_tokens(images) 342 | 343 | if not train_under_noise: 344 | return self.model(tokens) 345 | 346 | # allow for the network to predict from slightly noised images of the past 347 | 348 | times = torch.rand(images.shape[0], device = device) * self.train_max_noise 349 | noise = torch.randn_like(images) 350 | padded_times = right_pad_dims_to(images, times) 351 | noised_images = images * (1. - padded_times) + noise * padded_times 352 | noised_tokens = self.to_tokens(noised_images) 353 | 354 | return self.model(tokens, noised_tokens) 355 | -------------------------------------------------------------------------------- /autoregressive_diffusion_pytorch/image_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import math 4 | from pathlib import Path 5 | 6 | from accelerate import Accelerator 7 | from ema_pytorch import EMA 8 | 9 | import torch 10 | from torch import nn 11 | from torch.optim import Adam 12 | from torch.utils.data import DataLoader 13 | from torch.nn import Module, ModuleList 14 | from torch.utils.data import Dataset 15 | 16 | from torchvision.utils import save_image 17 | import torchvision.transforms as T 18 | 19 | from PIL import Image 20 | 21 | # functions 22 | 23 | def exists(v): 24 | return v is not None 25 | 26 | def default(v, d): 27 | return v if exists(v) else d 28 | 29 | def divisible_by(num, den): 30 | return (num % den) == 0 31 | 32 | def cycle(dl): 33 | while True: 34 | for batch in dl: 35 | yield batch 36 | 37 | # dataset classes 38 | 39 | class ImageDataset(Dataset): 40 | def __init__( 41 | self, 42 | folder: str | Path, 43 | image_size: int, 44 | exts: List[str] = ['jpg', 'jpeg', 'png', 'tiff'], 45 | augment_horizontal_flip = False, 46 | convert_image_to = None 47 | ): 48 | super().__init__() 49 | if isinstance(folder, str): 50 | folder = Path(folder) 51 | 52 | assert folder.is_dir() 53 | 54 | self.folder = folder 55 | self.image_size = image_size 56 | 57 | self.paths = [p for ext in exts for p in folder.glob(f'**/*.{ext}')] 58 | 59 | def convert_image_to_fn(img_type, image): 60 | if image.mode == img_type: 61 | return image 62 | 63 | return image.convert(img_type) 64 | 65 | maybe_convert_fn = partial(convert_image_to_fn, convert_image_to) if exists(convert_image_to) else nn.Identity() 66 | 67 | self.transform = T.Compose([ 68 | T.Lambda(maybe_convert_fn), 69 | T.Resize(image_size), 70 | T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(), 71 | T.CenterCrop(image_size), 72 | T.ToTensor() 73 | ]) 74 | 75 | def __len__(self): 76 | return len(self.paths) 77 | 78 | def __getitem__(self, index): 79 | path = self.paths[index] 80 | img = Image.open(path) 81 | return self.transform(img) 82 | 83 | # trainer 84 | 85 | class ImageTrainer(Module): 86 | def __init__( 87 | self, 88 | model, 89 | *, 90 | dataset: Dataset, 91 | num_train_steps = 70_000, 92 | learning_rate = 3e-4, 93 | batch_size = 16, 94 | checkpoints_folder: str = './checkpoints', 95 | results_folder: str = './results', 96 | save_results_every: int = 100, 97 | checkpoint_every: int = 1000, 98 | num_samples: int = 16, 99 | adam_kwargs: dict = dict(), 100 | accelerate_kwargs: dict = dict(), 101 | ema_kwargs: dict = dict() 102 | ): 103 | super().__init__() 104 | self.accelerator = Accelerator(**accelerate_kwargs) 105 | 106 | self.model = model 107 | 108 | if self.is_main: 109 | self.ema_model = EMA( 110 | self.model, 111 | forward_method_names = ('sample',), 112 | **ema_kwargs 113 | ) 114 | 115 | self.ema_model.to(self.accelerator.device) 116 | 117 | self.optimizer = Adam(model.parameters(), lr = learning_rate, **adam_kwargs) 118 | self.dl = DataLoader(dataset, batch_size = batch_size, shuffle = True, drop_last = True) 119 | 120 | self.model, self.optimizer, self.dl = self.accelerator.prepare(self.model, self.optimizer, self.dl) 121 | 122 | self.num_train_steps = num_train_steps 123 | 124 | self.checkpoints_folder = Path(checkpoints_folder) 125 | self.results_folder = Path(results_folder) 126 | 127 | self.checkpoints_folder.mkdir(exist_ok = True, parents = True) 128 | self.results_folder.mkdir(exist_ok = True, parents = True) 129 | 130 | self.checkpoint_every = checkpoint_every 131 | self.save_results_every = save_results_every 132 | 133 | self.num_sample_rows = int(math.sqrt(num_samples)) 134 | assert (self.num_sample_rows ** 2) == num_samples, f'{num_samples} must be a square' 135 | self.num_samples = num_samples 136 | 137 | assert self.checkpoints_folder.is_dir() 138 | assert self.results_folder.is_dir() 139 | 140 | @property 141 | def is_main(self): 142 | return self.accelerator.is_main_process 143 | 144 | def save(self, path): 145 | if not self.is_main: 146 | return 147 | 148 | save_package = dict( 149 | model = self.accelerator.unwrap_model(self.model).state_dict(), 150 | ema_model = self.ema_model.state_dict(), 151 | optimizer = self.accelerator.unwrap_model(self.optimizer).state_dict(), 152 | ) 153 | 154 | torch.save(save_package, str(self.checkpoints_folder / path)) 155 | 156 | def forward(self): 157 | 158 | dl = cycle(self.dl) 159 | 160 | for ind in range(self.num_train_steps): 161 | step = ind + 1 162 | 163 | self.model.train() 164 | 165 | data = next(dl) 166 | loss = self.model(data) 167 | 168 | self.accelerator.print(f'[{step}] loss: {loss.item():.3f}') 169 | self.accelerator.backward(loss) 170 | 171 | self.optimizer.step() 172 | self.optimizer.zero_grad() 173 | 174 | if self.is_main: 175 | self.ema_model.update() 176 | 177 | self.accelerator.wait_for_everyone() 178 | 179 | if self.is_main: 180 | if divisible_by(step, self.save_results_every): 181 | 182 | with torch.no_grad(): 183 | sampled = self.ema_model.sample(batch_size = self.num_samples) 184 | 185 | sampled.clamp_(0., 1.) 186 | save_image(sampled, str(self.results_folder / f'results.{step}.png'), nrow = self.num_sample_rows) 187 | 188 | if divisible_by(step, self.checkpoint_every): 189 | self.save(f'checkpoint.{step}.pt') 190 | 191 | self.accelerator.wait_for_everyone() 192 | 193 | 194 | print('training complete') 195 | -------------------------------------------------------------------------------- /images/results.96600.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/autoregressive-diffusion-pytorch/6387e02449f8f5ebd69b0709c9aab11e869bda07/images/results.96600.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "autoregressive-diffusion-pytorch" 3 | version = "0.2.8" 4 | description = "Autoregressive Diffusion - Pytorch" 5 | authors = [ 6 | { name = "Phil Wang", email = "lucidrains@gmail.com" } 7 | ] 8 | readme = "README.md" 9 | requires-python = ">= 3.8" 10 | license = { file = "LICENSE" } 11 | keywords = [ 12 | 'artificial intelligence', 13 | 'deep learning', 14 | 'transformers', 15 | 'denoising diffusion', 16 | ] 17 | classifiers=[ 18 | 'Development Status :: 4 - Beta', 19 | 'Intended Audience :: Developers', 20 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 21 | 'License :: OSI Approved :: MIT License', 22 | 'Programming Language :: Python :: 3.8', 23 | ] 24 | 25 | dependencies = [ 26 | 'einx>=0.3.0', 27 | 'einops>=0.8.0', 28 | 'ema-pytorch', 29 | 'x-transformers>=1.31.14', 30 | 'torch>=2.0', 31 | 'torchdiffeq', 32 | 'tqdm' 33 | ] 34 | 35 | [project.urls] 36 | Homepage = "https://pypi.org/project/autoregressive-diffusion-pytorch/" 37 | Repository = "https://github.com/lucidrains/autoregressive-diffusion-pytorch" 38 | 39 | [project.optional-dependencies] 40 | examples = ["tqdm", "numpy"] 41 | 42 | [build-system] 43 | requires = ["hatchling"] 44 | build-backend = "hatchling.build" 45 | 46 | [tool.hatch.metadata] 47 | allow-direct-references = true 48 | 49 | [tool.hatch.build.targets.wheel] 50 | packages = ["autoregressive_diffusion_pytorch"] 51 | --------------------------------------------------------------------------------