├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── muse.png
├── muse_maskgit_pytorch
├── __init__.py
├── attend.py
├── muse_maskgit_pytorch.py
├── t5.py
├── trainers.py
└── vqgan_vae.py
└── setup.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Muse - Pytorch
4 |
5 | Implementation of Muse: Text-to-Image Generation via Masked Generative Transformers, in Pytorch
6 |
7 | Please join
if you are interested in helping out with the replication with the LAION community
8 |
9 | ## Install
10 |
11 | ```bash
12 | $ pip install muse-maskgit-pytorch
13 | ```
14 |
15 | ## Usage
16 |
17 | First train your VAE - `VQGanVAE`
18 |
19 | ```python
20 | import torch
21 | from muse_maskgit_pytorch import VQGanVAE, VQGanVAETrainer
22 |
23 | vae = VQGanVAE(
24 | dim = 256,
25 | codebook_size = 65536
26 | )
27 |
28 | # train on folder of images, as many images as possible
29 |
30 | trainer = VQGanVAETrainer(
31 | vae = vae,
32 | image_size = 128, # you may want to start with small images, and then curriculum learn to larger ones, but because the vae is all convolution, it should generalize to 512 (as in paper) without training on it
33 | folder = '/path/to/images',
34 | batch_size = 4,
35 | grad_accum_every = 8,
36 | num_train_steps = 50000
37 | ).cuda()
38 |
39 | trainer.train()
40 | ```
41 |
42 | Then pass the trained `VQGanVAE` and a `Transformer` to `MaskGit`
43 |
44 | ```python
45 | import torch
46 | from muse_maskgit_pytorch import VQGanVAE, MaskGit, MaskGitTransformer
47 |
48 | # first instantiate your vae
49 |
50 | vae = VQGanVAE(
51 | dim = 256,
52 | codebook_size = 65536
53 | ).cuda()
54 |
55 | vae.load('/path/to/vae.pt') # you will want to load the exponentially moving averaged VAE
56 |
57 | # then you plug the vae and transformer into your MaskGit as so
58 |
59 | # (1) create your transformer / attention network
60 |
61 | transformer = MaskGitTransformer(
62 | num_tokens = 65536, # must be same as codebook size above
63 | seq_len = 256, # must be equivalent to fmap_size ** 2 in vae
64 | dim = 512, # model dimension
65 | depth = 8, # depth
66 | dim_head = 64, # attention head dimension
67 | heads = 8, # attention heads,
68 | ff_mult = 4, # feedforward expansion factor
69 | t5_name = 't5-small', # name of your T5
70 | )
71 |
72 | # (2) pass your trained VAE and the base transformer to MaskGit
73 |
74 | base_maskgit = MaskGit(
75 | vae = vae, # vqgan vae
76 | transformer = transformer, # transformer
77 | image_size = 256, # image size
78 | cond_drop_prob = 0.25, # conditional dropout, for classifier free guidance
79 | ).cuda()
80 |
81 | # ready your training text and images
82 |
83 | texts = [
84 | 'a child screaming at finding a worm within a half-eaten apple',
85 | 'lizard running across the desert on two feet',
86 | 'waking up to a psychedelic landscape',
87 | 'seashells sparkling in the shallow waters'
88 | ]
89 |
90 | images = torch.randn(4, 3, 256, 256).cuda()
91 |
92 | # feed it into your maskgit instance, with return_loss set to True
93 |
94 | loss = base_maskgit(
95 | images,
96 | texts = texts
97 | )
98 |
99 | loss.backward()
100 |
101 | # do this for a long time on much data
102 | # then...
103 |
104 | images = base_maskgit.generate(texts = [
105 | 'a whale breaching from afar',
106 | 'young girl blowing out candles on her birthday cake',
107 | 'fireworks with blue and green sparkles'
108 | ], cond_scale = 3.) # conditioning scale for classifier free guidance
109 |
110 | images.shape # (3, 3, 256, 256)
111 | ```
112 |
113 |
114 | To train the super-resolution maskgit requires you to change 1 field on `MaskGit` instantiation (you will need to now pass in the `cond_image_size`, as the previous image size being conditioned on)
115 |
116 | Optionally, you can pass in a different `VAE` as `cond_vae` for the conditioning low-resolution image. By default it will use the `vae` for both tokenizing the super and low resoluted images.
117 |
118 | ```python
119 | import torch
120 | import torch.nn.functional as F
121 | from muse_maskgit_pytorch import VQGanVAE, MaskGit, MaskGitTransformer
122 |
123 | # first instantiate your ViT VQGan VAE
124 | # a VQGan VAE made of transformers
125 |
126 | vae = VQGanVAE(
127 | dim = 256,
128 | codebook_size = 65536
129 | ).cuda()
130 |
131 | vae.load('./path/to/vae.pt') # you will want to load the exponentially moving averaged VAE
132 |
133 | # then you plug the VqGan VAE into your MaskGit as so
134 |
135 | # (1) create your transformer / attention network
136 |
137 | transformer = MaskGitTransformer(
138 | num_tokens = 65536, # must be same as codebook size above
139 | seq_len = 1024, # must be equivalent to fmap_size ** 2 in vae
140 | dim = 512, # model dimension
141 | depth = 2, # depth
142 | dim_head = 64, # attention head dimension
143 | heads = 8, # attention heads,
144 | ff_mult = 4, # feedforward expansion factor
145 | t5_name = 't5-small', # name of your T5
146 | )
147 |
148 | # (2) pass your trained VAE and the base transformer to MaskGit
149 |
150 | superres_maskgit = MaskGit(
151 | vae = vae,
152 | transformer = transformer,
153 | cond_drop_prob = 0.25,
154 | image_size = 512, # larger image size
155 | cond_image_size = 256, # conditioning image size <- this must be set
156 | ).cuda()
157 |
158 | # ready your training text and images
159 |
160 | texts = [
161 | 'a child screaming at finding a worm within a half-eaten apple',
162 | 'lizard running across the desert on two feet',
163 | 'waking up to a psychedelic landscape',
164 | 'seashells sparkling in the shallow waters'
165 | ]
166 |
167 | images = torch.randn(4, 3, 512, 512).cuda()
168 |
169 | # feed it into your maskgit instance, with return_loss set to True
170 |
171 | loss = superres_maskgit(
172 | images,
173 | texts = texts
174 | )
175 |
176 | loss.backward()
177 |
178 | # do this for a long time on much data
179 | # then...
180 |
181 | images = superres_maskgit.generate(
182 | texts = [
183 | 'a whale breaching from afar',
184 | 'young girl blowing out candles on her birthday cake',
185 | 'fireworks with blue and green sparkles',
186 | 'waking up to a psychedelic landscape'
187 | ],
188 | cond_images = F.interpolate(images, 256), # conditioning images must be passed in for generating from superres
189 | cond_scale = 3.
190 | )
191 |
192 | images.shape # (4, 3, 512, 512)
193 | ```
194 |
195 | All together now
196 |
197 | ```python
198 | from muse_maskgit_pytorch import Muse
199 |
200 | base_maskgit.load('./path/to/base.pt')
201 |
202 | superres_maskgit.load('./path/to/superres.pt')
203 |
204 | # pass in the trained base_maskgit and superres_maskgit from above
205 |
206 | muse = Muse(
207 | base = base_maskgit,
208 | superres = superres_maskgit
209 | )
210 |
211 | images = muse([
212 | 'a whale breaching from afar',
213 | 'young girl blowing out candles on her birthday cake',
214 | 'fireworks with blue and green sparkles',
215 | 'waking up to a psychedelic landscape'
216 | ])
217 |
218 | images # List[PIL.Image.Image]
219 | ```
220 |
221 | ## Appreciation
222 |
223 | - StabilityAI for the sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.
224 |
225 | - 🤗 Huggingface for the transformers and accelerate library, both which are wonderful
226 |
227 | ## Todo
228 |
229 | - [x] test end-to-end
230 | - [x] separate cond_images_or_ids, it is not done right
231 | - [x] add training code for vae
232 | - [x] add optional self-conditioning on embeddings
233 | - [x] combine with token critic paper, already implemented at Phenaki
234 |
235 | - [ ] hook up accelerate training code for maskgit
236 |
237 | ## Citations
238 |
239 | ```bibtex
240 | @inproceedings{Chang2023MuseTG,
241 | title = {Muse: Text-To-Image Generation via Masked Generative Transformers},
242 | author = {Huiwen Chang and Han Zhang and Jarred Barber and AJ Maschinot and Jos{\'e} Lezama and Lu Jiang and Ming-Hsuan Yang and Kevin P. Murphy and William T. Freeman and Michael Rubinstein and Yuanzhen Li and Dilip Krishnan},
243 | year = {2023}
244 | }
245 | ```
246 |
247 | ```bibtex
248 | @article{Chen2022AnalogBG,
249 | title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
250 | author = {Ting Chen and Ruixiang Zhang and Geo rey E. Hinton},
251 | journal = {ArXiv},
252 | year = {2022},
253 | volume = {abs/2208.04202}
254 | }
255 | ```
256 |
257 | ```bibtex
258 | @misc{jabri2022scalable,
259 | title = {Scalable Adaptive Computation for Iterative Generation},
260 | author = {Allan Jabri and David Fleet and Ting Chen},
261 | year = {2022},
262 | eprint = {2212.11972},
263 | archivePrefix = {arXiv},
264 | primaryClass = {cs.LG}
265 | }
266 | ```
267 |
268 | ```bibtex
269 | @article{Lezama2022ImprovedMI,
270 | title = {Improved Masked Image Generation with Token-Critic},
271 | author = {Jos{\'e} Lezama and Huiwen Chang and Lu Jiang and Irfan Essa},
272 | journal = {ArXiv},
273 | year = {2022},
274 | volume = {abs/2209.04439}
275 | }
276 | ```
277 |
278 | ```bibtex
279 | @inproceedings{Nijkamp2021SCRIPTSP,
280 | title = {SCRIPT: Self-Critic PreTraining of Transformers},
281 | author = {Erik Nijkamp and Bo Pang and Ying Nian Wu and Caiming Xiong},
282 | booktitle = {North American Chapter of the Association for Computational Linguistics},
283 | year = {2021}
284 | }
285 | ```
286 |
287 | ```bibtex
288 | @inproceedings{dao2022flashattention,
289 | title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
290 | author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
291 | booktitle = {Advances in Neural Information Processing Systems},
292 | year = {2022}
293 | }
294 | ```
295 |
296 | ```bibtex
297 | @misc{mentzer2023finite,
298 | title = {Finite Scalar Quantization: VQ-VAE Made Simple},
299 | author = {Fabian Mentzer and David Minnen and Eirikur Agustsson and Michael Tschannen},
300 | year = {2023},
301 | eprint = {2309.15505},
302 | archivePrefix = {arXiv},
303 | primaryClass = {cs.CV}
304 | }
305 | ```
306 |
307 | ```bibtex
308 | @misc{yu2023language,
309 | title = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation},
310 | author = {Lijun Yu and José Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang},
311 | year = {2023},
312 | eprint = {2310.05737},
313 | archivePrefix = {arXiv},
314 | primaryClass = {cs.CV}
315 | }
316 | ```
317 |
--------------------------------------------------------------------------------
/muse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/muse-maskgit-pytorch/6df7f33bcd33ba28a2f682d5bd293e4f8a513e6c/muse.png
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from muse_maskgit_pytorch.vqgan_vae import VQGanVAE
2 | from muse_maskgit_pytorch.muse_maskgit_pytorch import Transformer, MaskGit, Muse, MaskGitTransformer, TokenCritic
3 |
4 | from muse_maskgit_pytorch.trainers import VQGanVAETrainer
5 |
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/attend.py:
--------------------------------------------------------------------------------
1 | from functools import wraps
2 | from packaging import version
3 | from collections import namedtuple
4 |
5 | import torch
6 | from torch import nn, einsum
7 | import torch.nn.functional as F
8 |
9 | from memory_efficient_attention_pytorch.flash_attention import FlashAttentionFunction
10 | # constants
11 |
12 | AttentionConfig = namedtuple('AttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
13 |
14 | # helpers
15 |
16 | def exists(val):
17 | return val is not None
18 |
19 | def once(fn):
20 | called = False
21 | @wraps(fn)
22 | def inner(x):
23 | nonlocal called
24 | if called:
25 | return
26 | called = True
27 | return fn(x)
28 | return inner
29 |
30 | print_once = once(print)
31 |
32 | # main class
33 |
34 | class Attend(nn.Module):
35 | def __init__(
36 | self,
37 | scale = 8,
38 | dropout = 0.,
39 | flash = False
40 | ):
41 | super().__init__()
42 | self.scale = scale
43 | self.dropout = dropout
44 | self.attn_dropout = nn.Dropout(dropout)
45 |
46 | self.flash = flash
47 | assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
48 |
49 | # determine efficient attention configs for cuda and cpu
50 |
51 | self.cuda_config = None
52 | self.no_hardware_detected = False
53 |
54 | if not torch.cuda.is_available() or not flash:
55 | return
56 |
57 | device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
58 |
59 | if device_properties.major == 8 and device_properties.minor == 0:
60 | print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
61 | self.cuda_config = AttentionConfig(True, False, False)
62 | else:
63 | print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
64 | self.cuda_config = AttentionConfig(False, True, False)
65 |
66 | def flash_attn(self, q, k, v, mask = None):
67 | default_scale = q.shape[-1] ** -0.5
68 |
69 | is_cuda = q.is_cuda
70 |
71 | q, k, v = map(lambda t: t.contiguous(), (q, k, v))
72 |
73 | # scaled_dot_product_attention does not allow for custom scale
74 | # so hack it in, to support rmsnorm-ed queries and keys
75 |
76 | rescale = self.scale / default_scale
77 |
78 | q = q * (rescale ** 0.5)
79 | k = k * (rescale ** 0.5)
80 |
81 | # use naive implementation if not correct hardware
82 |
83 | # the below logic can also incorporate whether masking is needed or not
84 |
85 | use_naive = not is_cuda or not exists(self.cuda_config)
86 |
87 | if not is_cuda or self.no_hardware_detected:
88 | return FlashAttentionFunction.apply(q, k, v, mask, False, 512, 512)
89 |
90 | # use naive implementation
91 | # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
92 |
93 | try:
94 | raise Exception()
95 | with torch.backends.cuda.sdp_kernel(**self.cuda_config._asdict()):
96 | out = F.scaled_dot_product_attention(
97 | q, k, v,
98 | attn_mask = mask,
99 | dropout_p = self.dropout if self.training else 0.
100 | )
101 | except:
102 | print_once('no hardware detected, falling back to naive implementation from memory-efficient-attention-pytorch library')
103 | self.no_hardware_detected = True
104 |
105 | out = FlashAttentionFunction.apply(q, k, v, mask, False, 512, 512)
106 |
107 | return out
108 |
109 | def forward(self, q, k, v, mask = None, force_non_flash = False):
110 | """
111 | einstein notation
112 | b - batch
113 | h - heads
114 | n, i, j - sequence length (base sequence length, source, target)
115 | d - feature dimension
116 | """
117 |
118 | if self.flash and not force_non_flash:
119 | return self.flash_attn(q, k, v, mask = mask)
120 |
121 | # similarity
122 |
123 | sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
124 |
125 | # masking
126 |
127 | if exists(mask):
128 | mask_value = -torch.finfo(sim.dtype).max
129 | sim = sim.masked_fill(~mask, mask_value)
130 |
131 | # attention
132 |
133 | attn = sim.softmax(dim = -1)
134 | attn = self.attn_dropout(attn)
135 |
136 | # aggregate values
137 |
138 | out = einsum("b h i j, b h j d -> b h i d", attn, v)
139 |
140 | return out
141 |
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/muse_maskgit_pytorch.py:
--------------------------------------------------------------------------------
1 | import math
2 | from random import random
3 | from functools import partial
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn, einsum
8 | import pathlib
9 | from pathlib import Path
10 | import torchvision.transforms as T
11 |
12 | from typing import Callable, Optional, List
13 |
14 | from einops import rearrange, repeat
15 |
16 | from beartype import beartype
17 |
18 | from muse_maskgit_pytorch.vqgan_vae import VQGanVAE
19 | from muse_maskgit_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME
20 | from muse_maskgit_pytorch.attend import Attend
21 |
22 | from tqdm.auto import tqdm
23 |
24 | # helpers
25 |
26 | def exists(val):
27 | return val is not None
28 |
29 | def default(val, d):
30 | return val if exists(val) else d
31 |
32 | def eval_decorator(fn):
33 | def inner(model, *args, **kwargs):
34 | was_training = model.training
35 | model.eval()
36 | out = fn(model, *args, **kwargs)
37 | model.train(was_training)
38 | return out
39 | return inner
40 |
41 | def l2norm(t):
42 | return F.normalize(t, dim = -1)
43 |
44 | # tensor helpers
45 |
46 | def get_mask_subset_prob(mask, prob, min_mask = 0):
47 | batch, seq, device = *mask.shape, mask.device
48 | num_to_mask = (mask.sum(dim = -1, keepdim = True) * prob).clamp(min = min_mask)
49 | logits = torch.rand((batch, seq), device = device)
50 | logits = logits.masked_fill(~mask, -1)
51 |
52 | randperm = logits.argsort(dim = -1).argsort(dim = -1).float()
53 |
54 | num_padding = (~mask).sum(dim = -1, keepdim = True)
55 | randperm -= num_padding
56 |
57 | subset_mask = randperm < num_to_mask
58 | subset_mask.masked_fill_(~mask, False)
59 | return subset_mask
60 |
61 | # classes
62 |
63 | class LayerNorm(nn.Module):
64 | def __init__(self, dim):
65 | super().__init__()
66 | self.gamma = nn.Parameter(torch.ones(dim))
67 | self.register_buffer('beta', torch.zeros(dim))
68 |
69 | def forward(self, x):
70 | return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
71 |
72 | class GEGLU(nn.Module):
73 | """ https://arxiv.org/abs/2002.05202 """
74 |
75 | def forward(self, x):
76 | x, gate = x.chunk(2, dim = -1)
77 | return gate * F.gelu(x)
78 |
79 | def FeedForward(dim, mult = 4):
80 | """ https://arxiv.org/abs/2110.09456 """
81 |
82 | inner_dim = int(dim * mult * 2 / 3)
83 | return nn.Sequential(
84 | LayerNorm(dim),
85 | nn.Linear(dim, inner_dim * 2, bias = False),
86 | GEGLU(),
87 | LayerNorm(inner_dim),
88 | nn.Linear(inner_dim, dim, bias = False)
89 | )
90 |
91 | class Attention(nn.Module):
92 | def __init__(
93 | self,
94 | dim,
95 | dim_head = 64,
96 | heads = 8,
97 | cross_attend = False,
98 | scale = 8,
99 | flash = True,
100 | dropout = 0.
101 | ):
102 | super().__init__()
103 | self.scale = scale
104 | self.heads = heads
105 | inner_dim = dim_head * heads
106 |
107 | self.cross_attend = cross_attend
108 | self.norm = LayerNorm(dim)
109 |
110 | self.attend = Attend(
111 | flash = flash,
112 | dropout = dropout,
113 | scale = scale
114 | )
115 |
116 | self.null_kv = nn.Parameter(torch.randn(2, heads, 1, dim_head))
117 |
118 | self.to_q = nn.Linear(dim, inner_dim, bias = False)
119 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
120 |
121 | self.q_scale = nn.Parameter(torch.ones(dim_head))
122 | self.k_scale = nn.Parameter(torch.ones(dim_head))
123 |
124 | self.to_out = nn.Linear(inner_dim, dim, bias = False)
125 |
126 | def forward(
127 | self,
128 | x,
129 | context = None,
130 | context_mask = None
131 | ):
132 | assert not (exists(context) ^ self.cross_attend)
133 |
134 | n = x.shape[-2]
135 | h, is_cross_attn = self.heads, exists(context)
136 |
137 | x = self.norm(x)
138 |
139 | kv_input = context if self.cross_attend else x
140 |
141 | q, k, v = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1))
142 |
143 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
144 |
145 | nk, nv = self.null_kv
146 | nk, nv = map(lambda t: repeat(t, 'h 1 d -> b h 1 d', b = x.shape[0]), (nk, nv))
147 |
148 | k = torch.cat((nk, k), dim = -2)
149 | v = torch.cat((nv, v), dim = -2)
150 |
151 | q, k = map(l2norm, (q, k))
152 | q = q * self.q_scale
153 | k = k * self.k_scale
154 |
155 | if exists(context_mask):
156 | context_mask = repeat(context_mask, 'b j -> b h i j', h = h, i = n)
157 | context_mask = F.pad(context_mask, (1, 0), value = True)
158 |
159 | out = self.attend(q, k, v, mask = context_mask)
160 |
161 | out = rearrange(out, 'b h n d -> b n (h d)')
162 | return self.to_out(out)
163 |
164 | class TransformerBlocks(nn.Module):
165 | def __init__(
166 | self,
167 | *,
168 | dim,
169 | depth,
170 | dim_head = 64,
171 | heads = 8,
172 | ff_mult = 4,
173 | flash = True
174 | ):
175 | super().__init__()
176 | self.layers = nn.ModuleList([])
177 |
178 | for _ in range(depth):
179 | self.layers.append(nn.ModuleList([
180 | Attention(dim = dim, dim_head = dim_head, heads = heads, flash = flash),
181 | Attention(dim = dim, dim_head = dim_head, heads = heads, cross_attend = True, flash = flash),
182 | FeedForward(dim = dim, mult = ff_mult)
183 | ]))
184 |
185 | self.norm = LayerNorm(dim)
186 |
187 | def forward(self, x, context = None, context_mask = None):
188 | for attn, cross_attn, ff in self.layers:
189 | x = attn(x) + x
190 |
191 | x = cross_attn(x, context = context, context_mask = context_mask) + x
192 |
193 | x = ff(x) + x
194 |
195 | return self.norm(x)
196 |
197 | # transformer - it's all we need
198 |
199 | class Transformer(nn.Module):
200 | def __init__(
201 | self,
202 | *,
203 | num_tokens,
204 | dim,
205 | seq_len,
206 | dim_out = None,
207 | t5_name = DEFAULT_T5_NAME,
208 | self_cond = False,
209 | add_mask_id = False,
210 | **kwargs
211 | ):
212 | super().__init__()
213 | self.dim = dim
214 | self.mask_id = num_tokens if add_mask_id else None
215 |
216 | self.num_tokens = num_tokens
217 | self.token_emb = nn.Embedding(num_tokens + int(add_mask_id), dim)
218 | self.pos_emb = nn.Embedding(seq_len, dim)
219 | self.seq_len = seq_len
220 |
221 | self.transformer_blocks = TransformerBlocks(dim = dim, **kwargs)
222 | self.norm = LayerNorm(dim)
223 |
224 | self.dim_out = default(dim_out, num_tokens)
225 | self.to_logits = nn.Linear(dim, self.dim_out, bias = False)
226 |
227 | # text conditioning
228 |
229 | self.encode_text = partial(t5_encode_text, name = t5_name)
230 |
231 | text_embed_dim = get_encoded_dim(t5_name)
232 |
233 | self.text_embed_proj = nn.Linear(text_embed_dim, dim, bias = False) if text_embed_dim != dim else nn.Identity()
234 |
235 | # optional self conditioning
236 |
237 | self.self_cond = self_cond
238 | self.self_cond_to_init_embed = FeedForward(dim)
239 |
240 | def forward_with_cond_scale(
241 | self,
242 | *args,
243 | cond_scale = 3.,
244 | return_embed = False,
245 | **kwargs
246 | ):
247 | if cond_scale == 1:
248 | return self.forward(*args, return_embed = return_embed, cond_drop_prob = 0., **kwargs)
249 |
250 | logits, embed = self.forward(*args, return_embed = True, cond_drop_prob = 0., **kwargs)
251 |
252 | null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
253 |
254 | scaled_logits = null_logits + (logits - null_logits) * cond_scale
255 |
256 | if return_embed:
257 | return scaled_logits, embed
258 |
259 | return scaled_logits
260 |
261 | def forward_with_neg_prompt(
262 | self,
263 | text_embed: torch.Tensor,
264 | neg_text_embed: torch.Tensor,
265 | cond_scale = 3.,
266 | return_embed = False,
267 | **kwargs
268 | ):
269 | neg_logits = self.forward(*args, neg_text_embed = neg_text_embed, cond_drop_prob = 0., **kwargs)
270 | pos_logits, embed = self.forward(*args, return_embed = True, text_embed = text_embed, cond_drop_prob = 0., **kwargs)
271 |
272 | logits = neg_logits + (pos_logits - neg_logits) * cond_scale
273 |
274 | if return_embed:
275 | return scaled_logits, embed
276 |
277 | return scaled_logits
278 |
279 | def forward(
280 | self,
281 | x,
282 | return_embed = False,
283 | return_logits = False,
284 | labels = None,
285 | ignore_index = 0,
286 | self_cond_embed = None,
287 | cond_drop_prob = 0.,
288 | conditioning_token_ids: Optional[torch.Tensor] = None,
289 | texts: Optional[List[str]] = None,
290 | text_embeds: Optional[torch.Tensor] = None
291 | ):
292 | device, b, n = x.device, *x.shape
293 | assert n <= self.seq_len
294 |
295 | # prepare texts
296 |
297 | assert exists(texts) ^ exists(text_embeds)
298 |
299 | if exists(texts):
300 | text_embeds = self.encode_text(texts)
301 |
302 | context = self.text_embed_proj(text_embeds)
303 |
304 | context_mask = (text_embeds != 0).any(dim = -1)
305 |
306 | # classifier free guidance
307 |
308 | if cond_drop_prob > 0.:
309 | mask = prob_mask_like((b, 1), 1. - cond_drop_prob, device)
310 | context_mask = context_mask & mask
311 |
312 | # concat conditioning image token ids if needed
313 |
314 | if exists(conditioning_token_ids):
315 | conditioning_token_ids = rearrange(conditioning_token_ids, 'b ... -> b (...)')
316 | cond_token_emb = self.token_emb(conditioning_token_ids)
317 | context = torch.cat((context, cond_token_emb), dim = -2)
318 | context_mask = F.pad(context_mask, (0, conditioning_token_ids.shape[-1]), value = True)
319 |
320 | # embed tokens
321 |
322 | x = self.token_emb(x)
323 | x = x + self.pos_emb(torch.arange(n, device = device))
324 |
325 | if self.self_cond:
326 | if not exists(self_cond_embed):
327 | self_cond_embed = torch.zeros_like(x)
328 | x = x + self.self_cond_to_init_embed(self_cond_embed)
329 |
330 | embed = self.transformer_blocks(x, context = context, context_mask = context_mask)
331 |
332 | logits = self.to_logits(embed)
333 |
334 | if return_embed:
335 | return logits, embed
336 |
337 | if not exists(labels):
338 | return logits
339 |
340 | if self.dim_out == 1:
341 | loss = F.binary_cross_entropy_with_logits(rearrange(logits, '... 1 -> ...'), labels)
342 | else:
343 | loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels, ignore_index = ignore_index)
344 |
345 | if not return_logits:
346 | return loss
347 |
348 | return loss, logits
349 |
350 | # self critic wrapper
351 |
352 | class SelfCritic(nn.Module):
353 | def __init__(self, net):
354 | super().__init__()
355 | self.net = net
356 | self.to_pred = nn.Linear(net.dim, 1)
357 |
358 | def forward_with_cond_scale(self, x, *args, **kwargs):
359 | _, embeds = self.net.forward_with_cond_scale(x, *args, return_embed = True, **kwargs)
360 | return self.to_pred(embeds)
361 |
362 | def forward_with_neg_prompt(self, x, *args, **kwargs):
363 | _, embeds = self.net.forward_with_neg_prompt(x, *args, return_embed = True, **kwargs)
364 | return self.to_pred(embeds)
365 |
366 | def forward(self, x, *args, labels = None, **kwargs):
367 | _, embeds = self.net(x, *args, return_embed = True, **kwargs)
368 | logits = self.to_pred(embeds)
369 |
370 | if not exists(labels):
371 | return logits
372 |
373 | logits = rearrange(logits, '... 1 -> ...')
374 | return F.binary_cross_entropy_with_logits(logits, labels)
375 |
376 | # specialized transformers
377 |
378 | class MaskGitTransformer(Transformer):
379 | def __init__(self, *args, **kwargs):
380 | assert 'add_mask_id' not in kwargs
381 | super().__init__(*args, add_mask_id = True, **kwargs)
382 |
383 | class TokenCritic(Transformer):
384 | def __init__(self, *args, **kwargs):
385 | assert 'dim_out' not in kwargs
386 | super().__init__(*args, dim_out = 1, **kwargs)
387 |
388 | # classifier free guidance functions
389 |
390 | def uniform(shape, min = 0, max = 1, device = None):
391 | return torch.zeros(shape, device = device).float().uniform_(0, 1)
392 |
393 | def prob_mask_like(shape, prob, device = None):
394 | if prob == 1:
395 | return torch.ones(shape, device = device, dtype = torch.bool)
396 | elif prob == 0:
397 | return torch.zeros(shape, device = device, dtype = torch.bool)
398 | else:
399 | return uniform(shape, device = device) < prob
400 |
401 | # sampling helpers
402 |
403 | def log(t, eps = 1e-20):
404 | return torch.log(t.clamp(min = eps))
405 |
406 | def gumbel_noise(t):
407 | noise = torch.zeros_like(t).uniform_(0, 1)
408 | return -log(-log(noise))
409 |
410 | def gumbel_sample(t, temperature = 1., dim = -1):
411 | return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)
412 |
413 | def top_k(logits, thres = 0.9):
414 | k = math.ceil((1 - thres) * logits.shape[-1])
415 | val, ind = logits.topk(k, dim = -1)
416 | probs = torch.full_like(logits, float('-inf'))
417 | probs.scatter_(2, ind, val)
418 | return probs
419 |
420 | # noise schedules
421 |
422 | def cosine_schedule(t):
423 | return torch.cos(t * math.pi * 0.5)
424 |
425 | # main maskgit classes
426 |
427 | @beartype
428 | class MaskGit(nn.Module):
429 | def __init__(
430 | self,
431 | image_size,
432 | transformer: MaskGitTransformer,
433 | noise_schedule: Callable = cosine_schedule,
434 | token_critic: Optional[TokenCritic] = None,
435 | self_token_critic = False,
436 | vae: Optional[VQGanVAE] = None,
437 | cond_vae: Optional[VQGanVAE] = None,
438 | cond_image_size = None,
439 | cond_drop_prob = 0.5,
440 | self_cond_prob = 0.9,
441 | no_mask_token_prob = 0.,
442 | critic_loss_weight = 1.
443 | ):
444 | super().__init__()
445 | self.vae = vae.copy_for_eval() if exists(vae) else None
446 |
447 | if exists(cond_vae):
448 | self.cond_vae = cond_vae.eval()
449 | else:
450 | self.cond_vae = self.vae
451 |
452 | assert not (exists(cond_vae) and not exists(cond_image_size)), 'cond_image_size must be specified if conditioning'
453 |
454 | self.image_size = image_size
455 | self.cond_image_size = cond_image_size
456 | self.resize_image_for_cond_image = exists(cond_image_size)
457 |
458 | self.cond_drop_prob = cond_drop_prob
459 |
460 | self.transformer = transformer
461 | self.self_cond = transformer.self_cond
462 | assert self.vae.codebook_size == self.cond_vae.codebook_size == transformer.num_tokens, 'transformer num_tokens must be set to be equal to the vae codebook size'
463 |
464 | self.mask_id = transformer.mask_id
465 | self.noise_schedule = noise_schedule
466 |
467 | assert not (self_token_critic and exists(token_critic))
468 | self.token_critic = token_critic
469 |
470 | if self_token_critic:
471 | self.token_critic = SelfCritic(transformer)
472 |
473 | self.critic_loss_weight = critic_loss_weight
474 |
475 | # self conditioning
476 | self.self_cond_prob = self_cond_prob
477 |
478 | # percentage of tokens to be [mask]ed to remain the same token, so that transformer produces better embeddings across all tokens as done in original BERT paper
479 | # may be needed for self conditioning
480 | self.no_mask_token_prob = no_mask_token_prob
481 |
482 | def save(self, path):
483 | torch.save(self.state_dict(), path)
484 |
485 | def load(self, path):
486 | path = Path(path)
487 | assert path.exists()
488 | state_dict = torch.load(str(path))
489 | self.load_state_dict(state_dict)
490 |
491 | @torch.no_grad()
492 | @eval_decorator
493 | def generate(
494 | self,
495 | texts: List[str],
496 | negative_texts: Optional[List[str]] = None,
497 | cond_images: Optional[torch.Tensor] = None,
498 | fmap_size = None,
499 | temperature = 1.,
500 | topk_filter_thres = 0.9,
501 | can_remask_prev_masked = False,
502 | force_not_use_token_critic = False,
503 | timesteps = 18, # ideal number of steps is 18 in maskgit paper
504 | cond_scale = 3,
505 | critic_noise_scale = 1
506 | ):
507 | fmap_size = default(fmap_size, self.vae.get_encoded_fmap_size(self.image_size))
508 |
509 | # begin with all image token ids masked
510 |
511 | device = next(self.parameters()).device
512 |
513 | seq_len = fmap_size ** 2
514 |
515 | batch_size = len(texts)
516 |
517 | shape = (batch_size, seq_len)
518 |
519 | ids = torch.full(shape, self.mask_id, dtype = torch.long, device = device)
520 | scores = torch.zeros(shape, dtype = torch.float32, device = device)
521 |
522 | starting_temperature = temperature
523 |
524 | cond_ids = None
525 |
526 | text_embeds = self.transformer.encode_text(texts)
527 |
528 | demask_fn = self.transformer.forward_with_cond_scale
529 |
530 | # whether to use token critic for scores
531 |
532 | use_token_critic = exists(self.token_critic) and not force_not_use_token_critic
533 |
534 | if use_token_critic:
535 | token_critic_fn = self.token_critic.forward_with_cond_scale
536 |
537 | # negative prompting, as in paper
538 |
539 | neg_text_embeds = None
540 | if exists(negative_texts):
541 | assert len(texts) == len(negative_texts)
542 |
543 | neg_text_embeds = self.transformer.encode_text(negative_texts)
544 | demask_fn = partial(self.transformer.forward_with_neg_prompt, neg_text_embeds = neg_text_embeds)
545 |
546 | if use_token_critic:
547 | token_critic_fn = partial(self.token_critic.forward_with_neg_prompt, neg_text_embeds = neg_text_embeds)
548 |
549 | if self.resize_image_for_cond_image:
550 | assert exists(cond_images), 'conditioning image must be passed in to generate for super res maskgit'
551 | with torch.no_grad():
552 | _, cond_ids, _ = self.cond_vae.encode(cond_images)
553 |
554 | self_cond_embed = None
555 |
556 | for timestep, steps_until_x0 in tqdm(zip(torch.linspace(0, 1, timesteps, device = device), reversed(range(timesteps))), total = timesteps):
557 |
558 | rand_mask_prob = self.noise_schedule(timestep)
559 | num_token_masked = max(int((rand_mask_prob * seq_len).item()), 1)
560 |
561 | masked_indices = scores.topk(num_token_masked, dim = -1).indices
562 |
563 | ids = ids.scatter(1, masked_indices, self.mask_id)
564 |
565 | logits, embed = demask_fn(
566 | ids,
567 | text_embeds = text_embeds,
568 | self_cond_embed = self_cond_embed,
569 | conditioning_token_ids = cond_ids,
570 | cond_scale = cond_scale,
571 | return_embed = True
572 | )
573 |
574 | self_cond_embed = embed if self.self_cond else None
575 |
576 | filtered_logits = top_k(logits, topk_filter_thres)
577 |
578 | temperature = starting_temperature * (steps_until_x0 / timesteps) # temperature is annealed
579 |
580 | pred_ids = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)
581 |
582 | is_mask = ids == self.mask_id
583 |
584 | ids = torch.where(
585 | is_mask,
586 | pred_ids,
587 | ids
588 | )
589 |
590 | if use_token_critic:
591 | scores = token_critic_fn(
592 | ids,
593 | text_embeds = text_embeds,
594 | conditioning_token_ids = cond_ids,
595 | cond_scale = cond_scale
596 | )
597 |
598 | scores = rearrange(scores, '... 1 -> ...')
599 |
600 | scores = scores + (uniform(scores.shape, device = device) - 0.5) * critic_noise_scale * (steps_until_x0 / timesteps)
601 |
602 | else:
603 | probs_without_temperature = logits.softmax(dim = -1)
604 |
605 | scores = 1 - probs_without_temperature.gather(2, pred_ids[..., None])
606 | scores = rearrange(scores, '... 1 -> ...')
607 |
608 | if not can_remask_prev_masked:
609 | scores = scores.masked_fill(~is_mask, -1e5)
610 | else:
611 | assert self.no_mask_token_prob > 0., 'without training with some of the non-masked tokens forced to predict, not sure if the logits will be meaningful for these token'
612 |
613 | # get ids
614 |
615 | ids = rearrange(ids, 'b (i j) -> b i j', i = fmap_size, j = fmap_size)
616 |
617 | if not exists(self.vae):
618 | return ids
619 |
620 | images = self.vae.decode_from_ids(ids)
621 | return images
622 |
623 | def forward(
624 | self,
625 | images_or_ids: torch.Tensor,
626 | ignore_index = -1,
627 | cond_images: Optional[torch.Tensor] = None,
628 | cond_token_ids: Optional[torch.Tensor] = None,
629 | texts: Optional[List[str]] = None,
630 | text_embeds: Optional[torch.Tensor] = None,
631 | cond_drop_prob = None,
632 | train_only_generator = False,
633 | sample_temperature = None
634 | ):
635 | # tokenize if needed
636 |
637 | if images_or_ids.dtype == torch.float:
638 | assert exists(self.vae), 'vqgan vae must be passed in if training from raw images'
639 | assert all([height_or_width == self.image_size for height_or_width in images_or_ids.shape[-2:]]), 'the image you passed in is not of the correct dimensions'
640 |
641 | with torch.no_grad():
642 | _, ids, _ = self.vae.encode(images_or_ids)
643 | else:
644 | assert not self.resize_image_for_cond_image, 'you cannot pass in raw image token ids if you want the framework to autoresize image for conditioning super res transformer'
645 | ids = images_or_ids
646 |
647 | # take care of conditioning image if specified
648 |
649 | if self.resize_image_for_cond_image:
650 | cond_images_or_ids = F.interpolate(images_or_ids, self.cond_image_size, mode = 'nearest')
651 |
652 | # get some basic variables
653 |
654 | ids = rearrange(ids, 'b ... -> b (...)')
655 |
656 | batch, seq_len, device, cond_drop_prob = *ids.shape, ids.device, default(cond_drop_prob, self.cond_drop_prob)
657 |
658 | # tokenize conditional images if needed
659 |
660 | assert not (exists(cond_images) and exists(cond_token_ids)), 'if conditioning on low resolution, cannot pass in both images and token ids'
661 |
662 | if exists(cond_images):
663 | assert exists(self.cond_vae), 'cond vqgan vae must be passed in'
664 | assert all([height_or_width == self.cond_image_size for height_or_width in cond_images.shape[-2:]])
665 |
666 | with torch.no_grad():
667 | _, cond_token_ids, _ = self.cond_vae.encode(cond_images)
668 |
669 | # prepare mask
670 |
671 | rand_time = uniform((batch,), device = device)
672 | rand_mask_probs = self.noise_schedule(rand_time)
673 | num_token_masked = (seq_len * rand_mask_probs).round().clamp(min = 1)
674 |
675 | mask_id = self.mask_id
676 | batch_randperm = torch.rand((batch, seq_len), device = device).argsort(dim = -1)
677 | mask = batch_randperm < rearrange(num_token_masked, 'b -> b 1')
678 |
679 | mask_id = self.transformer.mask_id
680 | labels = torch.where(mask, ids, ignore_index)
681 |
682 | if self.no_mask_token_prob > 0.:
683 | no_mask_mask = get_mask_subset_prob(mask, self.no_mask_token_prob)
684 | mask &= ~no_mask_mask
685 |
686 | x = torch.where(mask, mask_id, ids)
687 |
688 | # get text embeddings
689 |
690 | if exists(texts):
691 | text_embeds = self.transformer.encode_text(texts)
692 | texts = None
693 |
694 | # self conditioning
695 |
696 | self_cond_embed = None
697 |
698 | if self.transformer.self_cond and random() < self.self_cond_prob:
699 | with torch.no_grad():
700 | _, self_cond_embed = self.transformer(
701 | x,
702 | text_embeds = text_embeds,
703 | conditioning_token_ids = cond_token_ids,
704 | cond_drop_prob = 0.,
705 | return_embed = True
706 | )
707 |
708 | self_cond_embed.detach_()
709 |
710 | # get loss
711 |
712 | ce_loss, logits = self.transformer(
713 | x,
714 | text_embeds = text_embeds,
715 | self_cond_embed = self_cond_embed,
716 | conditioning_token_ids = cond_token_ids,
717 | labels = labels,
718 | cond_drop_prob = cond_drop_prob,
719 | ignore_index = ignore_index,
720 | return_logits = True
721 | )
722 |
723 | if not exists(self.token_critic) or train_only_generator:
724 | return ce_loss
725 |
726 | # token critic loss
727 |
728 | sampled_ids = gumbel_sample(logits, temperature = default(sample_temperature, random()))
729 |
730 | critic_input = torch.where(mask, sampled_ids, x)
731 | critic_labels = (ids != critic_input).float()
732 |
733 | bce_loss = self.token_critic(
734 | critic_input,
735 | text_embeds = text_embeds,
736 | conditioning_token_ids = cond_token_ids,
737 | labels = critic_labels,
738 | cond_drop_prob = cond_drop_prob
739 | )
740 |
741 | return ce_loss + self.critic_loss_weight * bce_loss
742 |
743 | # final Muse class
744 |
745 | @beartype
746 | class Muse(nn.Module):
747 | def __init__(
748 | self,
749 | base: MaskGit,
750 | superres: MaskGit
751 | ):
752 | super().__init__()
753 | self.base_maskgit = base.eval()
754 |
755 | assert superres.resize_image_for_cond_image
756 | self.superres_maskgit = superres.eval()
757 |
758 | @torch.no_grad()
759 | def forward(
760 | self,
761 | texts: List[str],
762 | cond_scale = 3.,
763 | temperature = 1.,
764 | timesteps = 18,
765 | superres_timesteps = None,
766 | return_lowres = False,
767 | return_pil_images = True
768 | ):
769 | lowres_image = self.base_maskgit.generate(
770 | texts = texts,
771 | cond_scale = cond_scale,
772 | temperature = temperature,
773 | timesteps = timesteps
774 | )
775 |
776 | superres_image = self.superres_maskgit.generate(
777 | texts = texts,
778 | cond_scale = cond_scale,
779 | cond_images = lowres_image,
780 | temperature = temperature,
781 | timesteps = default(superres_timesteps, timesteps)
782 | )
783 |
784 | if return_pil_images:
785 | lowres_image = list(map(T.ToPILImage(), lowres_image))
786 | superres_image = list(map(T.ToPILImage(), superres_image))
787 |
788 | if not return_lowres:
789 | return superres_image
790 |
791 | return superres_image, lowres_image
792 |
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/t5.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch
3 | import transformers
4 | from transformers import T5Tokenizer, T5EncoderModel, T5Config
5 |
6 | from beartype import beartype
7 | from typing import List, Union
8 |
9 | transformers.logging.set_verbosity_error()
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 | # config
15 |
16 | MAX_LENGTH = 256
17 |
18 | DEFAULT_T5_NAME = 'google/t5-v1_1-base'
19 |
20 | T5_CONFIGS = {}
21 |
22 | # singleton globals
23 |
24 | def get_tokenizer(name):
25 | tokenizer = T5Tokenizer.from_pretrained(name)
26 | return tokenizer
27 |
28 | def get_model(name):
29 | model = T5EncoderModel.from_pretrained(name)
30 | return model
31 |
32 | def get_model_and_tokenizer(name):
33 | global T5_CONFIGS
34 |
35 | if name not in T5_CONFIGS:
36 | T5_CONFIGS[name] = dict()
37 | if "model" not in T5_CONFIGS[name]:
38 | T5_CONFIGS[name]["model"] = get_model(name)
39 | if "tokenizer" not in T5_CONFIGS[name]:
40 | T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name)
41 |
42 | return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer']
43 |
44 | def get_encoded_dim(name):
45 | if name not in T5_CONFIGS:
46 | # avoids loading the model if we only want to get the dim
47 | config = T5Config.from_pretrained(name)
48 | T5_CONFIGS[name] = dict(config=config)
49 | elif "config" in T5_CONFIGS[name]:
50 | config = T5_CONFIGS[name]["config"]
51 | elif "model" in T5_CONFIGS[name]:
52 | config = T5_CONFIGS[name]["model"].config
53 | else:
54 | assert False
55 | return config.d_model
56 |
57 | # encoding text
58 |
59 | @beartype
60 | def t5_encode_text(
61 | texts: Union[str, List[str]],
62 | name = DEFAULT_T5_NAME,
63 | output_device = None
64 | ):
65 | if isinstance(texts, str):
66 | texts = [texts]
67 |
68 | t5, tokenizer = get_model_and_tokenizer(name)
69 |
70 | if torch.cuda.is_available():
71 | t5 = t5.cuda()
72 |
73 | device = next(t5.parameters()).device
74 |
75 | encoded = tokenizer.batch_encode_plus(
76 | texts,
77 | return_tensors = "pt",
78 | padding = 'longest',
79 | max_length = MAX_LENGTH,
80 | truncation = True
81 | )
82 |
83 | input_ids = encoded.input_ids.to(device)
84 | attn_mask = encoded.attention_mask.to(device)
85 |
86 | t5.eval()
87 |
88 | with torch.no_grad():
89 | output = t5(input_ids = input_ids, attention_mask = attn_mask)
90 | encoded_text = output.last_hidden_state.detach()
91 |
92 | attn_mask = attn_mask.bool()
93 | encoded_text = encoded_text.masked_fill(~attn_mask[..., None], 0.)
94 |
95 | if not exists(output_device):
96 | return encoded_text
97 |
98 | encoded_text.to(output_device)
99 | return encoded_text
100 |
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/trainers.py:
--------------------------------------------------------------------------------
1 | from math import sqrt
2 | from random import choice
3 | from pathlib import Path
4 | from shutil import rmtree
5 | from functools import partial
6 |
7 | from beartype import beartype
8 |
9 | import torch
10 | from torch import nn
11 | from torch.optim import Adam
12 | from torch.utils.data import Dataset, DataLoader, random_split
13 |
14 | import torchvision.transforms as T
15 | from torchvision.datasets import ImageFolder
16 | from torchvision.utils import make_grid, save_image
17 |
18 | from muse_maskgit_pytorch.vqgan_vae import VQGanVAE
19 |
20 | from einops import rearrange
21 |
22 | from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs
23 |
24 | from ema_pytorch import EMA
25 |
26 | from PIL import Image, ImageFile
27 | ImageFile.LOAD_TRUNCATED_IMAGES = True
28 |
29 | # helper functions
30 |
31 | def exists(val):
32 | return val is not None
33 |
34 | def identity(t, *args, **kwargs):
35 | return t
36 |
37 | def noop(*args, **kwargs):
38 | pass
39 |
40 | def find_index(arr, cond):
41 | for ind, el in enumerate(arr):
42 | if cond(el):
43 | return ind
44 | return None
45 |
46 | def find_and_pop(arr, cond, default = None):
47 | ind = find_index(arr, cond)
48 |
49 | if exists(ind):
50 | return arr.pop(ind)
51 |
52 | if callable(default):
53 | return default()
54 |
55 | return default
56 |
57 | def cycle(dl):
58 | while True:
59 | for data in dl:
60 | yield data
61 |
62 | def cast_tuple(t):
63 | return t if isinstance(t, (tuple, list)) else (t,)
64 |
65 | def yes_or_no(question):
66 | answer = input(f'{question} (y/n) ')
67 | return answer.lower() in ('yes', 'y')
68 |
69 | def accum_log(log, new_logs):
70 | for key, new_value in new_logs.items():
71 | old_value = log.get(key, 0.)
72 | log[key] = old_value + new_value
73 | return log
74 |
75 | def pair(val):
76 | return val if isinstance(val, tuple) else (val, val)
77 |
78 | def convert_image_to_fn(img_type, image):
79 | if image.mode != img_type:
80 | return image.convert(img_type)
81 | return image
82 |
83 | # image related helpers fnuctions and dataset
84 |
85 | class ImageDataset(Dataset):
86 | def __init__(
87 | self,
88 | folder,
89 | image_size,
90 | exts = ['jpg', 'jpeg', 'png']
91 | ):
92 | super().__init__()
93 | self.folder = folder
94 | self.image_size = image_size
95 | self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
96 |
97 | print(f'{len(self.paths)} training samples found at {folder}')
98 |
99 | self.transform = T.Compose([
100 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
101 | T.Resize(image_size),
102 | T.RandomHorizontalFlip(),
103 | T.CenterCrop(image_size),
104 | T.ToTensor()
105 | ])
106 |
107 | def __len__(self):
108 | return len(self.paths)
109 |
110 | def __getitem__(self, index):
111 | path = self.paths[index]
112 | img = Image.open(path)
113 | return self.transform(img)
114 |
115 | # main trainer class
116 |
117 | @beartype
118 | class VQGanVAETrainer(nn.Module):
119 | def __init__(
120 | self,
121 | vae: VQGanVAE,
122 | *,
123 | folder,
124 | num_train_steps,
125 | batch_size,
126 | image_size,
127 | lr = 3e-4,
128 | grad_accum_every = 1,
129 | max_grad_norm = None,
130 | discr_max_grad_norm = None,
131 | save_results_every = 100,
132 | save_model_every = 1000,
133 | results_folder = './results',
134 | valid_frac = 0.05,
135 | random_split_seed = 42,
136 | use_ema = True,
137 | ema_beta = 0.995,
138 | ema_update_after_step = 0,
139 | ema_update_every = 1,
140 | apply_grad_penalty_every = 4,
141 | accelerate_kwargs: dict = dict()
142 | ):
143 | super().__init__()
144 |
145 | # instantiate accelerator
146 |
147 | kwargs_handlers = accelerate_kwargs.get('kwargs_handlers', [])
148 |
149 | ddp_kwargs = find_and_pop(
150 | kwargs_handlers,
151 | lambda x: isinstance(x, DistributedDataParallelKwargs),
152 | partial(DistributedDataParallelKwargs, find_unused_parameters = True)
153 | )
154 |
155 | ddp_kwargs.find_unused_parameters = True
156 | kwargs_handlers.append(ddp_kwargs)
157 | accelerate_kwargs.update(kwargs_handlers = kwargs_handlers)
158 |
159 | self.accelerator = Accelerator(**accelerate_kwargs)
160 |
161 | # vae
162 |
163 | self.vae = vae
164 |
165 | # training params
166 |
167 | self.register_buffer('steps', torch.Tensor([0]))
168 |
169 | self.num_train_steps = num_train_steps
170 | self.batch_size = batch_size
171 | self.grad_accum_every = grad_accum_every
172 |
173 | all_parameters = set(vae.parameters())
174 | discr_parameters = set(vae.discr.parameters())
175 | vae_parameters = all_parameters - discr_parameters
176 |
177 | self.vae_parameters = vae_parameters
178 |
179 | # optimizers
180 |
181 | self.optim = Adam(vae_parameters, lr = lr)
182 | self.discr_optim = Adam(discr_parameters, lr = lr)
183 |
184 | self.max_grad_norm = max_grad_norm
185 | self.discr_max_grad_norm = discr_max_grad_norm
186 |
187 | # create dataset
188 |
189 | self.ds = ImageDataset(folder, image_size)
190 |
191 | # split for validation
192 |
193 | if valid_frac > 0:
194 | train_size = int((1 - valid_frac) * len(self.ds))
195 | valid_size = len(self.ds) - train_size
196 | self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
197 | self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
198 | else:
199 | self.valid_ds = self.ds
200 | self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')
201 |
202 | # dataloader
203 |
204 | self.dl = DataLoader(
205 | self.ds,
206 | batch_size = batch_size,
207 | shuffle = True
208 | )
209 |
210 | self.valid_dl = DataLoader(
211 | self.valid_ds,
212 | batch_size = batch_size,
213 | shuffle = True
214 | )
215 |
216 | # prepare with accelerator
217 |
218 | (
219 | self.vae,
220 | self.optim,
221 | self.discr_optim,
222 | self.dl,
223 | self.valid_dl
224 | ) = self.accelerator.prepare(
225 | self.vae,
226 | self.optim,
227 | self.discr_optim,
228 | self.dl,
229 | self.valid_dl
230 | )
231 |
232 | self.use_ema = use_ema
233 |
234 | if use_ema:
235 | self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every)
236 | self.ema_vae = self.accelerator.prepare(self.ema_vae)
237 |
238 | self.dl_iter = cycle(self.dl)
239 | self.valid_dl_iter = cycle(self.valid_dl)
240 |
241 | self.save_model_every = save_model_every
242 | self.save_results_every = save_results_every
243 |
244 | self.apply_grad_penalty_every = apply_grad_penalty_every
245 |
246 | self.results_folder = Path(results_folder)
247 |
248 | if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
249 | rmtree(str(self.results_folder))
250 |
251 | self.results_folder.mkdir(parents = True, exist_ok = True)
252 |
253 | def save(self, path):
254 | if not self.accelerator.is_local_main_process:
255 | return
256 |
257 | pkg = dict(
258 | model = self.accelerator.get_state_dict(self.vae),
259 | optim = self.optim.state_dict(),
260 | discr_optim = self.discr_optim.state_dict()
261 | )
262 | torch.save(pkg, path)
263 |
264 | def load(self, path):
265 | path = Path(path)
266 | assert path.exists()
267 | pkg = torch.load(path)
268 |
269 | vae = self.accelerator.unwrap_model(self.vae)
270 | vae.load_state_dict(pkg['model'])
271 |
272 | self.optim.load_state_dict(pkg['optim'])
273 | self.discr_optim.load_state_dict(pkg['discr_optim'])
274 |
275 | def print(self, msg):
276 | self.accelerator.print(msg)
277 |
278 | @property
279 | def device(self):
280 | return self.accelerator.device
281 |
282 | @property
283 | def is_distributed(self):
284 | return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
285 |
286 | @property
287 | def is_main(self):
288 | return self.accelerator.is_main_process
289 |
290 | @property
291 | def is_local_main(self):
292 | return self.accelerator.is_local_main_process
293 |
294 | def train_step(self):
295 | device = self.device
296 |
297 | steps = int(self.steps.item())
298 | apply_grad_penalty = not (steps % self.apply_grad_penalty_every)
299 |
300 | self.vae.train()
301 | discr = self.vae.module.discr if self.is_distributed else self.vae.discr
302 | if self.use_ema:
303 | ema_vae = self.ema_vae.module if self.is_distributed else self.ema_vae
304 |
305 | # logs
306 |
307 | logs = {}
308 |
309 | # update vae (generator)
310 |
311 | for _ in range(self.grad_accum_every):
312 | img = next(self.dl_iter)
313 | img = img.to(device)
314 |
315 | with self.accelerator.autocast():
316 | loss = self.vae(
317 | img,
318 | add_gradient_penalty = apply_grad_penalty,
319 | return_loss = True
320 | )
321 |
322 | self.accelerator.backward(loss / self.grad_accum_every)
323 |
324 | accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
325 |
326 | if exists(self.max_grad_norm):
327 | self.accelerator.clip_grad_norm_(self.vae.parameters(), self.max_grad_norm)
328 |
329 | self.optim.step()
330 | self.optim.zero_grad()
331 |
332 | # update discriminator
333 |
334 | if exists(discr):
335 | self.discr_optim.zero_grad()
336 |
337 | for _ in range(self.grad_accum_every):
338 | img = next(self.dl_iter)
339 | img = img.to(device)
340 |
341 | loss = self.vae(img, return_discr_loss = True)
342 |
343 | self.accelerator.backward(loss / self.grad_accum_every)
344 |
345 | accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
346 |
347 | if exists(self.discr_max_grad_norm):
348 | self.accelerator.clip_grad_norm_(discr.parameters(), self.discr_max_grad_norm)
349 |
350 | self.discr_optim.step()
351 |
352 | # log
353 |
354 | self.print(f"{steps}: vae loss: {logs['loss']} - discr loss: {logs['discr_loss']}")
355 |
356 | # update exponential moving averaged generator
357 |
358 | if self.use_ema:
359 | ema_vae.update()
360 |
361 | # sample results every so often
362 |
363 | if not (steps % self.save_results_every):
364 | vaes_to_evaluate = ((self.vae, str(steps)),)
365 |
366 | if self.use_ema:
367 | vaes_to_evaluate = ((ema_vae.ema_model, f'{steps}.ema'),) + vaes_to_evaluate
368 |
369 | for model, filename in vaes_to_evaluate:
370 | model.eval()
371 |
372 | valid_data = next(self.valid_dl_iter)
373 | valid_data = valid_data.to(device)
374 |
375 | recons = model(valid_data, return_recons = True)
376 |
377 | # else save a grid of images
378 |
379 | imgs_and_recons = torch.stack((valid_data, recons), dim = 0)
380 | imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...')
381 |
382 | imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.)
383 | grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1))
384 |
385 | logs['reconstructions'] = grid
386 |
387 | save_image(grid, str(self.results_folder / f'{filename}.png'))
388 |
389 | self.print(f'{steps}: saving to {str(self.results_folder)}')
390 |
391 | # save model every so often
392 | self.accelerator.wait_for_everyone()
393 | if self.is_main and not (steps % self.save_model_every):
394 | state_dict = self.accelerator.unwrap_model(self.vae).state_dict()
395 | model_path = str(self.results_folder / f'vae.{steps}.pt')
396 | self.accelerator.save(state_dict, model_path)
397 |
398 | if self.use_ema:
399 | ema_state_dict = self.accelerator.unwrap_model(self.ema_vae).state_dict()
400 | model_path = str(self.results_folder / f'vae.{steps}.ema.pt')
401 | self.accelerator.save(ema_state_dict, model_path)
402 |
403 | self.print(f'{steps}: saving model to {str(self.results_folder)}')
404 |
405 | self.steps += 1
406 | return logs
407 |
408 | def train(self, log_fn = noop):
409 | device = next(self.vae.parameters()).device
410 |
411 | while self.steps < self.num_train_steps:
412 | logs = self.train_step()
413 | log_fn(logs)
414 |
415 | self.print('training complete')
416 |
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/vqgan_vae.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import copy
3 | import math
4 | from math import sqrt
5 | from functools import partial, wraps
6 |
7 | from vector_quantize_pytorch import VectorQuantize as VQ, LFQ
8 |
9 | import torch
10 | from torch import nn, einsum
11 | import torch.nn.functional as F
12 | from torch.autograd import grad as torch_grad
13 |
14 | import torchvision
15 |
16 | from einops import rearrange, reduce, repeat, pack, unpack
17 | from einops.layers.torch import Rearrange
18 |
19 | # constants
20 |
21 | MList = nn.ModuleList
22 |
23 | # helper functions
24 |
25 | def exists(val):
26 | return val is not None
27 |
28 | def default(val, d):
29 | return val if exists(val) else d
30 |
31 | # decorators
32 |
33 | def eval_decorator(fn):
34 | def inner(model, *args, **kwargs):
35 | was_training = model.training
36 | model.eval()
37 | out = fn(model, *args, **kwargs)
38 | model.train(was_training)
39 | return out
40 | return inner
41 |
42 | def remove_vgg(fn):
43 | @wraps(fn)
44 | def inner(self, *args, **kwargs):
45 | has_vgg = hasattr(self, '_vgg')
46 | if has_vgg:
47 | vgg = self._vgg
48 | delattr(self, '_vgg')
49 |
50 | out = fn(self, *args, **kwargs)
51 |
52 | if has_vgg:
53 | self._vgg = vgg
54 |
55 | return out
56 | return inner
57 |
58 | # keyword argument helpers
59 |
60 | def pick_and_pop(keys, d):
61 | values = list(map(lambda key: d.pop(key), keys))
62 | return dict(zip(keys, values))
63 |
64 | def group_dict_by_key(cond, d):
65 | return_val = [dict(),dict()]
66 | for key in d.keys():
67 | match = bool(cond(key))
68 | ind = int(not match)
69 | return_val[ind][key] = d[key]
70 | return (*return_val,)
71 |
72 | def string_begins_with(prefix, string_input):
73 | return string_input.startswith(prefix)
74 |
75 | def group_by_key_prefix(prefix, d):
76 | return group_dict_by_key(partial(string_begins_with, prefix), d)
77 |
78 | def groupby_prefix_and_trim(prefix, d):
79 | kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
80 | kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
81 | return kwargs_without_prefix, kwargs
82 |
83 | # tensor helper functions
84 |
85 | def log(t, eps = 1e-10):
86 | return torch.log(t + eps)
87 |
88 | def gradient_penalty(images, output, weight = 10):
89 | batch_size = images.shape[0]
90 |
91 | gradients = torch_grad(
92 | outputs = output,
93 | inputs = images,
94 | grad_outputs = torch.ones(output.size(), device = images.device),
95 | create_graph = True,
96 | retain_graph = True,
97 | only_inputs = True
98 | )[0]
99 |
100 | gradients = rearrange(gradients, 'b ... -> b (...)')
101 | return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()
102 |
103 | def leaky_relu(p = 0.1):
104 | return nn.LeakyReLU(0.1)
105 |
106 | def safe_div(numer, denom, eps = 1e-8):
107 | return numer / denom.clamp(min = eps)
108 |
109 | # gan losses
110 |
111 | def hinge_discr_loss(fake, real):
112 | return (F.relu(1 + fake) + F.relu(1 - real)).mean()
113 |
114 | def hinge_gen_loss(fake):
115 | return -fake.mean()
116 |
117 | def bce_discr_loss(fake, real):
118 | return (-log(1 - torch.sigmoid(fake)) - log(torch.sigmoid(real))).mean()
119 |
120 | def bce_gen_loss(fake):
121 | return -log(torch.sigmoid(fake)).mean()
122 |
123 | def grad_layer_wrt_loss(loss, layer):
124 | return torch_grad(
125 | outputs = loss,
126 | inputs = layer,
127 | grad_outputs = torch.ones_like(loss),
128 | retain_graph = True
129 | )[0].detach()
130 |
131 | # vqgan vae
132 |
133 | class LayerNormChan(nn.Module):
134 | def __init__(
135 | self,
136 | dim,
137 | eps = 1e-5
138 | ):
139 | super().__init__()
140 | self.eps = eps
141 | self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
142 |
143 | def forward(self, x):
144 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
145 | mean = torch.mean(x, dim = 1, keepdim = True)
146 | return (x - mean) * var.clamp(min = self.eps).rsqrt() * self.gamma
147 |
148 | # discriminator
149 |
150 | class Discriminator(nn.Module):
151 | def __init__(
152 | self,
153 | dims,
154 | channels = 3,
155 | groups = 16,
156 | init_kernel_size = 5
157 | ):
158 | super().__init__()
159 | dim_pairs = zip(dims[:-1], dims[1:])
160 |
161 | self.layers = MList([nn.Sequential(nn.Conv2d(channels, dims[0], init_kernel_size, padding = init_kernel_size // 2), leaky_relu())])
162 |
163 | for dim_in, dim_out in dim_pairs:
164 | self.layers.append(nn.Sequential(
165 | nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1),
166 | nn.GroupNorm(groups, dim_out),
167 | leaky_relu()
168 | ))
169 |
170 | dim = dims[-1]
171 | self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training
172 | nn.Conv2d(dim, dim, 1),
173 | leaky_relu(),
174 | nn.Conv2d(dim, 1, 4)
175 | )
176 |
177 | def forward(self, x):
178 | for net in self.layers:
179 | x = net(x)
180 |
181 | return self.to_logits(x)
182 |
183 | # resnet encoder / decoder
184 |
185 | class ResnetEncDec(nn.Module):
186 | def __init__(
187 | self,
188 | dim,
189 | *,
190 | channels = 3,
191 | layers = 4,
192 | layer_mults = None,
193 | num_resnet_blocks = 1,
194 | resnet_groups = 16,
195 | first_conv_kernel_size = 5
196 | ):
197 | super().__init__()
198 | assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'
199 |
200 | self.layers = layers
201 |
202 | self.encoders = MList([])
203 | self.decoders = MList([])
204 |
205 | layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
206 | assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
207 |
208 | layer_dims = [dim * mult for mult in layer_mults]
209 | dims = (dim, *layer_dims)
210 |
211 | self.encoded_dim = dims[-1]
212 |
213 | dim_pairs = zip(dims[:-1], dims[1:])
214 |
215 | append = lambda arr, t: arr.append(t)
216 | prepend = lambda arr, t: arr.insert(0, t)
217 |
218 | if not isinstance(num_resnet_blocks, tuple):
219 | num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)
220 |
221 | assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'
222 |
223 | for layer_index, (dim_in, dim_out), layer_num_resnet_blocks in zip(range(layers), dim_pairs, num_resnet_blocks):
224 | append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
225 | prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))
226 |
227 | for _ in range(layer_num_resnet_blocks):
228 | append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
229 | prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))
230 |
231 | prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
232 | append(self.decoders, nn.Conv2d(dim, channels, 1))
233 |
234 | def get_encoded_fmap_size(self, image_size):
235 | return image_size // (2 ** self.layers)
236 |
237 | @property
238 | def last_dec_layer(self):
239 | return self.decoders[-1].weight
240 |
241 | def encode(self, x):
242 | for enc in self.encoders:
243 | x = enc(x)
244 | return x
245 |
246 | def decode(self, x):
247 | for dec in self.decoders:
248 | x = dec(x)
249 | return x
250 |
251 | class GLUResBlock(nn.Module):
252 | def __init__(self, chan, groups = 16):
253 | super().__init__()
254 | self.net = nn.Sequential(
255 | nn.Conv2d(chan, chan * 2, 3, padding = 1),
256 | nn.GLU(dim = 1),
257 | nn.GroupNorm(groups, chan),
258 | nn.Conv2d(chan, chan * 2, 3, padding = 1),
259 | nn.GLU(dim = 1),
260 | nn.GroupNorm(groups, chan),
261 | nn.Conv2d(chan, chan, 1)
262 | )
263 |
264 | def forward(self, x):
265 | return self.net(x) + x
266 |
267 | class ResBlock(nn.Module):
268 | def __init__(self, chan, groups = 16):
269 | super().__init__()
270 | self.net = nn.Sequential(
271 | nn.Conv2d(chan, chan, 3, padding = 1),
272 | nn.GroupNorm(groups, chan),
273 | leaky_relu(),
274 | nn.Conv2d(chan, chan, 3, padding = 1),
275 | nn.GroupNorm(groups, chan),
276 | leaky_relu(),
277 | nn.Conv2d(chan, chan, 1)
278 | )
279 |
280 | def forward(self, x):
281 | return self.net(x) + x
282 |
283 | # main vqgan-vae classes
284 |
285 | class VQGanVAE(nn.Module):
286 | def __init__(
287 | self,
288 | *,
289 | dim,
290 | channels = 3,
291 | layers = 4,
292 | l2_recon_loss = False,
293 | use_hinge_loss = True,
294 | vgg = None,
295 | lookup_free_quantization = True,
296 | codebook_size = 65536,
297 | vq_kwargs: dict = dict(
298 | codebook_dim = 256,
299 | decay = 0.8,
300 | commitment_weight = 1.,
301 | kmeans_init = True,
302 | use_cosine_sim = True,
303 | ),
304 | lfq_kwargs: dict = dict(
305 | diversity_gamma = 4.
306 | ),
307 | use_vgg_and_gan = True,
308 | discr_layers = 4,
309 | **kwargs
310 | ):
311 | super().__init__()
312 | vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)
313 | encdec_kwargs, kwargs = groupby_prefix_and_trim('encdec_', kwargs)
314 |
315 | self.channels = channels
316 | self.codebook_size = codebook_size
317 | self.dim_divisor = 2 ** layers
318 |
319 | enc_dec_klass = ResnetEncDec
320 |
321 | self.enc_dec = enc_dec_klass(
322 | dim = dim,
323 | channels = channels,
324 | layers = layers,
325 | **encdec_kwargs
326 | )
327 |
328 | self.lookup_free_quantization = lookup_free_quantization
329 |
330 | if lookup_free_quantization:
331 | self.quantizer = LFQ(
332 | dim = self.enc_dec.encoded_dim,
333 | codebook_size = codebook_size,
334 | **lfq_kwargs
335 | )
336 | else:
337 | self.quantizer = VQ(
338 | dim = self.enc_dec.encoded_dim,
339 | codebook_size = codebook_size,
340 | accept_image_fmap = True
341 | **vq_kwargs
342 | )
343 |
344 | # reconstruction loss
345 |
346 | self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss
347 |
348 | # turn off GAN and perceptual loss if grayscale
349 |
350 | self._vgg = None
351 | self.discr = None
352 | self.use_vgg_and_gan = use_vgg_and_gan
353 |
354 | if not use_vgg_and_gan:
355 | return
356 |
357 | # preceptual loss
358 |
359 | if exists(vgg):
360 | self._vgg = vgg
361 |
362 | # gan related losses
363 |
364 | layer_mults = list(map(lambda t: 2 ** t, range(discr_layers)))
365 | layer_dims = [dim * mult for mult in layer_mults]
366 | dims = (dim, *layer_dims)
367 |
368 | self.discr = Discriminator(dims = dims, channels = channels)
369 |
370 | self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
371 | self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
372 |
373 | @property
374 | def device(self):
375 | return next(self.parameters()).device
376 |
377 | @property
378 | def vgg(self):
379 | if exists(self._vgg):
380 | return self._vgg
381 |
382 | vgg = torchvision.models.vgg16(pretrained = True)
383 | vgg.classifier = nn.Sequential(*vgg.classifier[:-2])
384 | self._vgg = vgg.to(self.device)
385 | return self._vgg
386 |
387 | @property
388 | def encoded_dim(self):
389 | return self.enc_dec.encoded_dim
390 |
391 | def get_encoded_fmap_size(self, image_size):
392 | return self.enc_dec.get_encoded_fmap_size(image_size)
393 |
394 | def copy_for_eval(self):
395 | device = next(self.parameters()).device
396 | vae_copy = copy.deepcopy(self.cpu())
397 |
398 | if vae_copy.use_vgg_and_gan:
399 | del vae_copy.discr
400 | del vae_copy._vgg
401 |
402 | vae_copy.eval()
403 | return vae_copy.to(device)
404 |
405 | @remove_vgg
406 | def state_dict(self, *args, **kwargs):
407 | return super().state_dict(*args, **kwargs)
408 |
409 | @remove_vgg
410 | def load_state_dict(self, *args, **kwargs):
411 | return super().load_state_dict(*args, **kwargs)
412 |
413 | def save(self, path):
414 | torch.save(self.state_dict(), path)
415 |
416 | def load(self, path):
417 | path = Path(path)
418 | assert path.exists()
419 | state_dict = torch.load(str(path))
420 | self.load_state_dict(state_dict)
421 |
422 | def encode(self, fmap):
423 | fmap = self.enc_dec.encode(fmap)
424 | fmap, indices, vq_aux_loss = self.quantizer(fmap)
425 | return fmap, indices, vq_aux_loss
426 |
427 | def decode_from_ids(self, ids):
428 |
429 | if self.lookup_free_quantization:
430 | ids, ps = pack([ids], 'b *')
431 | fmap = self.quantizer.indices_to_codes(ids)
432 | fmap, = unpack(fmap, ps, 'b * c')
433 | else:
434 | codes = self.codebook[ids]
435 | fmap = self.quantizer.project_out(codes)
436 |
437 | fmap = rearrange(fmap, 'b h w c -> b c h w')
438 | return self.decode(fmap)
439 |
440 | def decode(self, fmap):
441 | return self.enc_dec.decode(fmap)
442 |
443 | def forward(
444 | self,
445 | img,
446 | return_loss = False,
447 | return_discr_loss = False,
448 | return_recons = False,
449 | add_gradient_penalty = True
450 | ):
451 | batch, channels, height, width, device = *img.shape, img.device
452 |
453 | for dim_name, size in (('height', height), ('width', width)):
454 | assert (size % self.dim_divisor) == 0, f'{dim_name} must be divisible by {self.dim_divisor}'
455 |
456 | assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'
457 |
458 | fmap, indices, commit_loss = self.encode(img)
459 |
460 | fmap = self.decode(fmap)
461 |
462 | if not return_loss and not return_discr_loss:
463 | return fmap
464 |
465 | assert return_loss ^ return_discr_loss, 'you should either return autoencoder loss or discriminator loss, but not both'
466 |
467 | # whether to return discriminator loss
468 |
469 | if return_discr_loss:
470 | assert exists(self.discr), 'discriminator must exist to train it'
471 |
472 | fmap.detach_()
473 | img.requires_grad_()
474 |
475 | fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
476 |
477 | discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
478 |
479 | if add_gradient_penalty:
480 | gp = gradient_penalty(img, img_discr_logits)
481 | loss = discr_loss + gp
482 |
483 | if return_recons:
484 | return loss, fmap
485 |
486 | return loss
487 |
488 | # reconstruction loss
489 |
490 | recon_loss = self.recon_loss_fn(fmap, img)
491 |
492 | # early return if training on grayscale
493 |
494 | if not self.use_vgg_and_gan:
495 | if return_recons:
496 | return recon_loss, fmap
497 |
498 | return recon_loss
499 |
500 | # perceptual loss
501 |
502 | img_vgg_input = img
503 | fmap_vgg_input = fmap
504 |
505 | if img.shape[1] == 1:
506 | # handle grayscale for vgg
507 | img_vgg_input, fmap_vgg_input = map(lambda t: repeat(t, 'b 1 ... -> b c ...', c = 3), (img_vgg_input, fmap_vgg_input))
508 |
509 | img_vgg_feats = self.vgg(img_vgg_input)
510 | recon_vgg_feats = self.vgg(fmap_vgg_input)
511 | perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats)
512 |
513 | # generator loss
514 |
515 | gen_loss = self.gen_loss(self.discr(fmap))
516 |
517 | # calculate adaptive weight
518 |
519 | last_dec_layer = self.enc_dec.last_dec_layer
520 |
521 | norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
522 | norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
523 |
524 | adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss)
525 | adaptive_weight.clamp_(max = 1e4)
526 |
527 | # combine losses
528 |
529 | loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss
530 |
531 | if return_recons:
532 | return loss, fmap
533 |
534 | return loss
535 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'muse-maskgit-pytorch',
5 | packages = find_packages(exclude=[]),
6 | version = '0.3.5',
7 | license='MIT',
8 | description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | long_description_content_type = 'text/markdown',
12 | url = 'https://github.com/lucidrains/muse-maskgit-pytorch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'transformers',
17 | 'attention mechanism',
18 | 'text-to-image'
19 | ],
20 | install_requires=[
21 | 'accelerate',
22 | 'beartype',
23 | 'einops>=0.7',
24 | 'ema-pytorch>=0.2.2',
25 | 'memory-efficient-attention-pytorch>=0.1.4',
26 | 'pillow',
27 | 'sentencepiece',
28 | 'torch>=1.6',
29 | 'transformers',
30 | 'torch>=1.6',
31 | 'torchvision',
32 | 'tqdm',
33 | 'vector-quantize-pytorch>=1.11.8'
34 | ],
35 | classifiers=[
36 | 'Development Status :: 4 - Beta',
37 | 'Intended Audience :: Developers',
38 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
39 | 'License :: OSI Approved :: MIT License',
40 | 'Programming Language :: Python :: 3.6',
41 | ],
42 | )
43 |
--------------------------------------------------------------------------------