├── ar-diffusion.png ├── images └── results.96600.png ├── autoregressive_diffusion_pytorch ├── __init__.py ├── image_trainer.py ├── autoregressive_flow.py └── autoregressive_diffusion.py ├── LICENSE ├── .github └── workflows │ └── python-publish.yml ├── train_ar_flow_oxford.py ├── pyproject.toml ├── .gitignore └── README.md /ar-diffusion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/autoregressive-diffusion-pytorch/HEAD/ar-diffusion.png -------------------------------------------------------------------------------- /images/results.96600.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/autoregressive-diffusion-pytorch/HEAD/images/results.96600.png -------------------------------------------------------------------------------- /autoregressive_diffusion_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from autoregressive_diffusion_pytorch.autoregressive_diffusion import ( 2 | MLP, 3 | AutoregressiveDiffusion, 4 | ImageAutoregressiveDiffusion 5 | ) 6 | 7 | from autoregressive_diffusion_pytorch.autoregressive_flow import ( 8 | AutoregressiveFlow, 9 | ImageAutoregressiveFlow 10 | ) 11 | 12 | from autoregressive_diffusion_pytorch.image_trainer import ( 13 | ImageDataset, 14 | ImageTrainer 15 | ) 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /train_ar_flow_oxford.py: -------------------------------------------------------------------------------- 1 | # hf datasets for easy oxford flowers training 2 | 3 | import torchvision.transforms as T 4 | from torch.utils.data import Dataset 5 | from datasets import load_dataset 6 | 7 | class OxfordFlowersDataset(Dataset): 8 | def __init__( 9 | self, 10 | image_size 11 | ): 12 | self.ds = load_dataset('nelorth/oxford-flowers')['train'] 13 | 14 | self.transform = T.Compose([ 15 | T.Resize((image_size, image_size)), 16 | T.PILToTensor() 17 | ]) 18 | 19 | def __len__(self): 20 | return len(self.ds) 21 | 22 | def __getitem__(self, idx): 23 | pil = self.ds[idx]['image'] 24 | tensor = self.transform(pil) 25 | return tensor / 255. 26 | 27 | flowers_dataset = OxfordFlowersDataset( 28 | image_size = 64 29 | ) 30 | 31 | from autoregressive_diffusion_pytorch import ImageAutoregressiveFlow, ImageTrainer 32 | 33 | model = ImageAutoregressiveFlow( 34 | model = dict( 35 | dim = 1024, 36 | depth = 8, 37 | heads = 8, 38 | mlp_depth = 4, 39 | decoder_kwargs = dict( 40 | rotary_pos_emb = True 41 | ) 42 | ), 43 | image_size = 64, 44 | patch_size = 8, 45 | model_output_clean = True 46 | ) 47 | 48 | trainer = ImageTrainer( 49 | model, 50 | dataset = flowers_dataset, 51 | num_train_steps = 1_000_000, 52 | learning_rate = 7e-5, 53 | batch_size = 32 54 | ) 55 | 56 | trainer() 57 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "autoregressive-diffusion-pytorch" 3 | version = "0.3.0" 4 | description = "Autoregressive Diffusion - Pytorch" 5 | authors = [ 6 | { name = "Phil Wang", email = "lucidrains@gmail.com" } 7 | ] 8 | readme = "README.md" 9 | requires-python = ">= 3.8" 10 | license = { file = "LICENSE" } 11 | keywords = [ 12 | 'artificial intelligence', 13 | 'deep learning', 14 | 'transformers', 15 | 'denoising diffusion', 16 | ] 17 | classifiers=[ 18 | 'Development Status :: 4 - Beta', 19 | 'Intended Audience :: Developers', 20 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 21 | 'License :: OSI Approved :: MIT License', 22 | 'Programming Language :: Python :: 3.8', 23 | ] 24 | 25 | dependencies = [ 26 | 'accelerate<=1.6.0', 27 | 'einx>=0.3.0', 28 | 'einops>=0.8.0', 29 | 'ema-pytorch', 30 | 'x-transformers>=1.31.14', 31 | 'torch>=2.0', 32 | 'torchdiffeq', 33 | 'tqdm' 34 | ] 35 | 36 | [project.urls] 37 | Homepage = "https://pypi.org/project/autoregressive-diffusion-pytorch/" 38 | Repository = "https://github.com/lucidrains/autoregressive-diffusion-pytorch" 39 | 40 | [project.optional-dependencies] 41 | examples = ["tqdm", "numpy"] 42 | 43 | [build-system] 44 | requires = ["hatchling"] 45 | build-backend = "hatchling.build" 46 | 47 | [tool.hatch.metadata] 48 | allow-direct-references = true 49 | 50 | [tool.hatch.build.targets.wheel] 51 | packages = ["autoregressive_diffusion_pytorch"] 52 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | results/ 2 | checkpoints/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 113 | .pdm.toml 114 | .pdm-python 115 | .pdm-build/ 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Autoregressive Diffusion - Pytorch 4 | 5 | Implementation of the architecture behind Autoregressive Image Generation without Vector Quantization in Pytorch 6 | 7 | Official repository has been released here 8 | 9 | Alternative route 10 | 11 | 12 | 13 | *oxford flowers at 96k steps* 14 | 15 | ## Install 16 | 17 | ```bash 18 | $ pip install autoregressive-diffusion-pytorch 19 | ``` 20 | 21 | ## Usage 22 | 23 | ```python 24 | import torch 25 | from autoregressive_diffusion_pytorch import AutoregressiveDiffusion 26 | 27 | model = AutoregressiveDiffusion( 28 | dim_input = 512, 29 | dim = 1024, 30 | max_seq_len = 32, 31 | depth = 8, 32 | mlp_depth = 3, 33 | mlp_width = 1024 34 | ) 35 | 36 | seq = torch.randn(3, 32, 512) 37 | 38 | loss = model(seq) 39 | loss.backward() 40 | 41 | sampled = model.sample(batch_size = 3) 42 | 43 | assert sampled.shape == seq.shape 44 | 45 | ``` 46 | 47 | For images treated as a sequence of tokens (as in paper) 48 | 49 | ```python 50 | import torch 51 | from autoregressive_diffusion_pytorch import ImageAutoregressiveDiffusion 52 | 53 | model = ImageAutoregressiveDiffusion( 54 | model = dict( 55 | dim = 1024, 56 | depth = 12, 57 | heads = 12, 58 | ), 59 | image_size = 64, 60 | patch_size = 8 61 | ) 62 | 63 | images = torch.randn(3, 3, 64, 64) 64 | 65 | loss = model(images) 66 | loss.backward() 67 | 68 | sampled = model.sample(batch_size = 3) 69 | 70 | assert sampled.shape == images.shape 71 | 72 | ``` 73 | 74 | An images trainer 75 | 76 | ```python 77 | import torch 78 | 79 | from autoregressive_diffusion_pytorch import ( 80 | ImageDataset, 81 | ImageAutoregressiveDiffusion, 82 | ImageTrainer 83 | ) 84 | 85 | dataset = ImageDataset( 86 | '/path/to/your/images', 87 | image_size = 128 88 | ) 89 | 90 | model = ImageAutoregressiveDiffusion( 91 | model = dict( 92 | dim = 512 93 | ), 94 | image_size = 128, 95 | patch_size = 16 96 | ) 97 | 98 | trainer = ImageTrainer( 99 | model = model, 100 | dataset = dataset 101 | ) 102 | 103 | trainer() 104 | ``` 105 | 106 | For an improvised version using flow matching, just import `ImageAutoregressiveFlow` and `AutoregressiveFlow` instead 107 | 108 | The rest is the same 109 | 110 | ex. 111 | 112 | ```python 113 | import torch 114 | 115 | from autoregressive_diffusion_pytorch import ( 116 | ImageDataset, 117 | ImageTrainer, 118 | ImageAutoregressiveFlow, 119 | ) 120 | 121 | dataset = ImageDataset( 122 | '/path/to/your/images', 123 | image_size = 128 124 | ) 125 | 126 | model = ImageAutoregressiveFlow( 127 | model = dict( 128 | dim = 512 129 | ), 130 | image_size = 128, 131 | patch_size = 16 132 | ) 133 | 134 | trainer = ImageTrainer( 135 | model = model, 136 | dataset = dataset 137 | ) 138 | 139 | trainer() 140 | ``` 141 | 142 | ## Citations 143 | 144 | ```bibtex 145 | @article{Li2024AutoregressiveIG, 146 | title = {Autoregressive Image Generation without Vector Quantization}, 147 | author = {Tianhong Li and Yonglong Tian and He Li and Mingyang Deng and Kaiming He}, 148 | journal = {ArXiv}, 149 | year = {2024}, 150 | volume = {abs/2406.11838}, 151 | url = {https://api.semanticscholar.org/CorpusID:270560593} 152 | } 153 | ``` 154 | 155 | ```bibtex 156 | @article{Wu2023ARDiffusionAD, 157 | title = {AR-Diffusion: Auto-Regressive Diffusion Model for Text Generation}, 158 | author = {Tong Wu and Zhihao Fan and Xiao Liu and Yeyun Gong and Yelong Shen and Jian Jiao and Haitao Zheng and Juntao Li and Zhongyu Wei and Jian Guo and Nan Duan and Weizhu Chen}, 159 | journal = {ArXiv}, 160 | year = {2023}, 161 | volume = {abs/2305.09515}, 162 | url = {https://api.semanticscholar.org/CorpusID:258714669} 163 | } 164 | ``` 165 | 166 | ```bibtex 167 | @article{Karras2022ElucidatingTD, 168 | title = {Elucidating the Design Space of Diffusion-Based Generative Models}, 169 | author = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine}, 170 | journal = {ArXiv}, 171 | year = {2022}, 172 | volume = {abs/2206.00364}, 173 | url = {https://api.semanticscholar.org/CorpusID:249240415} 174 | } 175 | ``` 176 | 177 | ```bibtex 178 | @article{Liu2022FlowSA, 179 | title = {Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow}, 180 | author = {Xingchao Liu and Chengyue Gong and Qiang Liu}, 181 | journal = {ArXiv}, 182 | year = {2022}, 183 | volume = {abs/2209.03003}, 184 | url = {https://api.semanticscholar.org/CorpusID:252111177} 185 | } 186 | ``` 187 | 188 | ```bibtex 189 | @article{Esser2024ScalingRF, 190 | title = {Scaling Rectified Flow Transformers for High-Resolution Image Synthesis}, 191 | author = {Patrick Esser and Sumith Kulal and A. Blattmann and Rahim Entezari and Jonas Muller and Harry Saini and Yam Levi and Dominik Lorenz and Axel Sauer and Frederic Boesel and Dustin Podell and Tim Dockhorn and Zion English and Kyle Lacey and Alex Goodwin and Yannik Marek and Robin Rombach}, 192 | journal = {ArXiv}, 193 | year = {2024}, 194 | volume = {abs/2403.03206}, 195 | url = {https://api.semanticscholar.org/CorpusID:268247980} 196 | } 197 | ``` 198 | 199 | ```bibtex 200 | @misc{li2025basicsletdenoisinggenerative, 201 | title = {Back to Basics: Let Denoising Generative Models Denoise}, 202 | author = {Tianhong Li and Kaiming He}, 203 | year = {2025}, 204 | eprint = {2511.13720}, 205 | archivePrefix = {arXiv}, 206 | primaryClass = {cs.CV}, 207 | url = {https://arxiv.org/abs/2511.13720}, 208 | } 209 | ``` 210 | -------------------------------------------------------------------------------- /autoregressive_diffusion_pytorch/image_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import math 4 | from pathlib import Path 5 | 6 | from accelerate import Accelerator 7 | from ema_pytorch import EMA 8 | 9 | import torch 10 | from torch import nn 11 | from torch.optim import Adam 12 | from torch.utils.data import DataLoader 13 | from torch.nn import Module, ModuleList 14 | from torch.utils.data import Dataset 15 | 16 | from torchvision.utils import save_image 17 | import torchvision.transforms as T 18 | 19 | from PIL import Image 20 | 21 | # functions 22 | 23 | def exists(v): 24 | return v is not None 25 | 26 | def default(v, d): 27 | return v if exists(v) else d 28 | 29 | def divisible_by(num, den): 30 | return (num % den) == 0 31 | 32 | def cycle(dl): 33 | while True: 34 | for batch in dl: 35 | yield batch 36 | 37 | # dataset classes 38 | 39 | class ImageDataset(Dataset): 40 | def __init__( 41 | self, 42 | folder: str | Path, 43 | image_size: int, 44 | exts: List[str] = ['jpg', 'jpeg', 'png', 'tiff'], 45 | augment_horizontal_flip = False, 46 | convert_image_to = None 47 | ): 48 | super().__init__() 49 | if isinstance(folder, str): 50 | folder = Path(folder) 51 | 52 | assert folder.is_dir() 53 | 54 | self.folder = folder 55 | self.image_size = image_size 56 | 57 | self.paths = [p for ext in exts for p in folder.glob(f'**/*.{ext}')] 58 | 59 | def convert_image_to_fn(img_type, image): 60 | if image.mode == img_type: 61 | return image 62 | 63 | return image.convert(img_type) 64 | 65 | maybe_convert_fn = partial(convert_image_to_fn, convert_image_to) if exists(convert_image_to) else nn.Identity() 66 | 67 | self.transform = T.Compose([ 68 | T.Lambda(maybe_convert_fn), 69 | T.Resize(image_size), 70 | T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(), 71 | T.CenterCrop(image_size), 72 | T.ToTensor() 73 | ]) 74 | 75 | def __len__(self): 76 | return len(self.paths) 77 | 78 | def __getitem__(self, index): 79 | path = self.paths[index] 80 | img = Image.open(path) 81 | return self.transform(img) 82 | 83 | # trainer 84 | 85 | class ImageTrainer(Module): 86 | def __init__( 87 | self, 88 | model, 89 | *, 90 | dataset: Dataset, 91 | num_train_steps = 70_000, 92 | learning_rate = 3e-4, 93 | batch_size = 16, 94 | checkpoints_folder: str = './checkpoints', 95 | results_folder: str = './results', 96 | save_results_every: int = 100, 97 | checkpoint_every: int = 1000, 98 | num_samples: int = 16, 99 | adam_kwargs: dict = dict(), 100 | accelerate_kwargs: dict = dict(), 101 | ema_kwargs: dict = dict() 102 | ): 103 | super().__init__() 104 | self.accelerator = Accelerator(**accelerate_kwargs) 105 | 106 | self.model = model 107 | 108 | if self.is_main: 109 | self.ema_model = EMA( 110 | self.model, 111 | forward_method_names = ('sample',), 112 | **ema_kwargs 113 | ) 114 | 115 | self.ema_model.to(self.accelerator.device) 116 | 117 | self.optimizer = Adam(model.parameters(), lr = learning_rate, **adam_kwargs) 118 | self.dl = DataLoader(dataset, batch_size = batch_size, shuffle = True, drop_last = True) 119 | 120 | self.model, self.optimizer, self.dl = self.accelerator.prepare(self.model, self.optimizer, self.dl) 121 | 122 | self.num_train_steps = num_train_steps 123 | 124 | self.checkpoints_folder = Path(checkpoints_folder) 125 | self.results_folder = Path(results_folder) 126 | 127 | self.checkpoints_folder.mkdir(exist_ok = True, parents = True) 128 | self.results_folder.mkdir(exist_ok = True, parents = True) 129 | 130 | self.checkpoint_every = checkpoint_every 131 | self.save_results_every = save_results_every 132 | 133 | self.num_sample_rows = int(math.sqrt(num_samples)) 134 | assert (self.num_sample_rows ** 2) == num_samples, f'{num_samples} must be a square' 135 | self.num_samples = num_samples 136 | 137 | assert self.checkpoints_folder.is_dir() 138 | assert self.results_folder.is_dir() 139 | 140 | @property 141 | def is_main(self): 142 | return self.accelerator.is_main_process 143 | 144 | def save(self, path): 145 | if not self.is_main: 146 | return 147 | 148 | save_package = dict( 149 | model = self.accelerator.unwrap_model(self.model).state_dict(), 150 | ema_model = self.ema_model.state_dict(), 151 | optimizer = self.accelerator.unwrap_model(self.optimizer).state_dict(), 152 | ) 153 | 154 | torch.save(save_package, str(self.checkpoints_folder / path)) 155 | 156 | def forward(self): 157 | 158 | dl = cycle(self.dl) 159 | 160 | for ind in range(self.num_train_steps): 161 | step = ind + 1 162 | 163 | self.model.train() 164 | 165 | data = next(dl) 166 | loss = self.model(data) 167 | 168 | self.accelerator.print(f'[{step}] loss: {loss.item():.3f}') 169 | self.accelerator.backward(loss) 170 | 171 | self.optimizer.step() 172 | self.optimizer.zero_grad() 173 | 174 | if self.is_main: 175 | self.ema_model.update() 176 | 177 | self.accelerator.wait_for_everyone() 178 | 179 | if self.is_main: 180 | if divisible_by(step, self.save_results_every): 181 | 182 | with torch.no_grad(): 183 | sampled = self.ema_model.sample(batch_size = self.num_samples) 184 | 185 | sampled.clamp_(0., 1.) 186 | save_image(sampled, str(self.results_folder / f'results.{step}.png'), nrow = self.num_sample_rows) 187 | 188 | if divisible_by(step, self.checkpoint_every): 189 | self.save(f'checkpoint.{step}.pt') 190 | 191 | self.accelerator.wait_for_everyone() 192 | 193 | 194 | print('training complete') 195 | -------------------------------------------------------------------------------- /autoregressive_diffusion_pytorch/autoregressive_flow.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | from math import sqrt 5 | from typing import Literal 6 | from functools import partial 7 | 8 | import torch 9 | from torch import nn, pi 10 | import torch.nn.functional as F 11 | from torch.nn import Module, ModuleList 12 | 13 | from torchdiffeq import odeint 14 | 15 | import einx 16 | from einops import rearrange, repeat, reduce, pack, unpack 17 | from einops.layers.torch import Rearrange 18 | 19 | from tqdm import tqdm 20 | 21 | from x_transformers import Decoder 22 | 23 | from autoregressive_diffusion_pytorch.autoregressive_diffusion import MLP 24 | 25 | # helpers 26 | 27 | def exists(v): 28 | return v is not None 29 | 30 | def default(v, d): 31 | return v if exists(v) else d 32 | 33 | def divisible_by(num, den): 34 | return (num % den) == 0 35 | 36 | def cast_tuple(t): 37 | return (t,) if not isinstance(t, tuple) else t 38 | 39 | # tensor helpers 40 | 41 | def log(t, eps = 1e-20): 42 | return torch.log(t.clamp(min = eps)) 43 | 44 | def safe_div(num, den, eps = 1e-5): 45 | return num / den.clamp(min = eps) 46 | 47 | def right_pad_dims_to(x, t): 48 | padding_dims = x.ndim - t.ndim 49 | 50 | if padding_dims <= 0: 51 | return t 52 | 53 | return t.view(*t.shape, *((1,) * padding_dims)) 54 | 55 | def pack_one(t, pattern): 56 | packed, ps = pack([t], pattern) 57 | 58 | def unpack_one(to_unpack, unpack_pattern = None): 59 | unpacked, = unpack(to_unpack, ps, default(unpack_pattern, pattern)) 60 | return unpacked 61 | 62 | return packed, unpack_one 63 | 64 | # rectified flow 65 | 66 | class Flow(Module): 67 | def __init__( 68 | self, 69 | dim: int, 70 | net: MLP, 71 | *, 72 | atol = 1e-5, 73 | rtol = 1e-5, 74 | method = 'midpoint', 75 | model_output_clean = True, 76 | eps = 5e-2 77 | ): 78 | super().__init__() 79 | self.net = net 80 | self.dim = dim 81 | 82 | self.odeint_kwargs = dict( 83 | atol = atol, 84 | rtol = rtol, 85 | method = method 86 | ) 87 | 88 | self.model_output_clean = model_output_clean 89 | self.eps = eps 90 | 91 | @property 92 | def device(self): 93 | return next(self.net.parameters()).device 94 | 95 | @torch.no_grad() 96 | def sample( 97 | self, 98 | cond, 99 | num_sample_steps = 16 100 | ): 101 | 102 | batch = cond.shape[0] 103 | 104 | sampled_data_shape = (batch, self.dim) 105 | 106 | # start with random gaussian noise - y0 107 | 108 | noise = torch.randn(sampled_data_shape, device = self.device) 109 | 110 | # time steps 111 | 112 | times = torch.linspace(0., 1., num_sample_steps, device = self.device) 113 | 114 | # ode 115 | 116 | def ode_fn(t, x): 117 | t = repeat(t, '-> b', b = batch) 118 | out = self.net(x, times = t, cond = cond) 119 | 120 | if not self.model_output_clean: 121 | flow = out 122 | else: 123 | t = right_pad_dims_to(x, t) 124 | flow = (out - x) / (1. - t).clamp_min(self.eps) 125 | 126 | return flow 127 | 128 | trajectory = odeint(ode_fn, noise, times, **self.odeint_kwargs) 129 | 130 | sampled = trajectory[-1] 131 | 132 | return sampled 133 | 134 | # training 135 | 136 | def forward(self, seq, *, cond): 137 | batch_size, dim, device = *seq.shape, self.device 138 | 139 | assert dim == self.dim, f'dimension of sequence being passed in must be {self.dim} but received {dim}' 140 | 141 | times = torch.rand(batch_size, device = device) 142 | noise = torch.randn_like(seq) 143 | padded_times = right_pad_dims_to(seq, times) 144 | 145 | flow = seq - noise 146 | 147 | noised = (1.- padded_times) * noise + padded_times * seq 148 | 149 | model_out = self.net(noised, times = times, cond = cond) 150 | 151 | if not self.model_output_clean: 152 | pred_flow = model_out 153 | else: 154 | pred_flow = (model_out - noised) / (1. - padded_times).clamp_min(self.eps) 155 | 156 | return F.mse_loss(pred_flow, flow) 157 | 158 | # main model, a decoder with continuous wrapper + small denoising mlp 159 | 160 | class AutoregressiveFlow(Module): 161 | def __init__( 162 | self, 163 | dim, 164 | *, 165 | max_seq_len: int | tuple[int, ...], 166 | depth = 8, 167 | dim_head = 64, 168 | heads = 8, 169 | mlp_depth = 3, 170 | mlp_width = 1024, 171 | dim_input = None, 172 | decoder_kwargs: dict = dict(), 173 | mlp_kwargs: dict = dict(), 174 | flow_kwargs: dict = dict(), 175 | model_output_clean = True # output in x-space 176 | ): 177 | super().__init__() 178 | 179 | self.start_token = nn.Parameter(torch.zeros(dim)) 180 | 181 | max_seq_len = cast_tuple(max_seq_len) 182 | self.abs_pos_emb = nn.ParameterList([nn.Parameter(torch.zeros(seq_len, dim)) for seq_len in max_seq_len]) 183 | 184 | self.max_seq_len = math.prod(max_seq_len) 185 | 186 | dim_input = default(dim_input, dim) 187 | self.dim_input = dim_input 188 | self.proj_in = nn.Linear(dim_input, dim) 189 | 190 | self.transformer = Decoder( 191 | dim = dim, 192 | depth = depth, 193 | heads = heads, 194 | attn_dim_head = dim_head, 195 | **decoder_kwargs 196 | ) 197 | 198 | self.to_cond_emb = nn.Linear(dim, dim, bias = False) 199 | 200 | self.denoiser = MLP( 201 | dim_cond = dim, 202 | dim_input = dim_input, 203 | depth = mlp_depth, 204 | width = mlp_width, 205 | **mlp_kwargs 206 | ) 207 | 208 | self.flow = Flow( 209 | dim_input, 210 | self.denoiser, 211 | model_output_clean = model_output_clean, 212 | **flow_kwargs 213 | ) 214 | 215 | @property 216 | def device(self): 217 | return next(self.transformer.parameters()).device 218 | 219 | def axial_pos_emb(self): 220 | # prepare maybe axial positional embedding 221 | 222 | pos_emb, *rest_pos_embs = self.abs_pos_emb 223 | 224 | for rest_pos_emb in rest_pos_embs: 225 | pos_emb = einx.add('i d, j d -> (i j) d', pos_emb, rest_pos_emb) 226 | 227 | return F.pad(pos_emb, (0, 0, 1, 0), value = 0.) 228 | 229 | @torch.no_grad() 230 | def sample( 231 | self, 232 | batch_size = 1, 233 | prompt = None 234 | ): 235 | self.eval() 236 | 237 | start_tokens = repeat(self.start_token, 'd -> b 1 d', b = batch_size) 238 | 239 | if not exists(prompt): 240 | out = torch.empty((batch_size, 0, self.dim_input), device = self.device, dtype = torch.float32) 241 | else: 242 | out = prompt 243 | 244 | cache = None 245 | 246 | for _ in tqdm(range(self.max_seq_len - out.shape[1]), desc = 'tokens'): 247 | 248 | cond = self.proj_in(out) 249 | 250 | cond = torch.cat((start_tokens, cond), dim = 1) 251 | 252 | seq_len = cond.shape[-2] 253 | axial_pos_emb = self.axial_pos_emb() 254 | cond += axial_pos_emb[:seq_len] 255 | 256 | cond, cache = self.transformer(cond, cache = cache, return_hiddens = True) 257 | 258 | last_cond = cond[:, -1] 259 | 260 | last_cond += axial_pos_emb[seq_len] 261 | last_cond = self.to_cond_emb(last_cond) 262 | 263 | denoised_pred = self.flow.sample(cond = last_cond) 264 | 265 | denoised_pred = rearrange(denoised_pred, 'b d -> b 1 d') 266 | out = torch.cat((out, denoised_pred), dim = 1) 267 | 268 | return out 269 | 270 | def forward( 271 | self, 272 | seq, 273 | noised_seq = None 274 | ): 275 | b, seq_len, dim = seq.shape 276 | 277 | assert dim == self.dim_input 278 | assert seq_len == self.max_seq_len 279 | 280 | # break into seq and the continuous targets to be predicted 281 | 282 | seq, target = seq[:, :-1], seq 283 | 284 | if exists(noised_seq): 285 | seq = noised_seq[:, :-1] 286 | 287 | # append start tokens 288 | 289 | seq = self.proj_in(seq) 290 | start_token = repeat(self.start_token, 'd -> b 1 d', b = b) 291 | 292 | seq = torch.cat((start_token, seq), dim = 1) 293 | 294 | axial_pos_emb = self.axial_pos_emb() 295 | seq = seq + axial_pos_emb[:seq_len] 296 | 297 | cond = self.transformer(seq) 298 | 299 | cond = cond + axial_pos_emb[1:(seq_len + 1)] 300 | cond = self.to_cond_emb(cond) 301 | 302 | # pack batch and sequence dimensions, so to train each token with different noise levels 303 | 304 | target, _ = pack_one(target, '* d') 305 | cond, _ = pack_one(cond, '* d') 306 | 307 | return self.flow(target, cond = cond) 308 | 309 | # image wrapper 310 | 311 | def normalize_to_neg_one_to_one(img): 312 | return img * 2 - 1 313 | 314 | def unnormalize_to_zero_to_one(t): 315 | return (t + 1) * 0.5 316 | 317 | class ImageAutoregressiveFlow(Module): 318 | def __init__( 319 | self, 320 | *, 321 | image_size, 322 | patch_size, 323 | channels = 3, 324 | train_max_noise = 0., 325 | model_output_clean = True, # for outputting in x-space 326 | model: dict = dict(), 327 | ): 328 | super().__init__() 329 | assert divisible_by(image_size, patch_size) 330 | 331 | patch_height_width = image_size // patch_size 332 | num_patches = patch_height_width ** 2 333 | dim_in = channels * patch_size ** 2 334 | 335 | self.image_size = image_size 336 | self.patch_size = patch_size 337 | 338 | assert 0. <= train_max_noise < 1. 339 | 340 | self.train_max_noise = train_max_noise 341 | 342 | self.to_tokens = Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size) 343 | 344 | self.model = AutoregressiveFlow( 345 | **model, 346 | dim_input = dim_in, 347 | max_seq_len = (patch_height_width, patch_height_width), 348 | model_output_clean = model_output_clean 349 | ) 350 | 351 | self.to_image = Rearrange('b (h w) (c p1 p2) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size, h = int(math.sqrt(num_patches))) 352 | 353 | def sample(self, batch_size = 1): 354 | tokens = self.model.sample(batch_size = batch_size) 355 | images = self.to_image(tokens) 356 | return unnormalize_to_zero_to_one(images) 357 | 358 | def forward(self, images): 359 | train_under_noise, device = self.train_max_noise > 0., images.device 360 | 361 | images = normalize_to_neg_one_to_one(images) 362 | tokens = self.to_tokens(images) 363 | 364 | if not train_under_noise: 365 | return self.model(tokens) 366 | 367 | # allow for the network to predict from slightly noised images of the past 368 | 369 | times = torch.rand(images.shape[0], device = device) * self.train_max_noise 370 | noise = torch.randn_like(images) 371 | padded_times = right_pad_dims_to(images, times) 372 | noised_images = images * (1. - padded_times) + noise * padded_times 373 | noised_tokens = self.to_tokens(noised_images) 374 | 375 | return self.model(tokens, noised_tokens) 376 | -------------------------------------------------------------------------------- /autoregressive_diffusion_pytorch/autoregressive_diffusion.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | from math import sqrt 5 | from typing import Literal 6 | from functools import partial 7 | 8 | import torch 9 | from torch import nn, pi 10 | from torch.special import expm1 11 | import torch.nn.functional as F 12 | from torch.nn import Module, ModuleList 13 | 14 | import einx 15 | from einops import rearrange, repeat, reduce, pack, unpack 16 | from einops.layers.torch import Rearrange 17 | 18 | from tqdm import tqdm 19 | 20 | from x_transformers import Decoder 21 | 22 | # helpers 23 | 24 | def exists(v): 25 | return v is not None 26 | 27 | def default(v, d): 28 | return v if exists(v) else d 29 | 30 | def divisible_by(num, den): 31 | return (num % den) == 0 32 | 33 | # tensor helpers 34 | 35 | def log(t, eps = 1e-20): 36 | return torch.log(t.clamp(min = eps)) 37 | 38 | def safe_div(num, den, eps = 1e-5): 39 | return num / den.clamp(min = eps) 40 | 41 | def right_pad_dims_to(x, t): 42 | padding_dims = x.ndim - t.ndim 43 | 44 | if padding_dims <= 0: 45 | return t 46 | 47 | return t.view(*t.shape, *((1,) * padding_dims)) 48 | 49 | def pack_one(t, pattern): 50 | packed, ps = pack([t], pattern) 51 | 52 | def unpack_one(to_unpack, unpack_pattern = None): 53 | unpacked, = unpack(to_unpack, ps, default(unpack_pattern, pattern)) 54 | return unpacked 55 | 56 | return packed, unpack_one 57 | 58 | # sinusoidal embedding 59 | 60 | class AdaptiveLayerNorm(Module): 61 | def __init__( 62 | self, 63 | dim, 64 | dim_condition = None 65 | ): 66 | super().__init__() 67 | dim_condition = default(dim_condition, dim) 68 | 69 | self.ln = nn.LayerNorm(dim, elementwise_affine = False) 70 | self.to_gamma = nn.Linear(dim_condition, dim, bias = False) 71 | nn.init.zeros_(self.to_gamma.weight) 72 | 73 | def forward(self, x, *, condition): 74 | normed = self.ln(x) 75 | gamma = self.to_gamma(condition) 76 | return normed * (gamma + 1.) 77 | 78 | class LearnedSinusoidalPosEmb(Module): 79 | def __init__(self, dim): 80 | super().__init__() 81 | assert divisible_by(dim, 2) 82 | half_dim = dim // 2 83 | self.weights = nn.Parameter(torch.randn(half_dim)) 84 | 85 | def forward(self, x): 86 | x = rearrange(x, 'b -> b 1') 87 | freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * pi 88 | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) 89 | fouriered = torch.cat((x, fouriered), dim = -1) 90 | return fouriered 91 | 92 | # simple mlp 93 | 94 | class MLP(Module): 95 | def __init__( 96 | self, 97 | dim_cond, 98 | dim_input, 99 | depth = 3, 100 | width = 1024, 101 | dropout = 0. 102 | ): 103 | super().__init__() 104 | layers = ModuleList([]) 105 | 106 | self.to_time_emb = nn.Sequential( 107 | LearnedSinusoidalPosEmb(dim_cond), 108 | nn.Linear(dim_cond + 1, dim_cond), 109 | ) 110 | 111 | for _ in range(depth): 112 | 113 | adaptive_layernorm = AdaptiveLayerNorm( 114 | dim_input, 115 | dim_condition = dim_cond 116 | ) 117 | 118 | block = nn.Sequential( 119 | nn.Linear(dim_input, width), 120 | nn.SiLU(), 121 | nn.Dropout(dropout), 122 | nn.Linear(width, dim_input) 123 | ) 124 | 125 | block_out_gamma = nn.Linear(dim_cond, dim_input, bias = False) 126 | nn.init.zeros_(block_out_gamma.weight) 127 | 128 | layers.append(ModuleList([ 129 | adaptive_layernorm, 130 | block, 131 | block_out_gamma 132 | ])) 133 | 134 | self.layers = layers 135 | 136 | def forward( 137 | self, 138 | noised, 139 | *, 140 | times, 141 | cond 142 | ): 143 | assert noised.ndim == 2 144 | 145 | time_emb = self.to_time_emb(times) 146 | cond = F.silu(time_emb + cond) 147 | 148 | denoised = noised 149 | 150 | for adaln, block, block_out_gamma in self.layers: 151 | residual = denoised 152 | denoised = adaln(denoised, condition = cond) 153 | 154 | block_out = block(denoised) * (block_out_gamma(cond) + 1.) 155 | denoised = block_out + residual 156 | 157 | return denoised 158 | 159 | # gaussian diffusion 160 | 161 | class ElucidatedDiffusion(Module): 162 | def __init__( 163 | self, 164 | dim: int, 165 | net: MLP, 166 | *, 167 | num_sample_steps = 32, # number of sampling steps 168 | sigma_min = 0.002, # min noise level 169 | sigma_max = 80, # max noise level 170 | sigma_data = 0.5, # standard deviation of data distribution 171 | rho = 7, # controls the sampling schedule 172 | P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training 173 | P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training 174 | S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper 175 | S_tmin = 0.05, 176 | S_tmax = 50, 177 | S_noise = 1.003, 178 | clamp_during_sampling = True 179 | ): 180 | super().__init__() 181 | 182 | self.net = net 183 | self.dim = dim 184 | 185 | # parameters 186 | 187 | self.sigma_min = sigma_min 188 | self.sigma_max = sigma_max 189 | self.sigma_data = sigma_data 190 | 191 | self.rho = rho 192 | 193 | self.P_mean = P_mean 194 | self.P_std = P_std 195 | 196 | self.num_sample_steps = num_sample_steps # otherwise known as N in the paper 197 | 198 | self.S_churn = S_churn 199 | self.S_tmin = S_tmin 200 | self.S_tmax = S_tmax 201 | self.S_noise = S_noise 202 | 203 | self.clamp_during_sampling = clamp_during_sampling 204 | 205 | @property 206 | def device(self): 207 | return next(self.net.parameters()).device 208 | 209 | # derived preconditioning params - Table 1 210 | 211 | def c_skip(self, sigma): 212 | return (self.sigma_data ** 2) / (sigma ** 2 + self.sigma_data ** 2) 213 | 214 | def c_out(self, sigma): 215 | return sigma * self.sigma_data * (self.sigma_data ** 2 + sigma ** 2) ** -0.5 216 | 217 | def c_in(self, sigma): 218 | return 1 * (sigma ** 2 + self.sigma_data ** 2) ** -0.5 219 | 220 | def c_noise(self, sigma): 221 | return log(sigma) * 0.25 222 | 223 | # preconditioned network output 224 | # equation (7) in the paper 225 | 226 | def preconditioned_network_forward(self, noised_seq, sigma, *, cond, clamp = None): 227 | clamp = default(clamp, self.clamp_during_sampling) 228 | 229 | batch, device = noised_seq.shape[0], noised_seq.device 230 | 231 | if isinstance(sigma, float): 232 | sigma = torch.full((batch,), sigma, device = device) 233 | 234 | padded_sigma = right_pad_dims_to(noised_seq, sigma) 235 | 236 | net_out = self.net( 237 | self.c_in(padded_sigma) * noised_seq, 238 | times = self.c_noise(sigma), 239 | cond = cond 240 | ) 241 | 242 | out = self.c_skip(padded_sigma) * noised_seq + self.c_out(padded_sigma) * net_out 243 | 244 | if clamp: 245 | out = out.clamp(-1., 1.) 246 | 247 | return out 248 | 249 | # sampling 250 | 251 | # sample schedule 252 | # equation (5) in the paper 253 | 254 | def sample_schedule(self, num_sample_steps = None): 255 | num_sample_steps = default(num_sample_steps, self.num_sample_steps) 256 | 257 | N = num_sample_steps 258 | inv_rho = 1 / self.rho 259 | 260 | steps = torch.arange(num_sample_steps, device = self.device, dtype = torch.float32) 261 | sigmas = (self.sigma_max ** inv_rho + steps / (N - 1) * (self.sigma_min ** inv_rho - self.sigma_max ** inv_rho)) ** self.rho 262 | 263 | sigmas = F.pad(sigmas, (0, 1), value = 0.) # last step is sigma value of 0. 264 | return sigmas 265 | 266 | @torch.no_grad() 267 | def sample(self, cond, num_sample_steps = None, clamp = None): 268 | clamp = default(clamp, self.clamp_during_sampling) 269 | num_sample_steps = default(num_sample_steps, self.num_sample_steps) 270 | 271 | shape = (cond.shape[0], self.dim) 272 | 273 | # get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma 274 | 275 | sigmas = self.sample_schedule(num_sample_steps) 276 | 277 | gammas = torch.where( 278 | (sigmas >= self.S_tmin) & (sigmas <= self.S_tmax), 279 | min(self.S_churn / num_sample_steps, sqrt(2) - 1), 280 | 0. 281 | ) 282 | 283 | sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1])) 284 | 285 | # images is noise at the beginning 286 | 287 | init_sigma = sigmas[0] 288 | 289 | seq = init_sigma * torch.randn(shape, device = self.device) 290 | 291 | # gradually denoise 292 | 293 | for sigma, sigma_next, gamma in tqdm(sigmas_and_gammas, desc = 'sampling time step'): 294 | sigma, sigma_next, gamma = map(lambda t: t.item(), (sigma, sigma_next, gamma)) 295 | 296 | eps = self.S_noise * torch.randn(shape, device = self.device) # stochastic sampling 297 | 298 | sigma_hat = sigma + gamma * sigma 299 | seq_hat = seq + sqrt(sigma_hat ** 2 - sigma ** 2) * eps 300 | 301 | model_output = self.preconditioned_network_forward(seq_hat, sigma_hat, cond = cond, clamp = clamp) 302 | denoised_over_sigma = (seq_hat - model_output) / sigma_hat 303 | 304 | seq_next = seq_hat + (sigma_next - sigma_hat) * denoised_over_sigma 305 | 306 | # second order correction, if not the last timestep 307 | 308 | if sigma_next != 0: 309 | model_output_next = self.preconditioned_network_forward(seq_next, sigma_next, cond = cond, clamp = clamp) 310 | denoised_prime_over_sigma = (seq_next - model_output_next) / sigma_next 311 | seq_next = seq_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma) 312 | 313 | seq = seq_next 314 | 315 | if clamp: 316 | seq = seq.clamp(-1., 1.) 317 | 318 | return seq 319 | 320 | # training 321 | 322 | def loss_weight(self, sigma): 323 | return (sigma ** 2 + self.sigma_data ** 2) * (sigma * self.sigma_data) ** -2 324 | 325 | def noise_distribution(self, batch_size): 326 | return (self.P_mean + self.P_std * torch.randn((batch_size,), device = self.device)).exp() 327 | 328 | def forward(self, seq, *, cond): 329 | batch_size, dim, device = *seq.shape, self.device 330 | 331 | assert dim == self.dim, f'dimension of sequence being passed in must be {self.dim} but received {dim}' 332 | 333 | sigmas = self.noise_distribution(batch_size) 334 | padded_sigmas = right_pad_dims_to(seq, sigmas) 335 | 336 | noise = torch.randn_like(seq) 337 | 338 | noised_seq = seq + padded_sigmas * noise # alphas are 1. in the paper 339 | 340 | denoised = self.preconditioned_network_forward(noised_seq, sigmas, cond = cond) 341 | 342 | losses = F.mse_loss(denoised, seq, reduction = 'none') 343 | losses = reduce(losses, 'b ... -> b', 'mean') 344 | 345 | losses = losses * self.loss_weight(sigmas) 346 | 347 | return losses.mean() 348 | 349 | # main model, a decoder with continuous wrapper + small denoising mlp 350 | 351 | class AutoregressiveDiffusion(Module): 352 | def __init__( 353 | self, 354 | dim, 355 | *, 356 | max_seq_len, 357 | depth = 8, 358 | dim_head = 64, 359 | heads = 8, 360 | mlp_depth = 3, 361 | mlp_width = None, 362 | dim_input = None, 363 | decoder_kwargs: dict = dict(), 364 | mlp_kwargs: dict = dict(), 365 | diffusion_kwargs: dict = dict( 366 | clamp_during_sampling = True 367 | ) 368 | ): 369 | super().__init__() 370 | 371 | self.start_token = nn.Parameter(torch.zeros(dim)) 372 | self.max_seq_len = max_seq_len 373 | self.abs_pos_emb = nn.Embedding(max_seq_len, dim) 374 | 375 | dim_input = default(dim_input, dim) 376 | self.dim_input = dim_input 377 | self.proj_in = nn.Linear(dim_input, dim) 378 | 379 | self.transformer = Decoder( 380 | dim = dim, 381 | depth = depth, 382 | heads = heads, 383 | attn_dim_head = dim_head, 384 | **decoder_kwargs 385 | ) 386 | 387 | self.denoiser = MLP( 388 | dim_cond = dim, 389 | dim_input = dim_input, 390 | depth = mlp_depth, 391 | width = default(mlp_width, dim), 392 | **mlp_kwargs 393 | ) 394 | 395 | self.diffusion = ElucidatedDiffusion( 396 | dim_input, 397 | self.denoiser, 398 | **diffusion_kwargs 399 | ) 400 | 401 | @property 402 | def device(self): 403 | return next(self.transformer.parameters()).device 404 | 405 | @torch.no_grad() 406 | def sample( 407 | self, 408 | batch_size = 1, 409 | prompt = None 410 | ): 411 | self.eval() 412 | 413 | start_tokens = repeat(self.start_token, 'd -> b 1 d', b = batch_size) 414 | 415 | if not exists(prompt): 416 | out = torch.empty((batch_size, 0, self.dim_input), device = self.device, dtype = torch.float32) 417 | else: 418 | out = prompt 419 | 420 | cache = None 421 | 422 | for _ in tqdm(range(self.max_seq_len - out.shape[1]), desc = 'tokens'): 423 | 424 | cond = self.proj_in(out) 425 | 426 | cond = torch.cat((start_tokens, cond), dim = 1) 427 | cond = cond + self.abs_pos_emb(torch.arange(cond.shape[1], device = self.device)) 428 | 429 | cond, cache = self.transformer(cond, cache = cache, return_hiddens = True) 430 | 431 | last_cond = cond[:, -1] 432 | 433 | denoised_pred = self.diffusion.sample(cond = last_cond) 434 | 435 | denoised_pred = rearrange(denoised_pred, 'b d -> b 1 d') 436 | out = torch.cat((out, denoised_pred), dim = 1) 437 | 438 | return out 439 | 440 | def forward( 441 | self, 442 | seq 443 | ): 444 | b, seq_len, dim = seq.shape 445 | 446 | assert dim == self.dim_input 447 | assert seq_len == self.max_seq_len 448 | 449 | # break into seq and the continuous targets to be predicted 450 | 451 | seq, target = seq[:, :-1], seq 452 | 453 | # append start tokens 454 | 455 | seq = self.proj_in(seq) 456 | start_token = repeat(self.start_token, 'd -> b 1 d', b = b) 457 | 458 | seq = torch.cat((start_token, seq), dim = 1) 459 | seq = seq + self.abs_pos_emb(torch.arange(seq_len, device = self.device)) 460 | 461 | cond = self.transformer(seq) 462 | 463 | # pack batch and sequence dimensions, so to train each token with different noise levels 464 | 465 | target, _ = pack_one(target, '* d') 466 | cond, _ = pack_one(cond, '* d') 467 | 468 | diffusion_loss = self.diffusion(target, cond = cond) 469 | 470 | return diffusion_loss 471 | 472 | # image wrapper 473 | 474 | def normalize_to_neg_one_to_one(img): 475 | return img * 2 - 1 476 | 477 | def unnormalize_to_zero_to_one(t): 478 | return (t + 1) * 0.5 479 | 480 | class ImageAutoregressiveDiffusion(Module): 481 | def __init__( 482 | self, 483 | *, 484 | image_size, 485 | patch_size, 486 | channels = 3, 487 | model: dict = dict(), 488 | ): 489 | super().__init__() 490 | assert divisible_by(image_size, patch_size) 491 | 492 | num_patches = (image_size // patch_size) ** 2 493 | dim_in = channels * patch_size ** 2 494 | 495 | self.image_size = image_size 496 | self.patch_size = patch_size 497 | 498 | self.to_tokens = Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size) 499 | 500 | self.model = AutoregressiveDiffusion( 501 | **model, 502 | dim_input = dim_in, 503 | max_seq_len = num_patches 504 | ) 505 | 506 | self.to_image = Rearrange('b (h w) (c p1 p2) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size, h = int(math.sqrt(num_patches))) 507 | 508 | def sample(self, batch_size = 1): 509 | tokens = self.model.sample(batch_size = batch_size) 510 | images = self.to_image(tokens) 511 | return unnormalize_to_zero_to_one(images) 512 | 513 | def forward(self, images): 514 | images = normalize_to_neg_one_to_one(images) 515 | tokens = self.to_tokens(images) 516 | return self.model(tokens) 517 | --------------------------------------------------------------------------------