├── ar-diffusion.png
├── images
└── results.96600.png
├── autoregressive_diffusion_pytorch
├── __init__.py
├── image_trainer.py
├── autoregressive_flow.py
└── autoregressive_diffusion.py
├── LICENSE
├── .github
└── workflows
│ └── python-publish.yml
├── train_ar_flow_oxford.py
├── pyproject.toml
├── .gitignore
└── README.md
/ar-diffusion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/autoregressive-diffusion-pytorch/HEAD/ar-diffusion.png
--------------------------------------------------------------------------------
/images/results.96600.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/autoregressive-diffusion-pytorch/HEAD/images/results.96600.png
--------------------------------------------------------------------------------
/autoregressive_diffusion_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from autoregressive_diffusion_pytorch.autoregressive_diffusion import (
2 | MLP,
3 | AutoregressiveDiffusion,
4 | ImageAutoregressiveDiffusion
5 | )
6 |
7 | from autoregressive_diffusion_pytorch.autoregressive_flow import (
8 | AutoregressiveFlow,
9 | ImageAutoregressiveFlow
10 | )
11 |
12 | from autoregressive_diffusion_pytorch.image_trainer import (
13 | ImageDataset,
14 | ImageTrainer
15 | )
16 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | jobs:
16 | deploy:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: '3.x'
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install build
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34 | with:
35 | user: __token__
36 | password: ${{ secrets.PYPI_API_TOKEN }}
37 |
--------------------------------------------------------------------------------
/train_ar_flow_oxford.py:
--------------------------------------------------------------------------------
1 | # hf datasets for easy oxford flowers training
2 |
3 | import torchvision.transforms as T
4 | from torch.utils.data import Dataset
5 | from datasets import load_dataset
6 |
7 | class OxfordFlowersDataset(Dataset):
8 | def __init__(
9 | self,
10 | image_size
11 | ):
12 | self.ds = load_dataset('nelorth/oxford-flowers')['train']
13 |
14 | self.transform = T.Compose([
15 | T.Resize((image_size, image_size)),
16 | T.PILToTensor()
17 | ])
18 |
19 | def __len__(self):
20 | return len(self.ds)
21 |
22 | def __getitem__(self, idx):
23 | pil = self.ds[idx]['image']
24 | tensor = self.transform(pil)
25 | return tensor / 255.
26 |
27 | flowers_dataset = OxfordFlowersDataset(
28 | image_size = 64
29 | )
30 |
31 | from autoregressive_diffusion_pytorch import ImageAutoregressiveFlow, ImageTrainer
32 |
33 | model = ImageAutoregressiveFlow(
34 | model = dict(
35 | dim = 1024,
36 | depth = 8,
37 | heads = 8,
38 | mlp_depth = 4,
39 | decoder_kwargs = dict(
40 | rotary_pos_emb = True
41 | )
42 | ),
43 | image_size = 64,
44 | patch_size = 8,
45 | model_output_clean = True
46 | )
47 |
48 | trainer = ImageTrainer(
49 | model,
50 | dataset = flowers_dataset,
51 | num_train_steps = 1_000_000,
52 | learning_rate = 7e-5,
53 | batch_size = 32
54 | )
55 |
56 | trainer()
57 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "autoregressive-diffusion-pytorch"
3 | version = "0.3.0"
4 | description = "Autoregressive Diffusion - Pytorch"
5 | authors = [
6 | { name = "Phil Wang", email = "lucidrains@gmail.com" }
7 | ]
8 | readme = "README.md"
9 | requires-python = ">= 3.8"
10 | license = { file = "LICENSE" }
11 | keywords = [
12 | 'artificial intelligence',
13 | 'deep learning',
14 | 'transformers',
15 | 'denoising diffusion',
16 | ]
17 | classifiers=[
18 | 'Development Status :: 4 - Beta',
19 | 'Intended Audience :: Developers',
20 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
21 | 'License :: OSI Approved :: MIT License',
22 | 'Programming Language :: Python :: 3.8',
23 | ]
24 |
25 | dependencies = [
26 | 'accelerate<=1.6.0',
27 | 'einx>=0.3.0',
28 | 'einops>=0.8.0',
29 | 'ema-pytorch',
30 | 'x-transformers>=1.31.14',
31 | 'torch>=2.0',
32 | 'torchdiffeq',
33 | 'tqdm'
34 | ]
35 |
36 | [project.urls]
37 | Homepage = "https://pypi.org/project/autoregressive-diffusion-pytorch/"
38 | Repository = "https://github.com/lucidrains/autoregressive-diffusion-pytorch"
39 |
40 | [project.optional-dependencies]
41 | examples = ["tqdm", "numpy"]
42 |
43 | [build-system]
44 | requires = ["hatchling"]
45 | build-backend = "hatchling.build"
46 |
47 | [tool.hatch.metadata]
48 | allow-direct-references = true
49 |
50 | [tool.hatch.build.targets.wheel]
51 | packages = ["autoregressive_diffusion_pytorch"]
52 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | results/
2 | checkpoints/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 | cover/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | .pybuilder/
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | # For a library or package, you might want to ignore these files since the code is
90 | # intended to run in multiple environments; otherwise, check them in:
91 | # .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # poetry
101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102 | # This is especially recommended for binary packages to ensure reproducibility, and is more
103 | # commonly ignored for libraries.
104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105 | #poetry.lock
106 |
107 | # pdm
108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109 | #pdm.lock
110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111 | # in version control.
112 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
113 | .pdm.toml
114 | .pdm-python
115 | .pdm-build/
116 |
117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118 | __pypackages__/
119 |
120 | # Celery stuff
121 | celerybeat-schedule
122 | celerybeat.pid
123 |
124 | # SageMath parsed files
125 | *.sage.py
126 |
127 | # Environments
128 | .env
129 | .venv
130 | env/
131 | venv/
132 | ENV/
133 | env.bak/
134 | venv.bak/
135 |
136 | # Spyder project settings
137 | .spyderproject
138 | .spyproject
139 |
140 | # Rope project settings
141 | .ropeproject
142 |
143 | # mkdocs documentation
144 | /site
145 |
146 | # mypy
147 | .mypy_cache/
148 | .dmypy.json
149 | dmypy.json
150 |
151 | # Pyre type checker
152 | .pyre/
153 |
154 | # pytype static type analyzer
155 | .pytype/
156 |
157 | # Cython debug symbols
158 | cython_debug/
159 |
160 | # PyCharm
161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163 | # and can be added to the global gitignore or merged into this file. For a more nuclear
164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165 | #.idea/
166 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Autoregressive Diffusion - Pytorch
4 |
5 | Implementation of the architecture behind Autoregressive Image Generation without Vector Quantization in Pytorch
6 |
7 | Official repository has been released here
8 |
9 | Alternative route
10 |
11 |
12 |
13 | *oxford flowers at 96k steps*
14 |
15 | ## Install
16 |
17 | ```bash
18 | $ pip install autoregressive-diffusion-pytorch
19 | ```
20 |
21 | ## Usage
22 |
23 | ```python
24 | import torch
25 | from autoregressive_diffusion_pytorch import AutoregressiveDiffusion
26 |
27 | model = AutoregressiveDiffusion(
28 | dim_input = 512,
29 | dim = 1024,
30 | max_seq_len = 32,
31 | depth = 8,
32 | mlp_depth = 3,
33 | mlp_width = 1024
34 | )
35 |
36 | seq = torch.randn(3, 32, 512)
37 |
38 | loss = model(seq)
39 | loss.backward()
40 |
41 | sampled = model.sample(batch_size = 3)
42 |
43 | assert sampled.shape == seq.shape
44 |
45 | ```
46 |
47 | For images treated as a sequence of tokens (as in paper)
48 |
49 | ```python
50 | import torch
51 | from autoregressive_diffusion_pytorch import ImageAutoregressiveDiffusion
52 |
53 | model = ImageAutoregressiveDiffusion(
54 | model = dict(
55 | dim = 1024,
56 | depth = 12,
57 | heads = 12,
58 | ),
59 | image_size = 64,
60 | patch_size = 8
61 | )
62 |
63 | images = torch.randn(3, 3, 64, 64)
64 |
65 | loss = model(images)
66 | loss.backward()
67 |
68 | sampled = model.sample(batch_size = 3)
69 |
70 | assert sampled.shape == images.shape
71 |
72 | ```
73 |
74 | An images trainer
75 |
76 | ```python
77 | import torch
78 |
79 | from autoregressive_diffusion_pytorch import (
80 | ImageDataset,
81 | ImageAutoregressiveDiffusion,
82 | ImageTrainer
83 | )
84 |
85 | dataset = ImageDataset(
86 | '/path/to/your/images',
87 | image_size = 128
88 | )
89 |
90 | model = ImageAutoregressiveDiffusion(
91 | model = dict(
92 | dim = 512
93 | ),
94 | image_size = 128,
95 | patch_size = 16
96 | )
97 |
98 | trainer = ImageTrainer(
99 | model = model,
100 | dataset = dataset
101 | )
102 |
103 | trainer()
104 | ```
105 |
106 | For an improvised version using flow matching, just import `ImageAutoregressiveFlow` and `AutoregressiveFlow` instead
107 |
108 | The rest is the same
109 |
110 | ex.
111 |
112 | ```python
113 | import torch
114 |
115 | from autoregressive_diffusion_pytorch import (
116 | ImageDataset,
117 | ImageTrainer,
118 | ImageAutoregressiveFlow,
119 | )
120 |
121 | dataset = ImageDataset(
122 | '/path/to/your/images',
123 | image_size = 128
124 | )
125 |
126 | model = ImageAutoregressiveFlow(
127 | model = dict(
128 | dim = 512
129 | ),
130 | image_size = 128,
131 | patch_size = 16
132 | )
133 |
134 | trainer = ImageTrainer(
135 | model = model,
136 | dataset = dataset
137 | )
138 |
139 | trainer()
140 | ```
141 |
142 | ## Citations
143 |
144 | ```bibtex
145 | @article{Li2024AutoregressiveIG,
146 | title = {Autoregressive Image Generation without Vector Quantization},
147 | author = {Tianhong Li and Yonglong Tian and He Li and Mingyang Deng and Kaiming He},
148 | journal = {ArXiv},
149 | year = {2024},
150 | volume = {abs/2406.11838},
151 | url = {https://api.semanticscholar.org/CorpusID:270560593}
152 | }
153 | ```
154 |
155 | ```bibtex
156 | @article{Wu2023ARDiffusionAD,
157 | title = {AR-Diffusion: Auto-Regressive Diffusion Model for Text Generation},
158 | author = {Tong Wu and Zhihao Fan and Xiao Liu and Yeyun Gong and Yelong Shen and Jian Jiao and Haitao Zheng and Juntao Li and Zhongyu Wei and Jian Guo and Nan Duan and Weizhu Chen},
159 | journal = {ArXiv},
160 | year = {2023},
161 | volume = {abs/2305.09515},
162 | url = {https://api.semanticscholar.org/CorpusID:258714669}
163 | }
164 | ```
165 |
166 | ```bibtex
167 | @article{Karras2022ElucidatingTD,
168 | title = {Elucidating the Design Space of Diffusion-Based Generative Models},
169 | author = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine},
170 | journal = {ArXiv},
171 | year = {2022},
172 | volume = {abs/2206.00364},
173 | url = {https://api.semanticscholar.org/CorpusID:249240415}
174 | }
175 | ```
176 |
177 | ```bibtex
178 | @article{Liu2022FlowSA,
179 | title = {Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow},
180 | author = {Xingchao Liu and Chengyue Gong and Qiang Liu},
181 | journal = {ArXiv},
182 | year = {2022},
183 | volume = {abs/2209.03003},
184 | url = {https://api.semanticscholar.org/CorpusID:252111177}
185 | }
186 | ```
187 |
188 | ```bibtex
189 | @article{Esser2024ScalingRF,
190 | title = {Scaling Rectified Flow Transformers for High-Resolution Image Synthesis},
191 | author = {Patrick Esser and Sumith Kulal and A. Blattmann and Rahim Entezari and Jonas Muller and Harry Saini and Yam Levi and Dominik Lorenz and Axel Sauer and Frederic Boesel and Dustin Podell and Tim Dockhorn and Zion English and Kyle Lacey and Alex Goodwin and Yannik Marek and Robin Rombach},
192 | journal = {ArXiv},
193 | year = {2024},
194 | volume = {abs/2403.03206},
195 | url = {https://api.semanticscholar.org/CorpusID:268247980}
196 | }
197 | ```
198 |
199 | ```bibtex
200 | @misc{li2025basicsletdenoisinggenerative,
201 | title = {Back to Basics: Let Denoising Generative Models Denoise},
202 | author = {Tianhong Li and Kaiming He},
203 | year = {2025},
204 | eprint = {2511.13720},
205 | archivePrefix = {arXiv},
206 | primaryClass = {cs.CV},
207 | url = {https://arxiv.org/abs/2511.13720},
208 | }
209 | ```
210 |
--------------------------------------------------------------------------------
/autoregressive_diffusion_pytorch/image_trainer.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import math
4 | from pathlib import Path
5 |
6 | from accelerate import Accelerator
7 | from ema_pytorch import EMA
8 |
9 | import torch
10 | from torch import nn
11 | from torch.optim import Adam
12 | from torch.utils.data import DataLoader
13 | from torch.nn import Module, ModuleList
14 | from torch.utils.data import Dataset
15 |
16 | from torchvision.utils import save_image
17 | import torchvision.transforms as T
18 |
19 | from PIL import Image
20 |
21 | # functions
22 |
23 | def exists(v):
24 | return v is not None
25 |
26 | def default(v, d):
27 | return v if exists(v) else d
28 |
29 | def divisible_by(num, den):
30 | return (num % den) == 0
31 |
32 | def cycle(dl):
33 | while True:
34 | for batch in dl:
35 | yield batch
36 |
37 | # dataset classes
38 |
39 | class ImageDataset(Dataset):
40 | def __init__(
41 | self,
42 | folder: str | Path,
43 | image_size: int,
44 | exts: List[str] = ['jpg', 'jpeg', 'png', 'tiff'],
45 | augment_horizontal_flip = False,
46 | convert_image_to = None
47 | ):
48 | super().__init__()
49 | if isinstance(folder, str):
50 | folder = Path(folder)
51 |
52 | assert folder.is_dir()
53 |
54 | self.folder = folder
55 | self.image_size = image_size
56 |
57 | self.paths = [p for ext in exts for p in folder.glob(f'**/*.{ext}')]
58 |
59 | def convert_image_to_fn(img_type, image):
60 | if image.mode == img_type:
61 | return image
62 |
63 | return image.convert(img_type)
64 |
65 | maybe_convert_fn = partial(convert_image_to_fn, convert_image_to) if exists(convert_image_to) else nn.Identity()
66 |
67 | self.transform = T.Compose([
68 | T.Lambda(maybe_convert_fn),
69 | T.Resize(image_size),
70 | T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
71 | T.CenterCrop(image_size),
72 | T.ToTensor()
73 | ])
74 |
75 | def __len__(self):
76 | return len(self.paths)
77 |
78 | def __getitem__(self, index):
79 | path = self.paths[index]
80 | img = Image.open(path)
81 | return self.transform(img)
82 |
83 | # trainer
84 |
85 | class ImageTrainer(Module):
86 | def __init__(
87 | self,
88 | model,
89 | *,
90 | dataset: Dataset,
91 | num_train_steps = 70_000,
92 | learning_rate = 3e-4,
93 | batch_size = 16,
94 | checkpoints_folder: str = './checkpoints',
95 | results_folder: str = './results',
96 | save_results_every: int = 100,
97 | checkpoint_every: int = 1000,
98 | num_samples: int = 16,
99 | adam_kwargs: dict = dict(),
100 | accelerate_kwargs: dict = dict(),
101 | ema_kwargs: dict = dict()
102 | ):
103 | super().__init__()
104 | self.accelerator = Accelerator(**accelerate_kwargs)
105 |
106 | self.model = model
107 |
108 | if self.is_main:
109 | self.ema_model = EMA(
110 | self.model,
111 | forward_method_names = ('sample',),
112 | **ema_kwargs
113 | )
114 |
115 | self.ema_model.to(self.accelerator.device)
116 |
117 | self.optimizer = Adam(model.parameters(), lr = learning_rate, **adam_kwargs)
118 | self.dl = DataLoader(dataset, batch_size = batch_size, shuffle = True, drop_last = True)
119 |
120 | self.model, self.optimizer, self.dl = self.accelerator.prepare(self.model, self.optimizer, self.dl)
121 |
122 | self.num_train_steps = num_train_steps
123 |
124 | self.checkpoints_folder = Path(checkpoints_folder)
125 | self.results_folder = Path(results_folder)
126 |
127 | self.checkpoints_folder.mkdir(exist_ok = True, parents = True)
128 | self.results_folder.mkdir(exist_ok = True, parents = True)
129 |
130 | self.checkpoint_every = checkpoint_every
131 | self.save_results_every = save_results_every
132 |
133 | self.num_sample_rows = int(math.sqrt(num_samples))
134 | assert (self.num_sample_rows ** 2) == num_samples, f'{num_samples} must be a square'
135 | self.num_samples = num_samples
136 |
137 | assert self.checkpoints_folder.is_dir()
138 | assert self.results_folder.is_dir()
139 |
140 | @property
141 | def is_main(self):
142 | return self.accelerator.is_main_process
143 |
144 | def save(self, path):
145 | if not self.is_main:
146 | return
147 |
148 | save_package = dict(
149 | model = self.accelerator.unwrap_model(self.model).state_dict(),
150 | ema_model = self.ema_model.state_dict(),
151 | optimizer = self.accelerator.unwrap_model(self.optimizer).state_dict(),
152 | )
153 |
154 | torch.save(save_package, str(self.checkpoints_folder / path))
155 |
156 | def forward(self):
157 |
158 | dl = cycle(self.dl)
159 |
160 | for ind in range(self.num_train_steps):
161 | step = ind + 1
162 |
163 | self.model.train()
164 |
165 | data = next(dl)
166 | loss = self.model(data)
167 |
168 | self.accelerator.print(f'[{step}] loss: {loss.item():.3f}')
169 | self.accelerator.backward(loss)
170 |
171 | self.optimizer.step()
172 | self.optimizer.zero_grad()
173 |
174 | if self.is_main:
175 | self.ema_model.update()
176 |
177 | self.accelerator.wait_for_everyone()
178 |
179 | if self.is_main:
180 | if divisible_by(step, self.save_results_every):
181 |
182 | with torch.no_grad():
183 | sampled = self.ema_model.sample(batch_size = self.num_samples)
184 |
185 | sampled.clamp_(0., 1.)
186 | save_image(sampled, str(self.results_folder / f'results.{step}.png'), nrow = self.num_sample_rows)
187 |
188 | if divisible_by(step, self.checkpoint_every):
189 | self.save(f'checkpoint.{step}.pt')
190 |
191 | self.accelerator.wait_for_everyone()
192 |
193 |
194 | print('training complete')
195 |
--------------------------------------------------------------------------------
/autoregressive_diffusion_pytorch/autoregressive_flow.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import math
4 | from math import sqrt
5 | from typing import Literal
6 | from functools import partial
7 |
8 | import torch
9 | from torch import nn, pi
10 | import torch.nn.functional as F
11 | from torch.nn import Module, ModuleList
12 |
13 | from torchdiffeq import odeint
14 |
15 | import einx
16 | from einops import rearrange, repeat, reduce, pack, unpack
17 | from einops.layers.torch import Rearrange
18 |
19 | from tqdm import tqdm
20 |
21 | from x_transformers import Decoder
22 |
23 | from autoregressive_diffusion_pytorch.autoregressive_diffusion import MLP
24 |
25 | # helpers
26 |
27 | def exists(v):
28 | return v is not None
29 |
30 | def default(v, d):
31 | return v if exists(v) else d
32 |
33 | def divisible_by(num, den):
34 | return (num % den) == 0
35 |
36 | def cast_tuple(t):
37 | return (t,) if not isinstance(t, tuple) else t
38 |
39 | # tensor helpers
40 |
41 | def log(t, eps = 1e-20):
42 | return torch.log(t.clamp(min = eps))
43 |
44 | def safe_div(num, den, eps = 1e-5):
45 | return num / den.clamp(min = eps)
46 |
47 | def right_pad_dims_to(x, t):
48 | padding_dims = x.ndim - t.ndim
49 |
50 | if padding_dims <= 0:
51 | return t
52 |
53 | return t.view(*t.shape, *((1,) * padding_dims))
54 |
55 | def pack_one(t, pattern):
56 | packed, ps = pack([t], pattern)
57 |
58 | def unpack_one(to_unpack, unpack_pattern = None):
59 | unpacked, = unpack(to_unpack, ps, default(unpack_pattern, pattern))
60 | return unpacked
61 |
62 | return packed, unpack_one
63 |
64 | # rectified flow
65 |
66 | class Flow(Module):
67 | def __init__(
68 | self,
69 | dim: int,
70 | net: MLP,
71 | *,
72 | atol = 1e-5,
73 | rtol = 1e-5,
74 | method = 'midpoint',
75 | model_output_clean = True,
76 | eps = 5e-2
77 | ):
78 | super().__init__()
79 | self.net = net
80 | self.dim = dim
81 |
82 | self.odeint_kwargs = dict(
83 | atol = atol,
84 | rtol = rtol,
85 | method = method
86 | )
87 |
88 | self.model_output_clean = model_output_clean
89 | self.eps = eps
90 |
91 | @property
92 | def device(self):
93 | return next(self.net.parameters()).device
94 |
95 | @torch.no_grad()
96 | def sample(
97 | self,
98 | cond,
99 | num_sample_steps = 16
100 | ):
101 |
102 | batch = cond.shape[0]
103 |
104 | sampled_data_shape = (batch, self.dim)
105 |
106 | # start with random gaussian noise - y0
107 |
108 | noise = torch.randn(sampled_data_shape, device = self.device)
109 |
110 | # time steps
111 |
112 | times = torch.linspace(0., 1., num_sample_steps, device = self.device)
113 |
114 | # ode
115 |
116 | def ode_fn(t, x):
117 | t = repeat(t, '-> b', b = batch)
118 | out = self.net(x, times = t, cond = cond)
119 |
120 | if not self.model_output_clean:
121 | flow = out
122 | else:
123 | t = right_pad_dims_to(x, t)
124 | flow = (out - x) / (1. - t).clamp_min(self.eps)
125 |
126 | return flow
127 |
128 | trajectory = odeint(ode_fn, noise, times, **self.odeint_kwargs)
129 |
130 | sampled = trajectory[-1]
131 |
132 | return sampled
133 |
134 | # training
135 |
136 | def forward(self, seq, *, cond):
137 | batch_size, dim, device = *seq.shape, self.device
138 |
139 | assert dim == self.dim, f'dimension of sequence being passed in must be {self.dim} but received {dim}'
140 |
141 | times = torch.rand(batch_size, device = device)
142 | noise = torch.randn_like(seq)
143 | padded_times = right_pad_dims_to(seq, times)
144 |
145 | flow = seq - noise
146 |
147 | noised = (1.- padded_times) * noise + padded_times * seq
148 |
149 | model_out = self.net(noised, times = times, cond = cond)
150 |
151 | if not self.model_output_clean:
152 | pred_flow = model_out
153 | else:
154 | pred_flow = (model_out - noised) / (1. - padded_times).clamp_min(self.eps)
155 |
156 | return F.mse_loss(pred_flow, flow)
157 |
158 | # main model, a decoder with continuous wrapper + small denoising mlp
159 |
160 | class AutoregressiveFlow(Module):
161 | def __init__(
162 | self,
163 | dim,
164 | *,
165 | max_seq_len: int | tuple[int, ...],
166 | depth = 8,
167 | dim_head = 64,
168 | heads = 8,
169 | mlp_depth = 3,
170 | mlp_width = 1024,
171 | dim_input = None,
172 | decoder_kwargs: dict = dict(),
173 | mlp_kwargs: dict = dict(),
174 | flow_kwargs: dict = dict(),
175 | model_output_clean = True # output in x-space
176 | ):
177 | super().__init__()
178 |
179 | self.start_token = nn.Parameter(torch.zeros(dim))
180 |
181 | max_seq_len = cast_tuple(max_seq_len)
182 | self.abs_pos_emb = nn.ParameterList([nn.Parameter(torch.zeros(seq_len, dim)) for seq_len in max_seq_len])
183 |
184 | self.max_seq_len = math.prod(max_seq_len)
185 |
186 | dim_input = default(dim_input, dim)
187 | self.dim_input = dim_input
188 | self.proj_in = nn.Linear(dim_input, dim)
189 |
190 | self.transformer = Decoder(
191 | dim = dim,
192 | depth = depth,
193 | heads = heads,
194 | attn_dim_head = dim_head,
195 | **decoder_kwargs
196 | )
197 |
198 | self.to_cond_emb = nn.Linear(dim, dim, bias = False)
199 |
200 | self.denoiser = MLP(
201 | dim_cond = dim,
202 | dim_input = dim_input,
203 | depth = mlp_depth,
204 | width = mlp_width,
205 | **mlp_kwargs
206 | )
207 |
208 | self.flow = Flow(
209 | dim_input,
210 | self.denoiser,
211 | model_output_clean = model_output_clean,
212 | **flow_kwargs
213 | )
214 |
215 | @property
216 | def device(self):
217 | return next(self.transformer.parameters()).device
218 |
219 | def axial_pos_emb(self):
220 | # prepare maybe axial positional embedding
221 |
222 | pos_emb, *rest_pos_embs = self.abs_pos_emb
223 |
224 | for rest_pos_emb in rest_pos_embs:
225 | pos_emb = einx.add('i d, j d -> (i j) d', pos_emb, rest_pos_emb)
226 |
227 | return F.pad(pos_emb, (0, 0, 1, 0), value = 0.)
228 |
229 | @torch.no_grad()
230 | def sample(
231 | self,
232 | batch_size = 1,
233 | prompt = None
234 | ):
235 | self.eval()
236 |
237 | start_tokens = repeat(self.start_token, 'd -> b 1 d', b = batch_size)
238 |
239 | if not exists(prompt):
240 | out = torch.empty((batch_size, 0, self.dim_input), device = self.device, dtype = torch.float32)
241 | else:
242 | out = prompt
243 |
244 | cache = None
245 |
246 | for _ in tqdm(range(self.max_seq_len - out.shape[1]), desc = 'tokens'):
247 |
248 | cond = self.proj_in(out)
249 |
250 | cond = torch.cat((start_tokens, cond), dim = 1)
251 |
252 | seq_len = cond.shape[-2]
253 | axial_pos_emb = self.axial_pos_emb()
254 | cond += axial_pos_emb[:seq_len]
255 |
256 | cond, cache = self.transformer(cond, cache = cache, return_hiddens = True)
257 |
258 | last_cond = cond[:, -1]
259 |
260 | last_cond += axial_pos_emb[seq_len]
261 | last_cond = self.to_cond_emb(last_cond)
262 |
263 | denoised_pred = self.flow.sample(cond = last_cond)
264 |
265 | denoised_pred = rearrange(denoised_pred, 'b d -> b 1 d')
266 | out = torch.cat((out, denoised_pred), dim = 1)
267 |
268 | return out
269 |
270 | def forward(
271 | self,
272 | seq,
273 | noised_seq = None
274 | ):
275 | b, seq_len, dim = seq.shape
276 |
277 | assert dim == self.dim_input
278 | assert seq_len == self.max_seq_len
279 |
280 | # break into seq and the continuous targets to be predicted
281 |
282 | seq, target = seq[:, :-1], seq
283 |
284 | if exists(noised_seq):
285 | seq = noised_seq[:, :-1]
286 |
287 | # append start tokens
288 |
289 | seq = self.proj_in(seq)
290 | start_token = repeat(self.start_token, 'd -> b 1 d', b = b)
291 |
292 | seq = torch.cat((start_token, seq), dim = 1)
293 |
294 | axial_pos_emb = self.axial_pos_emb()
295 | seq = seq + axial_pos_emb[:seq_len]
296 |
297 | cond = self.transformer(seq)
298 |
299 | cond = cond + axial_pos_emb[1:(seq_len + 1)]
300 | cond = self.to_cond_emb(cond)
301 |
302 | # pack batch and sequence dimensions, so to train each token with different noise levels
303 |
304 | target, _ = pack_one(target, '* d')
305 | cond, _ = pack_one(cond, '* d')
306 |
307 | return self.flow(target, cond = cond)
308 |
309 | # image wrapper
310 |
311 | def normalize_to_neg_one_to_one(img):
312 | return img * 2 - 1
313 |
314 | def unnormalize_to_zero_to_one(t):
315 | return (t + 1) * 0.5
316 |
317 | class ImageAutoregressiveFlow(Module):
318 | def __init__(
319 | self,
320 | *,
321 | image_size,
322 | patch_size,
323 | channels = 3,
324 | train_max_noise = 0.,
325 | model_output_clean = True, # for outputting in x-space
326 | model: dict = dict(),
327 | ):
328 | super().__init__()
329 | assert divisible_by(image_size, patch_size)
330 |
331 | patch_height_width = image_size // patch_size
332 | num_patches = patch_height_width ** 2
333 | dim_in = channels * patch_size ** 2
334 |
335 | self.image_size = image_size
336 | self.patch_size = patch_size
337 |
338 | assert 0. <= train_max_noise < 1.
339 |
340 | self.train_max_noise = train_max_noise
341 |
342 | self.to_tokens = Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size)
343 |
344 | self.model = AutoregressiveFlow(
345 | **model,
346 | dim_input = dim_in,
347 | max_seq_len = (patch_height_width, patch_height_width),
348 | model_output_clean = model_output_clean
349 | )
350 |
351 | self.to_image = Rearrange('b (h w) (c p1 p2) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size, h = int(math.sqrt(num_patches)))
352 |
353 | def sample(self, batch_size = 1):
354 | tokens = self.model.sample(batch_size = batch_size)
355 | images = self.to_image(tokens)
356 | return unnormalize_to_zero_to_one(images)
357 |
358 | def forward(self, images):
359 | train_under_noise, device = self.train_max_noise > 0., images.device
360 |
361 | images = normalize_to_neg_one_to_one(images)
362 | tokens = self.to_tokens(images)
363 |
364 | if not train_under_noise:
365 | return self.model(tokens)
366 |
367 | # allow for the network to predict from slightly noised images of the past
368 |
369 | times = torch.rand(images.shape[0], device = device) * self.train_max_noise
370 | noise = torch.randn_like(images)
371 | padded_times = right_pad_dims_to(images, times)
372 | noised_images = images * (1. - padded_times) + noise * padded_times
373 | noised_tokens = self.to_tokens(noised_images)
374 |
375 | return self.model(tokens, noised_tokens)
376 |
--------------------------------------------------------------------------------
/autoregressive_diffusion_pytorch/autoregressive_diffusion.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import math
4 | from math import sqrt
5 | from typing import Literal
6 | from functools import partial
7 |
8 | import torch
9 | from torch import nn, pi
10 | from torch.special import expm1
11 | import torch.nn.functional as F
12 | from torch.nn import Module, ModuleList
13 |
14 | import einx
15 | from einops import rearrange, repeat, reduce, pack, unpack
16 | from einops.layers.torch import Rearrange
17 |
18 | from tqdm import tqdm
19 |
20 | from x_transformers import Decoder
21 |
22 | # helpers
23 |
24 | def exists(v):
25 | return v is not None
26 |
27 | def default(v, d):
28 | return v if exists(v) else d
29 |
30 | def divisible_by(num, den):
31 | return (num % den) == 0
32 |
33 | # tensor helpers
34 |
35 | def log(t, eps = 1e-20):
36 | return torch.log(t.clamp(min = eps))
37 |
38 | def safe_div(num, den, eps = 1e-5):
39 | return num / den.clamp(min = eps)
40 |
41 | def right_pad_dims_to(x, t):
42 | padding_dims = x.ndim - t.ndim
43 |
44 | if padding_dims <= 0:
45 | return t
46 |
47 | return t.view(*t.shape, *((1,) * padding_dims))
48 |
49 | def pack_one(t, pattern):
50 | packed, ps = pack([t], pattern)
51 |
52 | def unpack_one(to_unpack, unpack_pattern = None):
53 | unpacked, = unpack(to_unpack, ps, default(unpack_pattern, pattern))
54 | return unpacked
55 |
56 | return packed, unpack_one
57 |
58 | # sinusoidal embedding
59 |
60 | class AdaptiveLayerNorm(Module):
61 | def __init__(
62 | self,
63 | dim,
64 | dim_condition = None
65 | ):
66 | super().__init__()
67 | dim_condition = default(dim_condition, dim)
68 |
69 | self.ln = nn.LayerNorm(dim, elementwise_affine = False)
70 | self.to_gamma = nn.Linear(dim_condition, dim, bias = False)
71 | nn.init.zeros_(self.to_gamma.weight)
72 |
73 | def forward(self, x, *, condition):
74 | normed = self.ln(x)
75 | gamma = self.to_gamma(condition)
76 | return normed * (gamma + 1.)
77 |
78 | class LearnedSinusoidalPosEmb(Module):
79 | def __init__(self, dim):
80 | super().__init__()
81 | assert divisible_by(dim, 2)
82 | half_dim = dim // 2
83 | self.weights = nn.Parameter(torch.randn(half_dim))
84 |
85 | def forward(self, x):
86 | x = rearrange(x, 'b -> b 1')
87 | freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * pi
88 | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
89 | fouriered = torch.cat((x, fouriered), dim = -1)
90 | return fouriered
91 |
92 | # simple mlp
93 |
94 | class MLP(Module):
95 | def __init__(
96 | self,
97 | dim_cond,
98 | dim_input,
99 | depth = 3,
100 | width = 1024,
101 | dropout = 0.
102 | ):
103 | super().__init__()
104 | layers = ModuleList([])
105 |
106 | self.to_time_emb = nn.Sequential(
107 | LearnedSinusoidalPosEmb(dim_cond),
108 | nn.Linear(dim_cond + 1, dim_cond),
109 | )
110 |
111 | for _ in range(depth):
112 |
113 | adaptive_layernorm = AdaptiveLayerNorm(
114 | dim_input,
115 | dim_condition = dim_cond
116 | )
117 |
118 | block = nn.Sequential(
119 | nn.Linear(dim_input, width),
120 | nn.SiLU(),
121 | nn.Dropout(dropout),
122 | nn.Linear(width, dim_input)
123 | )
124 |
125 | block_out_gamma = nn.Linear(dim_cond, dim_input, bias = False)
126 | nn.init.zeros_(block_out_gamma.weight)
127 |
128 | layers.append(ModuleList([
129 | adaptive_layernorm,
130 | block,
131 | block_out_gamma
132 | ]))
133 |
134 | self.layers = layers
135 |
136 | def forward(
137 | self,
138 | noised,
139 | *,
140 | times,
141 | cond
142 | ):
143 | assert noised.ndim == 2
144 |
145 | time_emb = self.to_time_emb(times)
146 | cond = F.silu(time_emb + cond)
147 |
148 | denoised = noised
149 |
150 | for adaln, block, block_out_gamma in self.layers:
151 | residual = denoised
152 | denoised = adaln(denoised, condition = cond)
153 |
154 | block_out = block(denoised) * (block_out_gamma(cond) + 1.)
155 | denoised = block_out + residual
156 |
157 | return denoised
158 |
159 | # gaussian diffusion
160 |
161 | class ElucidatedDiffusion(Module):
162 | def __init__(
163 | self,
164 | dim: int,
165 | net: MLP,
166 | *,
167 | num_sample_steps = 32, # number of sampling steps
168 | sigma_min = 0.002, # min noise level
169 | sigma_max = 80, # max noise level
170 | sigma_data = 0.5, # standard deviation of data distribution
171 | rho = 7, # controls the sampling schedule
172 | P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training
173 | P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training
174 | S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper
175 | S_tmin = 0.05,
176 | S_tmax = 50,
177 | S_noise = 1.003,
178 | clamp_during_sampling = True
179 | ):
180 | super().__init__()
181 |
182 | self.net = net
183 | self.dim = dim
184 |
185 | # parameters
186 |
187 | self.sigma_min = sigma_min
188 | self.sigma_max = sigma_max
189 | self.sigma_data = sigma_data
190 |
191 | self.rho = rho
192 |
193 | self.P_mean = P_mean
194 | self.P_std = P_std
195 |
196 | self.num_sample_steps = num_sample_steps # otherwise known as N in the paper
197 |
198 | self.S_churn = S_churn
199 | self.S_tmin = S_tmin
200 | self.S_tmax = S_tmax
201 | self.S_noise = S_noise
202 |
203 | self.clamp_during_sampling = clamp_during_sampling
204 |
205 | @property
206 | def device(self):
207 | return next(self.net.parameters()).device
208 |
209 | # derived preconditioning params - Table 1
210 |
211 | def c_skip(self, sigma):
212 | return (self.sigma_data ** 2) / (sigma ** 2 + self.sigma_data ** 2)
213 |
214 | def c_out(self, sigma):
215 | return sigma * self.sigma_data * (self.sigma_data ** 2 + sigma ** 2) ** -0.5
216 |
217 | def c_in(self, sigma):
218 | return 1 * (sigma ** 2 + self.sigma_data ** 2) ** -0.5
219 |
220 | def c_noise(self, sigma):
221 | return log(sigma) * 0.25
222 |
223 | # preconditioned network output
224 | # equation (7) in the paper
225 |
226 | def preconditioned_network_forward(self, noised_seq, sigma, *, cond, clamp = None):
227 | clamp = default(clamp, self.clamp_during_sampling)
228 |
229 | batch, device = noised_seq.shape[0], noised_seq.device
230 |
231 | if isinstance(sigma, float):
232 | sigma = torch.full((batch,), sigma, device = device)
233 |
234 | padded_sigma = right_pad_dims_to(noised_seq, sigma)
235 |
236 | net_out = self.net(
237 | self.c_in(padded_sigma) * noised_seq,
238 | times = self.c_noise(sigma),
239 | cond = cond
240 | )
241 |
242 | out = self.c_skip(padded_sigma) * noised_seq + self.c_out(padded_sigma) * net_out
243 |
244 | if clamp:
245 | out = out.clamp(-1., 1.)
246 |
247 | return out
248 |
249 | # sampling
250 |
251 | # sample schedule
252 | # equation (5) in the paper
253 |
254 | def sample_schedule(self, num_sample_steps = None):
255 | num_sample_steps = default(num_sample_steps, self.num_sample_steps)
256 |
257 | N = num_sample_steps
258 | inv_rho = 1 / self.rho
259 |
260 | steps = torch.arange(num_sample_steps, device = self.device, dtype = torch.float32)
261 | sigmas = (self.sigma_max ** inv_rho + steps / (N - 1) * (self.sigma_min ** inv_rho - self.sigma_max ** inv_rho)) ** self.rho
262 |
263 | sigmas = F.pad(sigmas, (0, 1), value = 0.) # last step is sigma value of 0.
264 | return sigmas
265 |
266 | @torch.no_grad()
267 | def sample(self, cond, num_sample_steps = None, clamp = None):
268 | clamp = default(clamp, self.clamp_during_sampling)
269 | num_sample_steps = default(num_sample_steps, self.num_sample_steps)
270 |
271 | shape = (cond.shape[0], self.dim)
272 |
273 | # get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma
274 |
275 | sigmas = self.sample_schedule(num_sample_steps)
276 |
277 | gammas = torch.where(
278 | (sigmas >= self.S_tmin) & (sigmas <= self.S_tmax),
279 | min(self.S_churn / num_sample_steps, sqrt(2) - 1),
280 | 0.
281 | )
282 |
283 | sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1]))
284 |
285 | # images is noise at the beginning
286 |
287 | init_sigma = sigmas[0]
288 |
289 | seq = init_sigma * torch.randn(shape, device = self.device)
290 |
291 | # gradually denoise
292 |
293 | for sigma, sigma_next, gamma in tqdm(sigmas_and_gammas, desc = 'sampling time step'):
294 | sigma, sigma_next, gamma = map(lambda t: t.item(), (sigma, sigma_next, gamma))
295 |
296 | eps = self.S_noise * torch.randn(shape, device = self.device) # stochastic sampling
297 |
298 | sigma_hat = sigma + gamma * sigma
299 | seq_hat = seq + sqrt(sigma_hat ** 2 - sigma ** 2) * eps
300 |
301 | model_output = self.preconditioned_network_forward(seq_hat, sigma_hat, cond = cond, clamp = clamp)
302 | denoised_over_sigma = (seq_hat - model_output) / sigma_hat
303 |
304 | seq_next = seq_hat + (sigma_next - sigma_hat) * denoised_over_sigma
305 |
306 | # second order correction, if not the last timestep
307 |
308 | if sigma_next != 0:
309 | model_output_next = self.preconditioned_network_forward(seq_next, sigma_next, cond = cond, clamp = clamp)
310 | denoised_prime_over_sigma = (seq_next - model_output_next) / sigma_next
311 | seq_next = seq_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma)
312 |
313 | seq = seq_next
314 |
315 | if clamp:
316 | seq = seq.clamp(-1., 1.)
317 |
318 | return seq
319 |
320 | # training
321 |
322 | def loss_weight(self, sigma):
323 | return (sigma ** 2 + self.sigma_data ** 2) * (sigma * self.sigma_data) ** -2
324 |
325 | def noise_distribution(self, batch_size):
326 | return (self.P_mean + self.P_std * torch.randn((batch_size,), device = self.device)).exp()
327 |
328 | def forward(self, seq, *, cond):
329 | batch_size, dim, device = *seq.shape, self.device
330 |
331 | assert dim == self.dim, f'dimension of sequence being passed in must be {self.dim} but received {dim}'
332 |
333 | sigmas = self.noise_distribution(batch_size)
334 | padded_sigmas = right_pad_dims_to(seq, sigmas)
335 |
336 | noise = torch.randn_like(seq)
337 |
338 | noised_seq = seq + padded_sigmas * noise # alphas are 1. in the paper
339 |
340 | denoised = self.preconditioned_network_forward(noised_seq, sigmas, cond = cond)
341 |
342 | losses = F.mse_loss(denoised, seq, reduction = 'none')
343 | losses = reduce(losses, 'b ... -> b', 'mean')
344 |
345 | losses = losses * self.loss_weight(sigmas)
346 |
347 | return losses.mean()
348 |
349 | # main model, a decoder with continuous wrapper + small denoising mlp
350 |
351 | class AutoregressiveDiffusion(Module):
352 | def __init__(
353 | self,
354 | dim,
355 | *,
356 | max_seq_len,
357 | depth = 8,
358 | dim_head = 64,
359 | heads = 8,
360 | mlp_depth = 3,
361 | mlp_width = None,
362 | dim_input = None,
363 | decoder_kwargs: dict = dict(),
364 | mlp_kwargs: dict = dict(),
365 | diffusion_kwargs: dict = dict(
366 | clamp_during_sampling = True
367 | )
368 | ):
369 | super().__init__()
370 |
371 | self.start_token = nn.Parameter(torch.zeros(dim))
372 | self.max_seq_len = max_seq_len
373 | self.abs_pos_emb = nn.Embedding(max_seq_len, dim)
374 |
375 | dim_input = default(dim_input, dim)
376 | self.dim_input = dim_input
377 | self.proj_in = nn.Linear(dim_input, dim)
378 |
379 | self.transformer = Decoder(
380 | dim = dim,
381 | depth = depth,
382 | heads = heads,
383 | attn_dim_head = dim_head,
384 | **decoder_kwargs
385 | )
386 |
387 | self.denoiser = MLP(
388 | dim_cond = dim,
389 | dim_input = dim_input,
390 | depth = mlp_depth,
391 | width = default(mlp_width, dim),
392 | **mlp_kwargs
393 | )
394 |
395 | self.diffusion = ElucidatedDiffusion(
396 | dim_input,
397 | self.denoiser,
398 | **diffusion_kwargs
399 | )
400 |
401 | @property
402 | def device(self):
403 | return next(self.transformer.parameters()).device
404 |
405 | @torch.no_grad()
406 | def sample(
407 | self,
408 | batch_size = 1,
409 | prompt = None
410 | ):
411 | self.eval()
412 |
413 | start_tokens = repeat(self.start_token, 'd -> b 1 d', b = batch_size)
414 |
415 | if not exists(prompt):
416 | out = torch.empty((batch_size, 0, self.dim_input), device = self.device, dtype = torch.float32)
417 | else:
418 | out = prompt
419 |
420 | cache = None
421 |
422 | for _ in tqdm(range(self.max_seq_len - out.shape[1]), desc = 'tokens'):
423 |
424 | cond = self.proj_in(out)
425 |
426 | cond = torch.cat((start_tokens, cond), dim = 1)
427 | cond = cond + self.abs_pos_emb(torch.arange(cond.shape[1], device = self.device))
428 |
429 | cond, cache = self.transformer(cond, cache = cache, return_hiddens = True)
430 |
431 | last_cond = cond[:, -1]
432 |
433 | denoised_pred = self.diffusion.sample(cond = last_cond)
434 |
435 | denoised_pred = rearrange(denoised_pred, 'b d -> b 1 d')
436 | out = torch.cat((out, denoised_pred), dim = 1)
437 |
438 | return out
439 |
440 | def forward(
441 | self,
442 | seq
443 | ):
444 | b, seq_len, dim = seq.shape
445 |
446 | assert dim == self.dim_input
447 | assert seq_len == self.max_seq_len
448 |
449 | # break into seq and the continuous targets to be predicted
450 |
451 | seq, target = seq[:, :-1], seq
452 |
453 | # append start tokens
454 |
455 | seq = self.proj_in(seq)
456 | start_token = repeat(self.start_token, 'd -> b 1 d', b = b)
457 |
458 | seq = torch.cat((start_token, seq), dim = 1)
459 | seq = seq + self.abs_pos_emb(torch.arange(seq_len, device = self.device))
460 |
461 | cond = self.transformer(seq)
462 |
463 | # pack batch and sequence dimensions, so to train each token with different noise levels
464 |
465 | target, _ = pack_one(target, '* d')
466 | cond, _ = pack_one(cond, '* d')
467 |
468 | diffusion_loss = self.diffusion(target, cond = cond)
469 |
470 | return diffusion_loss
471 |
472 | # image wrapper
473 |
474 | def normalize_to_neg_one_to_one(img):
475 | return img * 2 - 1
476 |
477 | def unnormalize_to_zero_to_one(t):
478 | return (t + 1) * 0.5
479 |
480 | class ImageAutoregressiveDiffusion(Module):
481 | def __init__(
482 | self,
483 | *,
484 | image_size,
485 | patch_size,
486 | channels = 3,
487 | model: dict = dict(),
488 | ):
489 | super().__init__()
490 | assert divisible_by(image_size, patch_size)
491 |
492 | num_patches = (image_size // patch_size) ** 2
493 | dim_in = channels * patch_size ** 2
494 |
495 | self.image_size = image_size
496 | self.patch_size = patch_size
497 |
498 | self.to_tokens = Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size)
499 |
500 | self.model = AutoregressiveDiffusion(
501 | **model,
502 | dim_input = dim_in,
503 | max_seq_len = num_patches
504 | )
505 |
506 | self.to_image = Rearrange('b (h w) (c p1 p2) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size, h = int(math.sqrt(num_patches)))
507 |
508 | def sample(self, batch_size = 1):
509 | tokens = self.model.sample(batch_size = batch_size)
510 | images = self.to_image(tokens)
511 | return unnormalize_to_zero_to_one(images)
512 |
513 | def forward(self, images):
514 | images = normalize_to_neg_one_to_one(images)
515 | tokens = self.to_tokens(images)
516 | return self.model(tokens)
517 |
--------------------------------------------------------------------------------