├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── muse.png ├── muse_maskgit_pytorch ├── __init__.py ├── attend.py ├── muse_maskgit_pytorch.py ├── t5.py ├── trainers.py └── vqgan_vae.py └── setup.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Muse - Pytorch 4 | 5 | Implementation of Muse: Text-to-Image Generation via Masked Generative Transformers, in Pytorch 6 | 7 | Please join Join us on Discord if you are interested in helping out with the replication with the LAION community 8 | 9 | ## Install 10 | 11 | ```bash 12 | $ pip install muse-maskgit-pytorch 13 | ``` 14 | 15 | ## Usage 16 | 17 | First train your VAE - `VQGanVAE` 18 | 19 | ```python 20 | import torch 21 | from muse_maskgit_pytorch import VQGanVAE, VQGanVAETrainer 22 | 23 | vae = VQGanVAE( 24 | dim = 256, 25 | codebook_size = 65536 26 | ) 27 | 28 | # train on folder of images, as many images as possible 29 | 30 | trainer = VQGanVAETrainer( 31 | vae = vae, 32 | image_size = 128, # you may want to start with small images, and then curriculum learn to larger ones, but because the vae is all convolution, it should generalize to 512 (as in paper) without training on it 33 | folder = '/path/to/images', 34 | batch_size = 4, 35 | grad_accum_every = 8, 36 | num_train_steps = 50000 37 | ).cuda() 38 | 39 | trainer.train() 40 | ``` 41 | 42 | Then pass the trained `VQGanVAE` and a `Transformer` to `MaskGit` 43 | 44 | ```python 45 | import torch 46 | from muse_maskgit_pytorch import VQGanVAE, MaskGit, MaskGitTransformer 47 | 48 | # first instantiate your vae 49 | 50 | vae = VQGanVAE( 51 | dim = 256, 52 | codebook_size = 65536 53 | ).cuda() 54 | 55 | vae.load('/path/to/vae.pt') # you will want to load the exponentially moving averaged VAE 56 | 57 | # then you plug the vae and transformer into your MaskGit as so 58 | 59 | # (1) create your transformer / attention network 60 | 61 | transformer = MaskGitTransformer( 62 | num_tokens = 65536, # must be same as codebook size above 63 | seq_len = 256, # must be equivalent to fmap_size ** 2 in vae 64 | dim = 512, # model dimension 65 | depth = 8, # depth 66 | dim_head = 64, # attention head dimension 67 | heads = 8, # attention heads, 68 | ff_mult = 4, # feedforward expansion factor 69 | t5_name = 't5-small', # name of your T5 70 | ) 71 | 72 | # (2) pass your trained VAE and the base transformer to MaskGit 73 | 74 | base_maskgit = MaskGit( 75 | vae = vae, # vqgan vae 76 | transformer = transformer, # transformer 77 | image_size = 256, # image size 78 | cond_drop_prob = 0.25, # conditional dropout, for classifier free guidance 79 | ).cuda() 80 | 81 | # ready your training text and images 82 | 83 | texts = [ 84 | 'a child screaming at finding a worm within a half-eaten apple', 85 | 'lizard running across the desert on two feet', 86 | 'waking up to a psychedelic landscape', 87 | 'seashells sparkling in the shallow waters' 88 | ] 89 | 90 | images = torch.randn(4, 3, 256, 256).cuda() 91 | 92 | # feed it into your maskgit instance, with return_loss set to True 93 | 94 | loss = base_maskgit( 95 | images, 96 | texts = texts 97 | ) 98 | 99 | loss.backward() 100 | 101 | # do this for a long time on much data 102 | # then... 103 | 104 | images = base_maskgit.generate(texts = [ 105 | 'a whale breaching from afar', 106 | 'young girl blowing out candles on her birthday cake', 107 | 'fireworks with blue and green sparkles' 108 | ], cond_scale = 3.) # conditioning scale for classifier free guidance 109 | 110 | images.shape # (3, 3, 256, 256) 111 | ``` 112 | 113 | 114 | To train the super-resolution maskgit requires you to change 1 field on `MaskGit` instantiation (you will need to now pass in the `cond_image_size`, as the previous image size being conditioned on) 115 | 116 | Optionally, you can pass in a different `VAE` as `cond_vae` for the conditioning low-resolution image. By default it will use the `vae` for both tokenizing the super and low resoluted images. 117 | 118 | ```python 119 | import torch 120 | import torch.nn.functional as F 121 | from muse_maskgit_pytorch import VQGanVAE, MaskGit, MaskGitTransformer 122 | 123 | # first instantiate your ViT VQGan VAE 124 | # a VQGan VAE made of transformers 125 | 126 | vae = VQGanVAE( 127 | dim = 256, 128 | codebook_size = 65536 129 | ).cuda() 130 | 131 | vae.load('./path/to/vae.pt') # you will want to load the exponentially moving averaged VAE 132 | 133 | # then you plug the VqGan VAE into your MaskGit as so 134 | 135 | # (1) create your transformer / attention network 136 | 137 | transformer = MaskGitTransformer( 138 | num_tokens = 65536, # must be same as codebook size above 139 | seq_len = 1024, # must be equivalent to fmap_size ** 2 in vae 140 | dim = 512, # model dimension 141 | depth = 2, # depth 142 | dim_head = 64, # attention head dimension 143 | heads = 8, # attention heads, 144 | ff_mult = 4, # feedforward expansion factor 145 | t5_name = 't5-small', # name of your T5 146 | ) 147 | 148 | # (2) pass your trained VAE and the base transformer to MaskGit 149 | 150 | superres_maskgit = MaskGit( 151 | vae = vae, 152 | transformer = transformer, 153 | cond_drop_prob = 0.25, 154 | image_size = 512, # larger image size 155 | cond_image_size = 256, # conditioning image size <- this must be set 156 | ).cuda() 157 | 158 | # ready your training text and images 159 | 160 | texts = [ 161 | 'a child screaming at finding a worm within a half-eaten apple', 162 | 'lizard running across the desert on two feet', 163 | 'waking up to a psychedelic landscape', 164 | 'seashells sparkling in the shallow waters' 165 | ] 166 | 167 | images = torch.randn(4, 3, 512, 512).cuda() 168 | 169 | # feed it into your maskgit instance, with return_loss set to True 170 | 171 | loss = superres_maskgit( 172 | images, 173 | texts = texts 174 | ) 175 | 176 | loss.backward() 177 | 178 | # do this for a long time on much data 179 | # then... 180 | 181 | images = superres_maskgit.generate( 182 | texts = [ 183 | 'a whale breaching from afar', 184 | 'young girl blowing out candles on her birthday cake', 185 | 'fireworks with blue and green sparkles', 186 | 'waking up to a psychedelic landscape' 187 | ], 188 | cond_images = F.interpolate(images, 256), # conditioning images must be passed in for generating from superres 189 | cond_scale = 3. 190 | ) 191 | 192 | images.shape # (4, 3, 512, 512) 193 | ``` 194 | 195 | All together now 196 | 197 | ```python 198 | from muse_maskgit_pytorch import Muse 199 | 200 | base_maskgit.load('./path/to/base.pt') 201 | 202 | superres_maskgit.load('./path/to/superres.pt') 203 | 204 | # pass in the trained base_maskgit and superres_maskgit from above 205 | 206 | muse = Muse( 207 | base = base_maskgit, 208 | superres = superres_maskgit 209 | ) 210 | 211 | images = muse([ 212 | 'a whale breaching from afar', 213 | 'young girl blowing out candles on her birthday cake', 214 | 'fireworks with blue and green sparkles', 215 | 'waking up to a psychedelic landscape' 216 | ]) 217 | 218 | images # List[PIL.Image.Image] 219 | ``` 220 | 221 | ## Appreciation 222 | 223 | - StabilityAI for the sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence. 224 | 225 | - 🤗 Huggingface for the transformers and accelerate library, both which are wonderful 226 | 227 | ## Todo 228 | 229 | - [x] test end-to-end 230 | - [x] separate cond_images_or_ids, it is not done right 231 | - [x] add training code for vae 232 | - [x] add optional self-conditioning on embeddings 233 | - [x] combine with token critic paper, already implemented at Phenaki 234 | 235 | - [ ] hook up accelerate training code for maskgit 236 | 237 | ## Citations 238 | 239 | ```bibtex 240 | @inproceedings{Chang2023MuseTG, 241 | title = {Muse: Text-To-Image Generation via Masked Generative Transformers}, 242 | author = {Huiwen Chang and Han Zhang and Jarred Barber and AJ Maschinot and Jos{\'e} Lezama and Lu Jiang and Ming-Hsuan Yang and Kevin P. Murphy and William T. Freeman and Michael Rubinstein and Yuanzhen Li and Dilip Krishnan}, 243 | year = {2023} 244 | } 245 | ``` 246 | 247 | ```bibtex 248 | @article{Chen2022AnalogBG, 249 | title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning}, 250 | author = {Ting Chen and Ruixiang Zhang and Geo rey E. Hinton}, 251 | journal = {ArXiv}, 252 | year = {2022}, 253 | volume = {abs/2208.04202} 254 | } 255 | ``` 256 | 257 | ```bibtex 258 | @misc{jabri2022scalable, 259 | title = {Scalable Adaptive Computation for Iterative Generation}, 260 | author = {Allan Jabri and David Fleet and Ting Chen}, 261 | year = {2022}, 262 | eprint = {2212.11972}, 263 | archivePrefix = {arXiv}, 264 | primaryClass = {cs.LG} 265 | } 266 | ``` 267 | 268 | ```bibtex 269 | @article{Lezama2022ImprovedMI, 270 | title = {Improved Masked Image Generation with Token-Critic}, 271 | author = {Jos{\'e} Lezama and Huiwen Chang and Lu Jiang and Irfan Essa}, 272 | journal = {ArXiv}, 273 | year = {2022}, 274 | volume = {abs/2209.04439} 275 | } 276 | ``` 277 | 278 | ```bibtex 279 | @inproceedings{Nijkamp2021SCRIPTSP, 280 | title = {SCRIPT: Self-Critic PreTraining of Transformers}, 281 | author = {Erik Nijkamp and Bo Pang and Ying Nian Wu and Caiming Xiong}, 282 | booktitle = {North American Chapter of the Association for Computational Linguistics}, 283 | year = {2021} 284 | } 285 | ``` 286 | 287 | ```bibtex 288 | @inproceedings{dao2022flashattention, 289 | title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness}, 290 | author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, 291 | booktitle = {Advances in Neural Information Processing Systems}, 292 | year = {2022} 293 | } 294 | ``` 295 | 296 | ```bibtex 297 | @misc{mentzer2023finite, 298 | title = {Finite Scalar Quantization: VQ-VAE Made Simple}, 299 | author = {Fabian Mentzer and David Minnen and Eirikur Agustsson and Michael Tschannen}, 300 | year = {2023}, 301 | eprint = {2309.15505}, 302 | archivePrefix = {arXiv}, 303 | primaryClass = {cs.CV} 304 | } 305 | ``` 306 | 307 | ```bibtex 308 | @misc{yu2023language, 309 | title = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation}, 310 | author = {Lijun Yu and José Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang}, 311 | year = {2023}, 312 | eprint = {2310.05737}, 313 | archivePrefix = {arXiv}, 314 | primaryClass = {cs.CV} 315 | } 316 | ``` 317 | -------------------------------------------------------------------------------- /muse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/muse-maskgit-pytorch/6df7f33bcd33ba28a2f682d5bd293e4f8a513e6c/muse.png -------------------------------------------------------------------------------- /muse_maskgit_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from muse_maskgit_pytorch.vqgan_vae import VQGanVAE 2 | from muse_maskgit_pytorch.muse_maskgit_pytorch import Transformer, MaskGit, Muse, MaskGitTransformer, TokenCritic 3 | 4 | from muse_maskgit_pytorch.trainers import VQGanVAETrainer 5 | -------------------------------------------------------------------------------- /muse_maskgit_pytorch/attend.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from packaging import version 3 | from collections import namedtuple 4 | 5 | import torch 6 | from torch import nn, einsum 7 | import torch.nn.functional as F 8 | 9 | from memory_efficient_attention_pytorch.flash_attention import FlashAttentionFunction 10 | # constants 11 | 12 | AttentionConfig = namedtuple('AttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) 13 | 14 | # helpers 15 | 16 | def exists(val): 17 | return val is not None 18 | 19 | def once(fn): 20 | called = False 21 | @wraps(fn) 22 | def inner(x): 23 | nonlocal called 24 | if called: 25 | return 26 | called = True 27 | return fn(x) 28 | return inner 29 | 30 | print_once = once(print) 31 | 32 | # main class 33 | 34 | class Attend(nn.Module): 35 | def __init__( 36 | self, 37 | scale = 8, 38 | dropout = 0., 39 | flash = False 40 | ): 41 | super().__init__() 42 | self.scale = scale 43 | self.dropout = dropout 44 | self.attn_dropout = nn.Dropout(dropout) 45 | 46 | self.flash = flash 47 | assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' 48 | 49 | # determine efficient attention configs for cuda and cpu 50 | 51 | self.cuda_config = None 52 | self.no_hardware_detected = False 53 | 54 | if not torch.cuda.is_available() or not flash: 55 | return 56 | 57 | device_properties = torch.cuda.get_device_properties(torch.device('cuda')) 58 | 59 | if device_properties.major == 8 and device_properties.minor == 0: 60 | print_once('A100 GPU detected, using flash attention if input tensor is on cuda') 61 | self.cuda_config = AttentionConfig(True, False, False) 62 | else: 63 | print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') 64 | self.cuda_config = AttentionConfig(False, True, False) 65 | 66 | def flash_attn(self, q, k, v, mask = None): 67 | default_scale = q.shape[-1] ** -0.5 68 | 69 | is_cuda = q.is_cuda 70 | 71 | q, k, v = map(lambda t: t.contiguous(), (q, k, v)) 72 | 73 | # scaled_dot_product_attention does not allow for custom scale 74 | # so hack it in, to support rmsnorm-ed queries and keys 75 | 76 | rescale = self.scale / default_scale 77 | 78 | q = q * (rescale ** 0.5) 79 | k = k * (rescale ** 0.5) 80 | 81 | # use naive implementation if not correct hardware 82 | 83 | # the below logic can also incorporate whether masking is needed or not 84 | 85 | use_naive = not is_cuda or not exists(self.cuda_config) 86 | 87 | if not is_cuda or self.no_hardware_detected: 88 | return FlashAttentionFunction.apply(q, k, v, mask, False, 512, 512) 89 | 90 | # use naive implementation 91 | # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale 92 | 93 | try: 94 | raise Exception() 95 | with torch.backends.cuda.sdp_kernel(**self.cuda_config._asdict()): 96 | out = F.scaled_dot_product_attention( 97 | q, k, v, 98 | attn_mask = mask, 99 | dropout_p = self.dropout if self.training else 0. 100 | ) 101 | except: 102 | print_once('no hardware detected, falling back to naive implementation from memory-efficient-attention-pytorch library') 103 | self.no_hardware_detected = True 104 | 105 | out = FlashAttentionFunction.apply(q, k, v, mask, False, 512, 512) 106 | 107 | return out 108 | 109 | def forward(self, q, k, v, mask = None, force_non_flash = False): 110 | """ 111 | einstein notation 112 | b - batch 113 | h - heads 114 | n, i, j - sequence length (base sequence length, source, target) 115 | d - feature dimension 116 | """ 117 | 118 | if self.flash and not force_non_flash: 119 | return self.flash_attn(q, k, v, mask = mask) 120 | 121 | # similarity 122 | 123 | sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale 124 | 125 | # masking 126 | 127 | if exists(mask): 128 | mask_value = -torch.finfo(sim.dtype).max 129 | sim = sim.masked_fill(~mask, mask_value) 130 | 131 | # attention 132 | 133 | attn = sim.softmax(dim = -1) 134 | attn = self.attn_dropout(attn) 135 | 136 | # aggregate values 137 | 138 | out = einsum("b h i j, b h j d -> b h i d", attn, v) 139 | 140 | return out 141 | -------------------------------------------------------------------------------- /muse_maskgit_pytorch/muse_maskgit_pytorch.py: -------------------------------------------------------------------------------- 1 | import math 2 | from random import random 3 | from functools import partial 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn, einsum 8 | import pathlib 9 | from pathlib import Path 10 | import torchvision.transforms as T 11 | 12 | from typing import Callable, Optional, List 13 | 14 | from einops import rearrange, repeat 15 | 16 | from beartype import beartype 17 | 18 | from muse_maskgit_pytorch.vqgan_vae import VQGanVAE 19 | from muse_maskgit_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME 20 | from muse_maskgit_pytorch.attend import Attend 21 | 22 | from tqdm.auto import tqdm 23 | 24 | # helpers 25 | 26 | def exists(val): 27 | return val is not None 28 | 29 | def default(val, d): 30 | return val if exists(val) else d 31 | 32 | def eval_decorator(fn): 33 | def inner(model, *args, **kwargs): 34 | was_training = model.training 35 | model.eval() 36 | out = fn(model, *args, **kwargs) 37 | model.train(was_training) 38 | return out 39 | return inner 40 | 41 | def l2norm(t): 42 | return F.normalize(t, dim = -1) 43 | 44 | # tensor helpers 45 | 46 | def get_mask_subset_prob(mask, prob, min_mask = 0): 47 | batch, seq, device = *mask.shape, mask.device 48 | num_to_mask = (mask.sum(dim = -1, keepdim = True) * prob).clamp(min = min_mask) 49 | logits = torch.rand((batch, seq), device = device) 50 | logits = logits.masked_fill(~mask, -1) 51 | 52 | randperm = logits.argsort(dim = -1).argsort(dim = -1).float() 53 | 54 | num_padding = (~mask).sum(dim = -1, keepdim = True) 55 | randperm -= num_padding 56 | 57 | subset_mask = randperm < num_to_mask 58 | subset_mask.masked_fill_(~mask, False) 59 | return subset_mask 60 | 61 | # classes 62 | 63 | class LayerNorm(nn.Module): 64 | def __init__(self, dim): 65 | super().__init__() 66 | self.gamma = nn.Parameter(torch.ones(dim)) 67 | self.register_buffer('beta', torch.zeros(dim)) 68 | 69 | def forward(self, x): 70 | return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) 71 | 72 | class GEGLU(nn.Module): 73 | """ https://arxiv.org/abs/2002.05202 """ 74 | 75 | def forward(self, x): 76 | x, gate = x.chunk(2, dim = -1) 77 | return gate * F.gelu(x) 78 | 79 | def FeedForward(dim, mult = 4): 80 | """ https://arxiv.org/abs/2110.09456 """ 81 | 82 | inner_dim = int(dim * mult * 2 / 3) 83 | return nn.Sequential( 84 | LayerNorm(dim), 85 | nn.Linear(dim, inner_dim * 2, bias = False), 86 | GEGLU(), 87 | LayerNorm(inner_dim), 88 | nn.Linear(inner_dim, dim, bias = False) 89 | ) 90 | 91 | class Attention(nn.Module): 92 | def __init__( 93 | self, 94 | dim, 95 | dim_head = 64, 96 | heads = 8, 97 | cross_attend = False, 98 | scale = 8, 99 | flash = True, 100 | dropout = 0. 101 | ): 102 | super().__init__() 103 | self.scale = scale 104 | self.heads = heads 105 | inner_dim = dim_head * heads 106 | 107 | self.cross_attend = cross_attend 108 | self.norm = LayerNorm(dim) 109 | 110 | self.attend = Attend( 111 | flash = flash, 112 | dropout = dropout, 113 | scale = scale 114 | ) 115 | 116 | self.null_kv = nn.Parameter(torch.randn(2, heads, 1, dim_head)) 117 | 118 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 119 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) 120 | 121 | self.q_scale = nn.Parameter(torch.ones(dim_head)) 122 | self.k_scale = nn.Parameter(torch.ones(dim_head)) 123 | 124 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 125 | 126 | def forward( 127 | self, 128 | x, 129 | context = None, 130 | context_mask = None 131 | ): 132 | assert not (exists(context) ^ self.cross_attend) 133 | 134 | n = x.shape[-2] 135 | h, is_cross_attn = self.heads, exists(context) 136 | 137 | x = self.norm(x) 138 | 139 | kv_input = context if self.cross_attend else x 140 | 141 | q, k, v = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)) 142 | 143 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 144 | 145 | nk, nv = self.null_kv 146 | nk, nv = map(lambda t: repeat(t, 'h 1 d -> b h 1 d', b = x.shape[0]), (nk, nv)) 147 | 148 | k = torch.cat((nk, k), dim = -2) 149 | v = torch.cat((nv, v), dim = -2) 150 | 151 | q, k = map(l2norm, (q, k)) 152 | q = q * self.q_scale 153 | k = k * self.k_scale 154 | 155 | if exists(context_mask): 156 | context_mask = repeat(context_mask, 'b j -> b h i j', h = h, i = n) 157 | context_mask = F.pad(context_mask, (1, 0), value = True) 158 | 159 | out = self.attend(q, k, v, mask = context_mask) 160 | 161 | out = rearrange(out, 'b h n d -> b n (h d)') 162 | return self.to_out(out) 163 | 164 | class TransformerBlocks(nn.Module): 165 | def __init__( 166 | self, 167 | *, 168 | dim, 169 | depth, 170 | dim_head = 64, 171 | heads = 8, 172 | ff_mult = 4, 173 | flash = True 174 | ): 175 | super().__init__() 176 | self.layers = nn.ModuleList([]) 177 | 178 | for _ in range(depth): 179 | self.layers.append(nn.ModuleList([ 180 | Attention(dim = dim, dim_head = dim_head, heads = heads, flash = flash), 181 | Attention(dim = dim, dim_head = dim_head, heads = heads, cross_attend = True, flash = flash), 182 | FeedForward(dim = dim, mult = ff_mult) 183 | ])) 184 | 185 | self.norm = LayerNorm(dim) 186 | 187 | def forward(self, x, context = None, context_mask = None): 188 | for attn, cross_attn, ff in self.layers: 189 | x = attn(x) + x 190 | 191 | x = cross_attn(x, context = context, context_mask = context_mask) + x 192 | 193 | x = ff(x) + x 194 | 195 | return self.norm(x) 196 | 197 | # transformer - it's all we need 198 | 199 | class Transformer(nn.Module): 200 | def __init__( 201 | self, 202 | *, 203 | num_tokens, 204 | dim, 205 | seq_len, 206 | dim_out = None, 207 | t5_name = DEFAULT_T5_NAME, 208 | self_cond = False, 209 | add_mask_id = False, 210 | **kwargs 211 | ): 212 | super().__init__() 213 | self.dim = dim 214 | self.mask_id = num_tokens if add_mask_id else None 215 | 216 | self.num_tokens = num_tokens 217 | self.token_emb = nn.Embedding(num_tokens + int(add_mask_id), dim) 218 | self.pos_emb = nn.Embedding(seq_len, dim) 219 | self.seq_len = seq_len 220 | 221 | self.transformer_blocks = TransformerBlocks(dim = dim, **kwargs) 222 | self.norm = LayerNorm(dim) 223 | 224 | self.dim_out = default(dim_out, num_tokens) 225 | self.to_logits = nn.Linear(dim, self.dim_out, bias = False) 226 | 227 | # text conditioning 228 | 229 | self.encode_text = partial(t5_encode_text, name = t5_name) 230 | 231 | text_embed_dim = get_encoded_dim(t5_name) 232 | 233 | self.text_embed_proj = nn.Linear(text_embed_dim, dim, bias = False) if text_embed_dim != dim else nn.Identity() 234 | 235 | # optional self conditioning 236 | 237 | self.self_cond = self_cond 238 | self.self_cond_to_init_embed = FeedForward(dim) 239 | 240 | def forward_with_cond_scale( 241 | self, 242 | *args, 243 | cond_scale = 3., 244 | return_embed = False, 245 | **kwargs 246 | ): 247 | if cond_scale == 1: 248 | return self.forward(*args, return_embed = return_embed, cond_drop_prob = 0., **kwargs) 249 | 250 | logits, embed = self.forward(*args, return_embed = True, cond_drop_prob = 0., **kwargs) 251 | 252 | null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) 253 | 254 | scaled_logits = null_logits + (logits - null_logits) * cond_scale 255 | 256 | if return_embed: 257 | return scaled_logits, embed 258 | 259 | return scaled_logits 260 | 261 | def forward_with_neg_prompt( 262 | self, 263 | text_embed: torch.Tensor, 264 | neg_text_embed: torch.Tensor, 265 | cond_scale = 3., 266 | return_embed = False, 267 | **kwargs 268 | ): 269 | neg_logits = self.forward(*args, neg_text_embed = neg_text_embed, cond_drop_prob = 0., **kwargs) 270 | pos_logits, embed = self.forward(*args, return_embed = True, text_embed = text_embed, cond_drop_prob = 0., **kwargs) 271 | 272 | logits = neg_logits + (pos_logits - neg_logits) * cond_scale 273 | 274 | if return_embed: 275 | return scaled_logits, embed 276 | 277 | return scaled_logits 278 | 279 | def forward( 280 | self, 281 | x, 282 | return_embed = False, 283 | return_logits = False, 284 | labels = None, 285 | ignore_index = 0, 286 | self_cond_embed = None, 287 | cond_drop_prob = 0., 288 | conditioning_token_ids: Optional[torch.Tensor] = None, 289 | texts: Optional[List[str]] = None, 290 | text_embeds: Optional[torch.Tensor] = None 291 | ): 292 | device, b, n = x.device, *x.shape 293 | assert n <= self.seq_len 294 | 295 | # prepare texts 296 | 297 | assert exists(texts) ^ exists(text_embeds) 298 | 299 | if exists(texts): 300 | text_embeds = self.encode_text(texts) 301 | 302 | context = self.text_embed_proj(text_embeds) 303 | 304 | context_mask = (text_embeds != 0).any(dim = -1) 305 | 306 | # classifier free guidance 307 | 308 | if cond_drop_prob > 0.: 309 | mask = prob_mask_like((b, 1), 1. - cond_drop_prob, device) 310 | context_mask = context_mask & mask 311 | 312 | # concat conditioning image token ids if needed 313 | 314 | if exists(conditioning_token_ids): 315 | conditioning_token_ids = rearrange(conditioning_token_ids, 'b ... -> b (...)') 316 | cond_token_emb = self.token_emb(conditioning_token_ids) 317 | context = torch.cat((context, cond_token_emb), dim = -2) 318 | context_mask = F.pad(context_mask, (0, conditioning_token_ids.shape[-1]), value = True) 319 | 320 | # embed tokens 321 | 322 | x = self.token_emb(x) 323 | x = x + self.pos_emb(torch.arange(n, device = device)) 324 | 325 | if self.self_cond: 326 | if not exists(self_cond_embed): 327 | self_cond_embed = torch.zeros_like(x) 328 | x = x + self.self_cond_to_init_embed(self_cond_embed) 329 | 330 | embed = self.transformer_blocks(x, context = context, context_mask = context_mask) 331 | 332 | logits = self.to_logits(embed) 333 | 334 | if return_embed: 335 | return logits, embed 336 | 337 | if not exists(labels): 338 | return logits 339 | 340 | if self.dim_out == 1: 341 | loss = F.binary_cross_entropy_with_logits(rearrange(logits, '... 1 -> ...'), labels) 342 | else: 343 | loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels, ignore_index = ignore_index) 344 | 345 | if not return_logits: 346 | return loss 347 | 348 | return loss, logits 349 | 350 | # self critic wrapper 351 | 352 | class SelfCritic(nn.Module): 353 | def __init__(self, net): 354 | super().__init__() 355 | self.net = net 356 | self.to_pred = nn.Linear(net.dim, 1) 357 | 358 | def forward_with_cond_scale(self, x, *args, **kwargs): 359 | _, embeds = self.net.forward_with_cond_scale(x, *args, return_embed = True, **kwargs) 360 | return self.to_pred(embeds) 361 | 362 | def forward_with_neg_prompt(self, x, *args, **kwargs): 363 | _, embeds = self.net.forward_with_neg_prompt(x, *args, return_embed = True, **kwargs) 364 | return self.to_pred(embeds) 365 | 366 | def forward(self, x, *args, labels = None, **kwargs): 367 | _, embeds = self.net(x, *args, return_embed = True, **kwargs) 368 | logits = self.to_pred(embeds) 369 | 370 | if not exists(labels): 371 | return logits 372 | 373 | logits = rearrange(logits, '... 1 -> ...') 374 | return F.binary_cross_entropy_with_logits(logits, labels) 375 | 376 | # specialized transformers 377 | 378 | class MaskGitTransformer(Transformer): 379 | def __init__(self, *args, **kwargs): 380 | assert 'add_mask_id' not in kwargs 381 | super().__init__(*args, add_mask_id = True, **kwargs) 382 | 383 | class TokenCritic(Transformer): 384 | def __init__(self, *args, **kwargs): 385 | assert 'dim_out' not in kwargs 386 | super().__init__(*args, dim_out = 1, **kwargs) 387 | 388 | # classifier free guidance functions 389 | 390 | def uniform(shape, min = 0, max = 1, device = None): 391 | return torch.zeros(shape, device = device).float().uniform_(0, 1) 392 | 393 | def prob_mask_like(shape, prob, device = None): 394 | if prob == 1: 395 | return torch.ones(shape, device = device, dtype = torch.bool) 396 | elif prob == 0: 397 | return torch.zeros(shape, device = device, dtype = torch.bool) 398 | else: 399 | return uniform(shape, device = device) < prob 400 | 401 | # sampling helpers 402 | 403 | def log(t, eps = 1e-20): 404 | return torch.log(t.clamp(min = eps)) 405 | 406 | def gumbel_noise(t): 407 | noise = torch.zeros_like(t).uniform_(0, 1) 408 | return -log(-log(noise)) 409 | 410 | def gumbel_sample(t, temperature = 1., dim = -1): 411 | return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim) 412 | 413 | def top_k(logits, thres = 0.9): 414 | k = math.ceil((1 - thres) * logits.shape[-1]) 415 | val, ind = logits.topk(k, dim = -1) 416 | probs = torch.full_like(logits, float('-inf')) 417 | probs.scatter_(2, ind, val) 418 | return probs 419 | 420 | # noise schedules 421 | 422 | def cosine_schedule(t): 423 | return torch.cos(t * math.pi * 0.5) 424 | 425 | # main maskgit classes 426 | 427 | @beartype 428 | class MaskGit(nn.Module): 429 | def __init__( 430 | self, 431 | image_size, 432 | transformer: MaskGitTransformer, 433 | noise_schedule: Callable = cosine_schedule, 434 | token_critic: Optional[TokenCritic] = None, 435 | self_token_critic = False, 436 | vae: Optional[VQGanVAE] = None, 437 | cond_vae: Optional[VQGanVAE] = None, 438 | cond_image_size = None, 439 | cond_drop_prob = 0.5, 440 | self_cond_prob = 0.9, 441 | no_mask_token_prob = 0., 442 | critic_loss_weight = 1. 443 | ): 444 | super().__init__() 445 | self.vae = vae.copy_for_eval() if exists(vae) else None 446 | 447 | if exists(cond_vae): 448 | self.cond_vae = cond_vae.eval() 449 | else: 450 | self.cond_vae = self.vae 451 | 452 | assert not (exists(cond_vae) and not exists(cond_image_size)), 'cond_image_size must be specified if conditioning' 453 | 454 | self.image_size = image_size 455 | self.cond_image_size = cond_image_size 456 | self.resize_image_for_cond_image = exists(cond_image_size) 457 | 458 | self.cond_drop_prob = cond_drop_prob 459 | 460 | self.transformer = transformer 461 | self.self_cond = transformer.self_cond 462 | assert self.vae.codebook_size == self.cond_vae.codebook_size == transformer.num_tokens, 'transformer num_tokens must be set to be equal to the vae codebook size' 463 | 464 | self.mask_id = transformer.mask_id 465 | self.noise_schedule = noise_schedule 466 | 467 | assert not (self_token_critic and exists(token_critic)) 468 | self.token_critic = token_critic 469 | 470 | if self_token_critic: 471 | self.token_critic = SelfCritic(transformer) 472 | 473 | self.critic_loss_weight = critic_loss_weight 474 | 475 | # self conditioning 476 | self.self_cond_prob = self_cond_prob 477 | 478 | # percentage of tokens to be [mask]ed to remain the same token, so that transformer produces better embeddings across all tokens as done in original BERT paper 479 | # may be needed for self conditioning 480 | self.no_mask_token_prob = no_mask_token_prob 481 | 482 | def save(self, path): 483 | torch.save(self.state_dict(), path) 484 | 485 | def load(self, path): 486 | path = Path(path) 487 | assert path.exists() 488 | state_dict = torch.load(str(path)) 489 | self.load_state_dict(state_dict) 490 | 491 | @torch.no_grad() 492 | @eval_decorator 493 | def generate( 494 | self, 495 | texts: List[str], 496 | negative_texts: Optional[List[str]] = None, 497 | cond_images: Optional[torch.Tensor] = None, 498 | fmap_size = None, 499 | temperature = 1., 500 | topk_filter_thres = 0.9, 501 | can_remask_prev_masked = False, 502 | force_not_use_token_critic = False, 503 | timesteps = 18, # ideal number of steps is 18 in maskgit paper 504 | cond_scale = 3, 505 | critic_noise_scale = 1 506 | ): 507 | fmap_size = default(fmap_size, self.vae.get_encoded_fmap_size(self.image_size)) 508 | 509 | # begin with all image token ids masked 510 | 511 | device = next(self.parameters()).device 512 | 513 | seq_len = fmap_size ** 2 514 | 515 | batch_size = len(texts) 516 | 517 | shape = (batch_size, seq_len) 518 | 519 | ids = torch.full(shape, self.mask_id, dtype = torch.long, device = device) 520 | scores = torch.zeros(shape, dtype = torch.float32, device = device) 521 | 522 | starting_temperature = temperature 523 | 524 | cond_ids = None 525 | 526 | text_embeds = self.transformer.encode_text(texts) 527 | 528 | demask_fn = self.transformer.forward_with_cond_scale 529 | 530 | # whether to use token critic for scores 531 | 532 | use_token_critic = exists(self.token_critic) and not force_not_use_token_critic 533 | 534 | if use_token_critic: 535 | token_critic_fn = self.token_critic.forward_with_cond_scale 536 | 537 | # negative prompting, as in paper 538 | 539 | neg_text_embeds = None 540 | if exists(negative_texts): 541 | assert len(texts) == len(negative_texts) 542 | 543 | neg_text_embeds = self.transformer.encode_text(negative_texts) 544 | demask_fn = partial(self.transformer.forward_with_neg_prompt, neg_text_embeds = neg_text_embeds) 545 | 546 | if use_token_critic: 547 | token_critic_fn = partial(self.token_critic.forward_with_neg_prompt, neg_text_embeds = neg_text_embeds) 548 | 549 | if self.resize_image_for_cond_image: 550 | assert exists(cond_images), 'conditioning image must be passed in to generate for super res maskgit' 551 | with torch.no_grad(): 552 | _, cond_ids, _ = self.cond_vae.encode(cond_images) 553 | 554 | self_cond_embed = None 555 | 556 | for timestep, steps_until_x0 in tqdm(zip(torch.linspace(0, 1, timesteps, device = device), reversed(range(timesteps))), total = timesteps): 557 | 558 | rand_mask_prob = self.noise_schedule(timestep) 559 | num_token_masked = max(int((rand_mask_prob * seq_len).item()), 1) 560 | 561 | masked_indices = scores.topk(num_token_masked, dim = -1).indices 562 | 563 | ids = ids.scatter(1, masked_indices, self.mask_id) 564 | 565 | logits, embed = demask_fn( 566 | ids, 567 | text_embeds = text_embeds, 568 | self_cond_embed = self_cond_embed, 569 | conditioning_token_ids = cond_ids, 570 | cond_scale = cond_scale, 571 | return_embed = True 572 | ) 573 | 574 | self_cond_embed = embed if self.self_cond else None 575 | 576 | filtered_logits = top_k(logits, topk_filter_thres) 577 | 578 | temperature = starting_temperature * (steps_until_x0 / timesteps) # temperature is annealed 579 | 580 | pred_ids = gumbel_sample(filtered_logits, temperature = temperature, dim = -1) 581 | 582 | is_mask = ids == self.mask_id 583 | 584 | ids = torch.where( 585 | is_mask, 586 | pred_ids, 587 | ids 588 | ) 589 | 590 | if use_token_critic: 591 | scores = token_critic_fn( 592 | ids, 593 | text_embeds = text_embeds, 594 | conditioning_token_ids = cond_ids, 595 | cond_scale = cond_scale 596 | ) 597 | 598 | scores = rearrange(scores, '... 1 -> ...') 599 | 600 | scores = scores + (uniform(scores.shape, device = device) - 0.5) * critic_noise_scale * (steps_until_x0 / timesteps) 601 | 602 | else: 603 | probs_without_temperature = logits.softmax(dim = -1) 604 | 605 | scores = 1 - probs_without_temperature.gather(2, pred_ids[..., None]) 606 | scores = rearrange(scores, '... 1 -> ...') 607 | 608 | if not can_remask_prev_masked: 609 | scores = scores.masked_fill(~is_mask, -1e5) 610 | else: 611 | assert self.no_mask_token_prob > 0., 'without training with some of the non-masked tokens forced to predict, not sure if the logits will be meaningful for these token' 612 | 613 | # get ids 614 | 615 | ids = rearrange(ids, 'b (i j) -> b i j', i = fmap_size, j = fmap_size) 616 | 617 | if not exists(self.vae): 618 | return ids 619 | 620 | images = self.vae.decode_from_ids(ids) 621 | return images 622 | 623 | def forward( 624 | self, 625 | images_or_ids: torch.Tensor, 626 | ignore_index = -1, 627 | cond_images: Optional[torch.Tensor] = None, 628 | cond_token_ids: Optional[torch.Tensor] = None, 629 | texts: Optional[List[str]] = None, 630 | text_embeds: Optional[torch.Tensor] = None, 631 | cond_drop_prob = None, 632 | train_only_generator = False, 633 | sample_temperature = None 634 | ): 635 | # tokenize if needed 636 | 637 | if images_or_ids.dtype == torch.float: 638 | assert exists(self.vae), 'vqgan vae must be passed in if training from raw images' 639 | assert all([height_or_width == self.image_size for height_or_width in images_or_ids.shape[-2:]]), 'the image you passed in is not of the correct dimensions' 640 | 641 | with torch.no_grad(): 642 | _, ids, _ = self.vae.encode(images_or_ids) 643 | else: 644 | assert not self.resize_image_for_cond_image, 'you cannot pass in raw image token ids if you want the framework to autoresize image for conditioning super res transformer' 645 | ids = images_or_ids 646 | 647 | # take care of conditioning image if specified 648 | 649 | if self.resize_image_for_cond_image: 650 | cond_images_or_ids = F.interpolate(images_or_ids, self.cond_image_size, mode = 'nearest') 651 | 652 | # get some basic variables 653 | 654 | ids = rearrange(ids, 'b ... -> b (...)') 655 | 656 | batch, seq_len, device, cond_drop_prob = *ids.shape, ids.device, default(cond_drop_prob, self.cond_drop_prob) 657 | 658 | # tokenize conditional images if needed 659 | 660 | assert not (exists(cond_images) and exists(cond_token_ids)), 'if conditioning on low resolution, cannot pass in both images and token ids' 661 | 662 | if exists(cond_images): 663 | assert exists(self.cond_vae), 'cond vqgan vae must be passed in' 664 | assert all([height_or_width == self.cond_image_size for height_or_width in cond_images.shape[-2:]]) 665 | 666 | with torch.no_grad(): 667 | _, cond_token_ids, _ = self.cond_vae.encode(cond_images) 668 | 669 | # prepare mask 670 | 671 | rand_time = uniform((batch,), device = device) 672 | rand_mask_probs = self.noise_schedule(rand_time) 673 | num_token_masked = (seq_len * rand_mask_probs).round().clamp(min = 1) 674 | 675 | mask_id = self.mask_id 676 | batch_randperm = torch.rand((batch, seq_len), device = device).argsort(dim = -1) 677 | mask = batch_randperm < rearrange(num_token_masked, 'b -> b 1') 678 | 679 | mask_id = self.transformer.mask_id 680 | labels = torch.where(mask, ids, ignore_index) 681 | 682 | if self.no_mask_token_prob > 0.: 683 | no_mask_mask = get_mask_subset_prob(mask, self.no_mask_token_prob) 684 | mask &= ~no_mask_mask 685 | 686 | x = torch.where(mask, mask_id, ids) 687 | 688 | # get text embeddings 689 | 690 | if exists(texts): 691 | text_embeds = self.transformer.encode_text(texts) 692 | texts = None 693 | 694 | # self conditioning 695 | 696 | self_cond_embed = None 697 | 698 | if self.transformer.self_cond and random() < self.self_cond_prob: 699 | with torch.no_grad(): 700 | _, self_cond_embed = self.transformer( 701 | x, 702 | text_embeds = text_embeds, 703 | conditioning_token_ids = cond_token_ids, 704 | cond_drop_prob = 0., 705 | return_embed = True 706 | ) 707 | 708 | self_cond_embed.detach_() 709 | 710 | # get loss 711 | 712 | ce_loss, logits = self.transformer( 713 | x, 714 | text_embeds = text_embeds, 715 | self_cond_embed = self_cond_embed, 716 | conditioning_token_ids = cond_token_ids, 717 | labels = labels, 718 | cond_drop_prob = cond_drop_prob, 719 | ignore_index = ignore_index, 720 | return_logits = True 721 | ) 722 | 723 | if not exists(self.token_critic) or train_only_generator: 724 | return ce_loss 725 | 726 | # token critic loss 727 | 728 | sampled_ids = gumbel_sample(logits, temperature = default(sample_temperature, random())) 729 | 730 | critic_input = torch.where(mask, sampled_ids, x) 731 | critic_labels = (ids != critic_input).float() 732 | 733 | bce_loss = self.token_critic( 734 | critic_input, 735 | text_embeds = text_embeds, 736 | conditioning_token_ids = cond_token_ids, 737 | labels = critic_labels, 738 | cond_drop_prob = cond_drop_prob 739 | ) 740 | 741 | return ce_loss + self.critic_loss_weight * bce_loss 742 | 743 | # final Muse class 744 | 745 | @beartype 746 | class Muse(nn.Module): 747 | def __init__( 748 | self, 749 | base: MaskGit, 750 | superres: MaskGit 751 | ): 752 | super().__init__() 753 | self.base_maskgit = base.eval() 754 | 755 | assert superres.resize_image_for_cond_image 756 | self.superres_maskgit = superres.eval() 757 | 758 | @torch.no_grad() 759 | def forward( 760 | self, 761 | texts: List[str], 762 | cond_scale = 3., 763 | temperature = 1., 764 | timesteps = 18, 765 | superres_timesteps = None, 766 | return_lowres = False, 767 | return_pil_images = True 768 | ): 769 | lowres_image = self.base_maskgit.generate( 770 | texts = texts, 771 | cond_scale = cond_scale, 772 | temperature = temperature, 773 | timesteps = timesteps 774 | ) 775 | 776 | superres_image = self.superres_maskgit.generate( 777 | texts = texts, 778 | cond_scale = cond_scale, 779 | cond_images = lowres_image, 780 | temperature = temperature, 781 | timesteps = default(superres_timesteps, timesteps) 782 | ) 783 | 784 | if return_pil_images: 785 | lowres_image = list(map(T.ToPILImage(), lowres_image)) 786 | superres_image = list(map(T.ToPILImage(), superres_image)) 787 | 788 | if not return_lowres: 789 | return superres_image 790 | 791 | return superres_image, lowres_image 792 | -------------------------------------------------------------------------------- /muse_maskgit_pytorch/t5.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import transformers 4 | from transformers import T5Tokenizer, T5EncoderModel, T5Config 5 | 6 | from beartype import beartype 7 | from typing import List, Union 8 | 9 | transformers.logging.set_verbosity_error() 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | # config 15 | 16 | MAX_LENGTH = 256 17 | 18 | DEFAULT_T5_NAME = 'google/t5-v1_1-base' 19 | 20 | T5_CONFIGS = {} 21 | 22 | # singleton globals 23 | 24 | def get_tokenizer(name): 25 | tokenizer = T5Tokenizer.from_pretrained(name) 26 | return tokenizer 27 | 28 | def get_model(name): 29 | model = T5EncoderModel.from_pretrained(name) 30 | return model 31 | 32 | def get_model_and_tokenizer(name): 33 | global T5_CONFIGS 34 | 35 | if name not in T5_CONFIGS: 36 | T5_CONFIGS[name] = dict() 37 | if "model" not in T5_CONFIGS[name]: 38 | T5_CONFIGS[name]["model"] = get_model(name) 39 | if "tokenizer" not in T5_CONFIGS[name]: 40 | T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name) 41 | 42 | return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer'] 43 | 44 | def get_encoded_dim(name): 45 | if name not in T5_CONFIGS: 46 | # avoids loading the model if we only want to get the dim 47 | config = T5Config.from_pretrained(name) 48 | T5_CONFIGS[name] = dict(config=config) 49 | elif "config" in T5_CONFIGS[name]: 50 | config = T5_CONFIGS[name]["config"] 51 | elif "model" in T5_CONFIGS[name]: 52 | config = T5_CONFIGS[name]["model"].config 53 | else: 54 | assert False 55 | return config.d_model 56 | 57 | # encoding text 58 | 59 | @beartype 60 | def t5_encode_text( 61 | texts: Union[str, List[str]], 62 | name = DEFAULT_T5_NAME, 63 | output_device = None 64 | ): 65 | if isinstance(texts, str): 66 | texts = [texts] 67 | 68 | t5, tokenizer = get_model_and_tokenizer(name) 69 | 70 | if torch.cuda.is_available(): 71 | t5 = t5.cuda() 72 | 73 | device = next(t5.parameters()).device 74 | 75 | encoded = tokenizer.batch_encode_plus( 76 | texts, 77 | return_tensors = "pt", 78 | padding = 'longest', 79 | max_length = MAX_LENGTH, 80 | truncation = True 81 | ) 82 | 83 | input_ids = encoded.input_ids.to(device) 84 | attn_mask = encoded.attention_mask.to(device) 85 | 86 | t5.eval() 87 | 88 | with torch.no_grad(): 89 | output = t5(input_ids = input_ids, attention_mask = attn_mask) 90 | encoded_text = output.last_hidden_state.detach() 91 | 92 | attn_mask = attn_mask.bool() 93 | encoded_text = encoded_text.masked_fill(~attn_mask[..., None], 0.) 94 | 95 | if not exists(output_device): 96 | return encoded_text 97 | 98 | encoded_text.to(output_device) 99 | return encoded_text 100 | -------------------------------------------------------------------------------- /muse_maskgit_pytorch/trainers.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | from random import choice 3 | from pathlib import Path 4 | from shutil import rmtree 5 | from functools import partial 6 | 7 | from beartype import beartype 8 | 9 | import torch 10 | from torch import nn 11 | from torch.optim import Adam 12 | from torch.utils.data import Dataset, DataLoader, random_split 13 | 14 | import torchvision.transforms as T 15 | from torchvision.datasets import ImageFolder 16 | from torchvision.utils import make_grid, save_image 17 | 18 | from muse_maskgit_pytorch.vqgan_vae import VQGanVAE 19 | 20 | from einops import rearrange 21 | 22 | from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs 23 | 24 | from ema_pytorch import EMA 25 | 26 | from PIL import Image, ImageFile 27 | ImageFile.LOAD_TRUNCATED_IMAGES = True 28 | 29 | # helper functions 30 | 31 | def exists(val): 32 | return val is not None 33 | 34 | def identity(t, *args, **kwargs): 35 | return t 36 | 37 | def noop(*args, **kwargs): 38 | pass 39 | 40 | def find_index(arr, cond): 41 | for ind, el in enumerate(arr): 42 | if cond(el): 43 | return ind 44 | return None 45 | 46 | def find_and_pop(arr, cond, default = None): 47 | ind = find_index(arr, cond) 48 | 49 | if exists(ind): 50 | return arr.pop(ind) 51 | 52 | if callable(default): 53 | return default() 54 | 55 | return default 56 | 57 | def cycle(dl): 58 | while True: 59 | for data in dl: 60 | yield data 61 | 62 | def cast_tuple(t): 63 | return t if isinstance(t, (tuple, list)) else (t,) 64 | 65 | def yes_or_no(question): 66 | answer = input(f'{question} (y/n) ') 67 | return answer.lower() in ('yes', 'y') 68 | 69 | def accum_log(log, new_logs): 70 | for key, new_value in new_logs.items(): 71 | old_value = log.get(key, 0.) 72 | log[key] = old_value + new_value 73 | return log 74 | 75 | def pair(val): 76 | return val if isinstance(val, tuple) else (val, val) 77 | 78 | def convert_image_to_fn(img_type, image): 79 | if image.mode != img_type: 80 | return image.convert(img_type) 81 | return image 82 | 83 | # image related helpers fnuctions and dataset 84 | 85 | class ImageDataset(Dataset): 86 | def __init__( 87 | self, 88 | folder, 89 | image_size, 90 | exts = ['jpg', 'jpeg', 'png'] 91 | ): 92 | super().__init__() 93 | self.folder = folder 94 | self.image_size = image_size 95 | self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')] 96 | 97 | print(f'{len(self.paths)} training samples found at {folder}') 98 | 99 | self.transform = T.Compose([ 100 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 101 | T.Resize(image_size), 102 | T.RandomHorizontalFlip(), 103 | T.CenterCrop(image_size), 104 | T.ToTensor() 105 | ]) 106 | 107 | def __len__(self): 108 | return len(self.paths) 109 | 110 | def __getitem__(self, index): 111 | path = self.paths[index] 112 | img = Image.open(path) 113 | return self.transform(img) 114 | 115 | # main trainer class 116 | 117 | @beartype 118 | class VQGanVAETrainer(nn.Module): 119 | def __init__( 120 | self, 121 | vae: VQGanVAE, 122 | *, 123 | folder, 124 | num_train_steps, 125 | batch_size, 126 | image_size, 127 | lr = 3e-4, 128 | grad_accum_every = 1, 129 | max_grad_norm = None, 130 | discr_max_grad_norm = None, 131 | save_results_every = 100, 132 | save_model_every = 1000, 133 | results_folder = './results', 134 | valid_frac = 0.05, 135 | random_split_seed = 42, 136 | use_ema = True, 137 | ema_beta = 0.995, 138 | ema_update_after_step = 0, 139 | ema_update_every = 1, 140 | apply_grad_penalty_every = 4, 141 | accelerate_kwargs: dict = dict() 142 | ): 143 | super().__init__() 144 | 145 | # instantiate accelerator 146 | 147 | kwargs_handlers = accelerate_kwargs.get('kwargs_handlers', []) 148 | 149 | ddp_kwargs = find_and_pop( 150 | kwargs_handlers, 151 | lambda x: isinstance(x, DistributedDataParallelKwargs), 152 | partial(DistributedDataParallelKwargs, find_unused_parameters = True) 153 | ) 154 | 155 | ddp_kwargs.find_unused_parameters = True 156 | kwargs_handlers.append(ddp_kwargs) 157 | accelerate_kwargs.update(kwargs_handlers = kwargs_handlers) 158 | 159 | self.accelerator = Accelerator(**accelerate_kwargs) 160 | 161 | # vae 162 | 163 | self.vae = vae 164 | 165 | # training params 166 | 167 | self.register_buffer('steps', torch.Tensor([0])) 168 | 169 | self.num_train_steps = num_train_steps 170 | self.batch_size = batch_size 171 | self.grad_accum_every = grad_accum_every 172 | 173 | all_parameters = set(vae.parameters()) 174 | discr_parameters = set(vae.discr.parameters()) 175 | vae_parameters = all_parameters - discr_parameters 176 | 177 | self.vae_parameters = vae_parameters 178 | 179 | # optimizers 180 | 181 | self.optim = Adam(vae_parameters, lr = lr) 182 | self.discr_optim = Adam(discr_parameters, lr = lr) 183 | 184 | self.max_grad_norm = max_grad_norm 185 | self.discr_max_grad_norm = discr_max_grad_norm 186 | 187 | # create dataset 188 | 189 | self.ds = ImageDataset(folder, image_size) 190 | 191 | # split for validation 192 | 193 | if valid_frac > 0: 194 | train_size = int((1 - valid_frac) * len(self.ds)) 195 | valid_size = len(self.ds) - train_size 196 | self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed)) 197 | self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples') 198 | else: 199 | self.valid_ds = self.ds 200 | self.print(f'training with shared training and valid dataset of {len(self.ds)} samples') 201 | 202 | # dataloader 203 | 204 | self.dl = DataLoader( 205 | self.ds, 206 | batch_size = batch_size, 207 | shuffle = True 208 | ) 209 | 210 | self.valid_dl = DataLoader( 211 | self.valid_ds, 212 | batch_size = batch_size, 213 | shuffle = True 214 | ) 215 | 216 | # prepare with accelerator 217 | 218 | ( 219 | self.vae, 220 | self.optim, 221 | self.discr_optim, 222 | self.dl, 223 | self.valid_dl 224 | ) = self.accelerator.prepare( 225 | self.vae, 226 | self.optim, 227 | self.discr_optim, 228 | self.dl, 229 | self.valid_dl 230 | ) 231 | 232 | self.use_ema = use_ema 233 | 234 | if use_ema: 235 | self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every) 236 | self.ema_vae = self.accelerator.prepare(self.ema_vae) 237 | 238 | self.dl_iter = cycle(self.dl) 239 | self.valid_dl_iter = cycle(self.valid_dl) 240 | 241 | self.save_model_every = save_model_every 242 | self.save_results_every = save_results_every 243 | 244 | self.apply_grad_penalty_every = apply_grad_penalty_every 245 | 246 | self.results_folder = Path(results_folder) 247 | 248 | if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'): 249 | rmtree(str(self.results_folder)) 250 | 251 | self.results_folder.mkdir(parents = True, exist_ok = True) 252 | 253 | def save(self, path): 254 | if not self.accelerator.is_local_main_process: 255 | return 256 | 257 | pkg = dict( 258 | model = self.accelerator.get_state_dict(self.vae), 259 | optim = self.optim.state_dict(), 260 | discr_optim = self.discr_optim.state_dict() 261 | ) 262 | torch.save(pkg, path) 263 | 264 | def load(self, path): 265 | path = Path(path) 266 | assert path.exists() 267 | pkg = torch.load(path) 268 | 269 | vae = self.accelerator.unwrap_model(self.vae) 270 | vae.load_state_dict(pkg['model']) 271 | 272 | self.optim.load_state_dict(pkg['optim']) 273 | self.discr_optim.load_state_dict(pkg['discr_optim']) 274 | 275 | def print(self, msg): 276 | self.accelerator.print(msg) 277 | 278 | @property 279 | def device(self): 280 | return self.accelerator.device 281 | 282 | @property 283 | def is_distributed(self): 284 | return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1) 285 | 286 | @property 287 | def is_main(self): 288 | return self.accelerator.is_main_process 289 | 290 | @property 291 | def is_local_main(self): 292 | return self.accelerator.is_local_main_process 293 | 294 | def train_step(self): 295 | device = self.device 296 | 297 | steps = int(self.steps.item()) 298 | apply_grad_penalty = not (steps % self.apply_grad_penalty_every) 299 | 300 | self.vae.train() 301 | discr = self.vae.module.discr if self.is_distributed else self.vae.discr 302 | if self.use_ema: 303 | ema_vae = self.ema_vae.module if self.is_distributed else self.ema_vae 304 | 305 | # logs 306 | 307 | logs = {} 308 | 309 | # update vae (generator) 310 | 311 | for _ in range(self.grad_accum_every): 312 | img = next(self.dl_iter) 313 | img = img.to(device) 314 | 315 | with self.accelerator.autocast(): 316 | loss = self.vae( 317 | img, 318 | add_gradient_penalty = apply_grad_penalty, 319 | return_loss = True 320 | ) 321 | 322 | self.accelerator.backward(loss / self.grad_accum_every) 323 | 324 | accum_log(logs, {'loss': loss.item() / self.grad_accum_every}) 325 | 326 | if exists(self.max_grad_norm): 327 | self.accelerator.clip_grad_norm_(self.vae.parameters(), self.max_grad_norm) 328 | 329 | self.optim.step() 330 | self.optim.zero_grad() 331 | 332 | # update discriminator 333 | 334 | if exists(discr): 335 | self.discr_optim.zero_grad() 336 | 337 | for _ in range(self.grad_accum_every): 338 | img = next(self.dl_iter) 339 | img = img.to(device) 340 | 341 | loss = self.vae(img, return_discr_loss = True) 342 | 343 | self.accelerator.backward(loss / self.grad_accum_every) 344 | 345 | accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every}) 346 | 347 | if exists(self.discr_max_grad_norm): 348 | self.accelerator.clip_grad_norm_(discr.parameters(), self.discr_max_grad_norm) 349 | 350 | self.discr_optim.step() 351 | 352 | # log 353 | 354 | self.print(f"{steps}: vae loss: {logs['loss']} - discr loss: {logs['discr_loss']}") 355 | 356 | # update exponential moving averaged generator 357 | 358 | if self.use_ema: 359 | ema_vae.update() 360 | 361 | # sample results every so often 362 | 363 | if not (steps % self.save_results_every): 364 | vaes_to_evaluate = ((self.vae, str(steps)),) 365 | 366 | if self.use_ema: 367 | vaes_to_evaluate = ((ema_vae.ema_model, f'{steps}.ema'),) + vaes_to_evaluate 368 | 369 | for model, filename in vaes_to_evaluate: 370 | model.eval() 371 | 372 | valid_data = next(self.valid_dl_iter) 373 | valid_data = valid_data.to(device) 374 | 375 | recons = model(valid_data, return_recons = True) 376 | 377 | # else save a grid of images 378 | 379 | imgs_and_recons = torch.stack((valid_data, recons), dim = 0) 380 | imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...') 381 | 382 | imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.) 383 | grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1)) 384 | 385 | logs['reconstructions'] = grid 386 | 387 | save_image(grid, str(self.results_folder / f'{filename}.png')) 388 | 389 | self.print(f'{steps}: saving to {str(self.results_folder)}') 390 | 391 | # save model every so often 392 | self.accelerator.wait_for_everyone() 393 | if self.is_main and not (steps % self.save_model_every): 394 | state_dict = self.accelerator.unwrap_model(self.vae).state_dict() 395 | model_path = str(self.results_folder / f'vae.{steps}.pt') 396 | self.accelerator.save(state_dict, model_path) 397 | 398 | if self.use_ema: 399 | ema_state_dict = self.accelerator.unwrap_model(self.ema_vae).state_dict() 400 | model_path = str(self.results_folder / f'vae.{steps}.ema.pt') 401 | self.accelerator.save(ema_state_dict, model_path) 402 | 403 | self.print(f'{steps}: saving model to {str(self.results_folder)}') 404 | 405 | self.steps += 1 406 | return logs 407 | 408 | def train(self, log_fn = noop): 409 | device = next(self.vae.parameters()).device 410 | 411 | while self.steps < self.num_train_steps: 412 | logs = self.train_step() 413 | log_fn(logs) 414 | 415 | self.print('training complete') 416 | -------------------------------------------------------------------------------- /muse_maskgit_pytorch/vqgan_vae.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import copy 3 | import math 4 | from math import sqrt 5 | from functools import partial, wraps 6 | 7 | from vector_quantize_pytorch import VectorQuantize as VQ, LFQ 8 | 9 | import torch 10 | from torch import nn, einsum 11 | import torch.nn.functional as F 12 | from torch.autograd import grad as torch_grad 13 | 14 | import torchvision 15 | 16 | from einops import rearrange, reduce, repeat, pack, unpack 17 | from einops.layers.torch import Rearrange 18 | 19 | # constants 20 | 21 | MList = nn.ModuleList 22 | 23 | # helper functions 24 | 25 | def exists(val): 26 | return val is not None 27 | 28 | def default(val, d): 29 | return val if exists(val) else d 30 | 31 | # decorators 32 | 33 | def eval_decorator(fn): 34 | def inner(model, *args, **kwargs): 35 | was_training = model.training 36 | model.eval() 37 | out = fn(model, *args, **kwargs) 38 | model.train(was_training) 39 | return out 40 | return inner 41 | 42 | def remove_vgg(fn): 43 | @wraps(fn) 44 | def inner(self, *args, **kwargs): 45 | has_vgg = hasattr(self, '_vgg') 46 | if has_vgg: 47 | vgg = self._vgg 48 | delattr(self, '_vgg') 49 | 50 | out = fn(self, *args, **kwargs) 51 | 52 | if has_vgg: 53 | self._vgg = vgg 54 | 55 | return out 56 | return inner 57 | 58 | # keyword argument helpers 59 | 60 | def pick_and_pop(keys, d): 61 | values = list(map(lambda key: d.pop(key), keys)) 62 | return dict(zip(keys, values)) 63 | 64 | def group_dict_by_key(cond, d): 65 | return_val = [dict(),dict()] 66 | for key in d.keys(): 67 | match = bool(cond(key)) 68 | ind = int(not match) 69 | return_val[ind][key] = d[key] 70 | return (*return_val,) 71 | 72 | def string_begins_with(prefix, string_input): 73 | return string_input.startswith(prefix) 74 | 75 | def group_by_key_prefix(prefix, d): 76 | return group_dict_by_key(partial(string_begins_with, prefix), d) 77 | 78 | def groupby_prefix_and_trim(prefix, d): 79 | kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) 80 | kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) 81 | return kwargs_without_prefix, kwargs 82 | 83 | # tensor helper functions 84 | 85 | def log(t, eps = 1e-10): 86 | return torch.log(t + eps) 87 | 88 | def gradient_penalty(images, output, weight = 10): 89 | batch_size = images.shape[0] 90 | 91 | gradients = torch_grad( 92 | outputs = output, 93 | inputs = images, 94 | grad_outputs = torch.ones(output.size(), device = images.device), 95 | create_graph = True, 96 | retain_graph = True, 97 | only_inputs = True 98 | )[0] 99 | 100 | gradients = rearrange(gradients, 'b ... -> b (...)') 101 | return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean() 102 | 103 | def leaky_relu(p = 0.1): 104 | return nn.LeakyReLU(0.1) 105 | 106 | def safe_div(numer, denom, eps = 1e-8): 107 | return numer / denom.clamp(min = eps) 108 | 109 | # gan losses 110 | 111 | def hinge_discr_loss(fake, real): 112 | return (F.relu(1 + fake) + F.relu(1 - real)).mean() 113 | 114 | def hinge_gen_loss(fake): 115 | return -fake.mean() 116 | 117 | def bce_discr_loss(fake, real): 118 | return (-log(1 - torch.sigmoid(fake)) - log(torch.sigmoid(real))).mean() 119 | 120 | def bce_gen_loss(fake): 121 | return -log(torch.sigmoid(fake)).mean() 122 | 123 | def grad_layer_wrt_loss(loss, layer): 124 | return torch_grad( 125 | outputs = loss, 126 | inputs = layer, 127 | grad_outputs = torch.ones_like(loss), 128 | retain_graph = True 129 | )[0].detach() 130 | 131 | # vqgan vae 132 | 133 | class LayerNormChan(nn.Module): 134 | def __init__( 135 | self, 136 | dim, 137 | eps = 1e-5 138 | ): 139 | super().__init__() 140 | self.eps = eps 141 | self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1)) 142 | 143 | def forward(self, x): 144 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True) 145 | mean = torch.mean(x, dim = 1, keepdim = True) 146 | return (x - mean) * var.clamp(min = self.eps).rsqrt() * self.gamma 147 | 148 | # discriminator 149 | 150 | class Discriminator(nn.Module): 151 | def __init__( 152 | self, 153 | dims, 154 | channels = 3, 155 | groups = 16, 156 | init_kernel_size = 5 157 | ): 158 | super().__init__() 159 | dim_pairs = zip(dims[:-1], dims[1:]) 160 | 161 | self.layers = MList([nn.Sequential(nn.Conv2d(channels, dims[0], init_kernel_size, padding = init_kernel_size // 2), leaky_relu())]) 162 | 163 | for dim_in, dim_out in dim_pairs: 164 | self.layers.append(nn.Sequential( 165 | nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), 166 | nn.GroupNorm(groups, dim_out), 167 | leaky_relu() 168 | )) 169 | 170 | dim = dims[-1] 171 | self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training 172 | nn.Conv2d(dim, dim, 1), 173 | leaky_relu(), 174 | nn.Conv2d(dim, 1, 4) 175 | ) 176 | 177 | def forward(self, x): 178 | for net in self.layers: 179 | x = net(x) 180 | 181 | return self.to_logits(x) 182 | 183 | # resnet encoder / decoder 184 | 185 | class ResnetEncDec(nn.Module): 186 | def __init__( 187 | self, 188 | dim, 189 | *, 190 | channels = 3, 191 | layers = 4, 192 | layer_mults = None, 193 | num_resnet_blocks = 1, 194 | resnet_groups = 16, 195 | first_conv_kernel_size = 5 196 | ): 197 | super().__init__() 198 | assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)' 199 | 200 | self.layers = layers 201 | 202 | self.encoders = MList([]) 203 | self.decoders = MList([]) 204 | 205 | layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers)))) 206 | assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers' 207 | 208 | layer_dims = [dim * mult for mult in layer_mults] 209 | dims = (dim, *layer_dims) 210 | 211 | self.encoded_dim = dims[-1] 212 | 213 | dim_pairs = zip(dims[:-1], dims[1:]) 214 | 215 | append = lambda arr, t: arr.append(t) 216 | prepend = lambda arr, t: arr.insert(0, t) 217 | 218 | if not isinstance(num_resnet_blocks, tuple): 219 | num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks) 220 | 221 | assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers' 222 | 223 | for layer_index, (dim_in, dim_out), layer_num_resnet_blocks in zip(range(layers), dim_pairs, num_resnet_blocks): 224 | append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu())) 225 | prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu())) 226 | 227 | for _ in range(layer_num_resnet_blocks): 228 | append(self.encoders, ResBlock(dim_out, groups = resnet_groups)) 229 | prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups)) 230 | 231 | prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2)) 232 | append(self.decoders, nn.Conv2d(dim, channels, 1)) 233 | 234 | def get_encoded_fmap_size(self, image_size): 235 | return image_size // (2 ** self.layers) 236 | 237 | @property 238 | def last_dec_layer(self): 239 | return self.decoders[-1].weight 240 | 241 | def encode(self, x): 242 | for enc in self.encoders: 243 | x = enc(x) 244 | return x 245 | 246 | def decode(self, x): 247 | for dec in self.decoders: 248 | x = dec(x) 249 | return x 250 | 251 | class GLUResBlock(nn.Module): 252 | def __init__(self, chan, groups = 16): 253 | super().__init__() 254 | self.net = nn.Sequential( 255 | nn.Conv2d(chan, chan * 2, 3, padding = 1), 256 | nn.GLU(dim = 1), 257 | nn.GroupNorm(groups, chan), 258 | nn.Conv2d(chan, chan * 2, 3, padding = 1), 259 | nn.GLU(dim = 1), 260 | nn.GroupNorm(groups, chan), 261 | nn.Conv2d(chan, chan, 1) 262 | ) 263 | 264 | def forward(self, x): 265 | return self.net(x) + x 266 | 267 | class ResBlock(nn.Module): 268 | def __init__(self, chan, groups = 16): 269 | super().__init__() 270 | self.net = nn.Sequential( 271 | nn.Conv2d(chan, chan, 3, padding = 1), 272 | nn.GroupNorm(groups, chan), 273 | leaky_relu(), 274 | nn.Conv2d(chan, chan, 3, padding = 1), 275 | nn.GroupNorm(groups, chan), 276 | leaky_relu(), 277 | nn.Conv2d(chan, chan, 1) 278 | ) 279 | 280 | def forward(self, x): 281 | return self.net(x) + x 282 | 283 | # main vqgan-vae classes 284 | 285 | class VQGanVAE(nn.Module): 286 | def __init__( 287 | self, 288 | *, 289 | dim, 290 | channels = 3, 291 | layers = 4, 292 | l2_recon_loss = False, 293 | use_hinge_loss = True, 294 | vgg = None, 295 | lookup_free_quantization = True, 296 | codebook_size = 65536, 297 | vq_kwargs: dict = dict( 298 | codebook_dim = 256, 299 | decay = 0.8, 300 | commitment_weight = 1., 301 | kmeans_init = True, 302 | use_cosine_sim = True, 303 | ), 304 | lfq_kwargs: dict = dict( 305 | diversity_gamma = 4. 306 | ), 307 | use_vgg_and_gan = True, 308 | discr_layers = 4, 309 | **kwargs 310 | ): 311 | super().__init__() 312 | vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs) 313 | encdec_kwargs, kwargs = groupby_prefix_and_trim('encdec_', kwargs) 314 | 315 | self.channels = channels 316 | self.codebook_size = codebook_size 317 | self.dim_divisor = 2 ** layers 318 | 319 | enc_dec_klass = ResnetEncDec 320 | 321 | self.enc_dec = enc_dec_klass( 322 | dim = dim, 323 | channels = channels, 324 | layers = layers, 325 | **encdec_kwargs 326 | ) 327 | 328 | self.lookup_free_quantization = lookup_free_quantization 329 | 330 | if lookup_free_quantization: 331 | self.quantizer = LFQ( 332 | dim = self.enc_dec.encoded_dim, 333 | codebook_size = codebook_size, 334 | **lfq_kwargs 335 | ) 336 | else: 337 | self.quantizer = VQ( 338 | dim = self.enc_dec.encoded_dim, 339 | codebook_size = codebook_size, 340 | accept_image_fmap = True 341 | **vq_kwargs 342 | ) 343 | 344 | # reconstruction loss 345 | 346 | self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss 347 | 348 | # turn off GAN and perceptual loss if grayscale 349 | 350 | self._vgg = None 351 | self.discr = None 352 | self.use_vgg_and_gan = use_vgg_and_gan 353 | 354 | if not use_vgg_and_gan: 355 | return 356 | 357 | # preceptual loss 358 | 359 | if exists(vgg): 360 | self._vgg = vgg 361 | 362 | # gan related losses 363 | 364 | layer_mults = list(map(lambda t: 2 ** t, range(discr_layers))) 365 | layer_dims = [dim * mult for mult in layer_mults] 366 | dims = (dim, *layer_dims) 367 | 368 | self.discr = Discriminator(dims = dims, channels = channels) 369 | 370 | self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss 371 | self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss 372 | 373 | @property 374 | def device(self): 375 | return next(self.parameters()).device 376 | 377 | @property 378 | def vgg(self): 379 | if exists(self._vgg): 380 | return self._vgg 381 | 382 | vgg = torchvision.models.vgg16(pretrained = True) 383 | vgg.classifier = nn.Sequential(*vgg.classifier[:-2]) 384 | self._vgg = vgg.to(self.device) 385 | return self._vgg 386 | 387 | @property 388 | def encoded_dim(self): 389 | return self.enc_dec.encoded_dim 390 | 391 | def get_encoded_fmap_size(self, image_size): 392 | return self.enc_dec.get_encoded_fmap_size(image_size) 393 | 394 | def copy_for_eval(self): 395 | device = next(self.parameters()).device 396 | vae_copy = copy.deepcopy(self.cpu()) 397 | 398 | if vae_copy.use_vgg_and_gan: 399 | del vae_copy.discr 400 | del vae_copy._vgg 401 | 402 | vae_copy.eval() 403 | return vae_copy.to(device) 404 | 405 | @remove_vgg 406 | def state_dict(self, *args, **kwargs): 407 | return super().state_dict(*args, **kwargs) 408 | 409 | @remove_vgg 410 | def load_state_dict(self, *args, **kwargs): 411 | return super().load_state_dict(*args, **kwargs) 412 | 413 | def save(self, path): 414 | torch.save(self.state_dict(), path) 415 | 416 | def load(self, path): 417 | path = Path(path) 418 | assert path.exists() 419 | state_dict = torch.load(str(path)) 420 | self.load_state_dict(state_dict) 421 | 422 | def encode(self, fmap): 423 | fmap = self.enc_dec.encode(fmap) 424 | fmap, indices, vq_aux_loss = self.quantizer(fmap) 425 | return fmap, indices, vq_aux_loss 426 | 427 | def decode_from_ids(self, ids): 428 | 429 | if self.lookup_free_quantization: 430 | ids, ps = pack([ids], 'b *') 431 | fmap = self.quantizer.indices_to_codes(ids) 432 | fmap, = unpack(fmap, ps, 'b * c') 433 | else: 434 | codes = self.codebook[ids] 435 | fmap = self.quantizer.project_out(codes) 436 | 437 | fmap = rearrange(fmap, 'b h w c -> b c h w') 438 | return self.decode(fmap) 439 | 440 | def decode(self, fmap): 441 | return self.enc_dec.decode(fmap) 442 | 443 | def forward( 444 | self, 445 | img, 446 | return_loss = False, 447 | return_discr_loss = False, 448 | return_recons = False, 449 | add_gradient_penalty = True 450 | ): 451 | batch, channels, height, width, device = *img.shape, img.device 452 | 453 | for dim_name, size in (('height', height), ('width', width)): 454 | assert (size % self.dim_divisor) == 0, f'{dim_name} must be divisible by {self.dim_divisor}' 455 | 456 | assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE' 457 | 458 | fmap, indices, commit_loss = self.encode(img) 459 | 460 | fmap = self.decode(fmap) 461 | 462 | if not return_loss and not return_discr_loss: 463 | return fmap 464 | 465 | assert return_loss ^ return_discr_loss, 'you should either return autoencoder loss or discriminator loss, but not both' 466 | 467 | # whether to return discriminator loss 468 | 469 | if return_discr_loss: 470 | assert exists(self.discr), 'discriminator must exist to train it' 471 | 472 | fmap.detach_() 473 | img.requires_grad_() 474 | 475 | fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img)) 476 | 477 | discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits) 478 | 479 | if add_gradient_penalty: 480 | gp = gradient_penalty(img, img_discr_logits) 481 | loss = discr_loss + gp 482 | 483 | if return_recons: 484 | return loss, fmap 485 | 486 | return loss 487 | 488 | # reconstruction loss 489 | 490 | recon_loss = self.recon_loss_fn(fmap, img) 491 | 492 | # early return if training on grayscale 493 | 494 | if not self.use_vgg_and_gan: 495 | if return_recons: 496 | return recon_loss, fmap 497 | 498 | return recon_loss 499 | 500 | # perceptual loss 501 | 502 | img_vgg_input = img 503 | fmap_vgg_input = fmap 504 | 505 | if img.shape[1] == 1: 506 | # handle grayscale for vgg 507 | img_vgg_input, fmap_vgg_input = map(lambda t: repeat(t, 'b 1 ... -> b c ...', c = 3), (img_vgg_input, fmap_vgg_input)) 508 | 509 | img_vgg_feats = self.vgg(img_vgg_input) 510 | recon_vgg_feats = self.vgg(fmap_vgg_input) 511 | perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats) 512 | 513 | # generator loss 514 | 515 | gen_loss = self.gen_loss(self.discr(fmap)) 516 | 517 | # calculate adaptive weight 518 | 519 | last_dec_layer = self.enc_dec.last_dec_layer 520 | 521 | norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2) 522 | norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2) 523 | 524 | adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss) 525 | adaptive_weight.clamp_(max = 1e4) 526 | 527 | # combine losses 528 | 529 | loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss 530 | 531 | if return_recons: 532 | return loss, fmap 533 | 534 | return loss 535 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'muse-maskgit-pytorch', 5 | packages = find_packages(exclude=[]), 6 | version = '0.3.5', 7 | license='MIT', 8 | description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | long_description_content_type = 'text/markdown', 12 | url = 'https://github.com/lucidrains/muse-maskgit-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'transformers', 17 | 'attention mechanism', 18 | 'text-to-image' 19 | ], 20 | install_requires=[ 21 | 'accelerate', 22 | 'beartype', 23 | 'einops>=0.7', 24 | 'ema-pytorch>=0.2.2', 25 | 'memory-efficient-attention-pytorch>=0.1.4', 26 | 'pillow', 27 | 'sentencepiece', 28 | 'torch>=1.6', 29 | 'transformers', 30 | 'torch>=1.6', 31 | 'torchvision', 32 | 'tqdm', 33 | 'vector-quantize-pytorch>=1.11.8' 34 | ], 35 | classifiers=[ 36 | 'Development Status :: 4 - Beta', 37 | 'Intended Audience :: Developers', 38 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 39 | 'License :: OSI Approved :: MIT License', 40 | 'Programming Language :: Python :: 3.6', 41 | ], 42 | ) 43 | --------------------------------------------------------------------------------