├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── dalle_pytorch ├── __init__.py ├── attention.py ├── dalle_pytorch.py ├── data │ └── bpe_simple_vocab_16e6.txt ├── distributed_backends │ ├── __init__.py │ ├── deepspeed_backend.py │ ├── distributed_backend.py │ ├── dummy_backend.py │ └── horovod_backend.py ├── distributed_utils.py ├── loader.py ├── reversible.py ├── tokenizer.py ├── transformer.py └── vae.py ├── docker └── Dockerfile ├── examples └── rainbow_dalle.ipynb ├── generate.py ├── images ├── avocado-0.png ├── avocado-1.png ├── avocado-2.png ├── avocado-3.png ├── avocado-4.png ├── avocado-5.png ├── avocado-6.png ├── avocado-7.png ├── avocado-before.png ├── avocado.png ├── banner.jpg ├── birds.png ├── clothing.png ├── cloud-0.png ├── cloud-1.png ├── cloud-2.png ├── cloud-3.png ├── cloud-4.png ├── cloud-5.png ├── cloud-6.png ├── cloud-7.png ├── cube-cloud-before.jpg ├── cube-cloud.png ├── cube-porcupine-before.jpg ├── cube-porcupine.png ├── cube-water-before.jpg ├── cube-water.png ├── girl-0.png ├── girl-1.png ├── girl-2.png ├── girl-3.png ├── girl-4.png ├── girl-5.png ├── girl-6.png ├── girl-7.png ├── girl-glasses-before.png ├── girl-glasses.png ├── landscape.png ├── layouts-1.jpg ├── layouts-2.jpg ├── researcher-mad-before.png ├── researcher-mad.png └── wb.png ├── install_apex.sh ├── install_deepspeed.sh ├── setup.py ├── train_dalle.py └── train_vae.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # dall-e generation outputs 2 | outputs/ 3 | *.pt 4 | taming/ 5 | wandb/ 6 | dalle-ds-cp/ 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # Visual Studio Code 95 | .vscode 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include dalle_pytorch *.txt 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DALL-3 2 | 3 | try it out on colab 4 | 5 | Dall-3 colab link 6 | 7 | 8 | DALL-3 is a mashup of DALLE-pytorch, VQGAN and Clip-Guided Diffusion. The basic idea is to use a diffusion model instead of VAE for the decoder stage, which allows us to use 16x16 tokens instead of 32x32 while maintaining comparable image quality. 9 | 10 | This DALLE model is meant to be used with https://github.com/Jack000/guided-diffusion 11 | 12 | minor modifications to DALLE-pytorch: 13 | - hardcoded 128px image size in dalle (in order to use mismatched VAE/DALLE image sizes) 14 | - added top-p filtering 15 | 16 | Cherry picked sample images: 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 |
Before diffusionAfter diffusion
a cube made of cloud - beforea cube made of cloud
Prompt: A cube made of cloud, a cube with the texture of cloud
a cube made of water - beforea cube made of water
Prompt: A cube made of water, a cube with the texture of water
a cube made of porcupine - beforea cube made of porcupine
Prompt: A cube made of porcupine, a cube with the texture of porcupine
an armchair shaped like an avocado, an avocado armchair - beforean armchair shaped like an avocado, an avocado armchair
Prompt: An armchair shaped like an avocado, an avocado armchair
a girl with thick glasses - beforea girl with thick glasses
Prompt: A girl with thick glasses, a girl wearing glasses
a machine learning researcher smashes his computer in a fit of rage - beforea machine learning researcher smashes his computer in a fit of rage
Prompt: A machine learning researcher smashes his computer in a fit of rage
36 | 37 | Non-Cherry picked images (clip re-ranked best 8 out of 1024): 38 | 39 | 40 | 41 | 42 | 43 |
A cube made of cloud. A cube with the texture of cloud
44 | 45 | 46 | 47 | 48 | 49 |
An armchair shaped like an avocado. An avocado armchair
50 | 51 | 52 | 53 | 54 | 55 |
A girl with thick glasses. A girl wearing glasses
56 | 57 | ## Usage 58 | ```# git clone this repo, then 59 | cd DALLE-pytorch 60 | pip install -e . 61 | 62 | # download GumbelVQ VAE model 63 | mkdir -p vqgan_gumbel_f8 64 | wget 'https://heibox.uni-heidelberg.de/f/b24d14998a8d4f19a34f/?dl=1' -O 'vqgan_gumbel_f8/model.yaml' 65 | wget 'https://heibox.uni-heidelberg.de/f/34a747d5765840b5a99d/?dl=1' -O 'vqgan_gumbel_f8/last.ckpt' 66 | 67 | # download DALL-E models 68 | wget https://dall-3.com/models/dalle/bpe.model 69 | wget https://dall-3.com/models/dalle/dalle-latest.pt 70 | 71 | # generate (optionally install OpenAI clip for --clip_sort) 72 | python generate.py --top_p 0.85 --temperature 1.0 --clip_sort --output_npy --dalle_path ./dalle-latest.pt --bpe_path bpe.model --taming --vqgan_model_path vqgan_gumbel_f8/last.ckpt --vqgan_config_path vqgan_gumbel_f8/model.yaml --text 'a girl with thick glasses. a girl wearing glasses' 73 | 74 | # post process 75 | # use the npy file as input to clip guided diffusion https://github.com/Jack000/guided-diffusion 76 | 77 | ``` 78 | 79 | ## Citations 80 | 81 | ```bibtex 82 | @misc{ramesh2021zeroshot, 83 | title = {Zero-Shot Text-to-Image Generation}, 84 | author = {Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever}, 85 | year = {2021}, 86 | eprint = {2102.12092}, 87 | archivePrefix = {arXiv}, 88 | primaryClass = {cs.CV} 89 | } 90 | ``` 91 | 92 | ```bibtex 93 | @misc{unpublished2021clip, 94 | title = {CLIP: Connecting Text and Images}, 95 | author = {Alec Radford, Ilya Sutskever, Jong Wook Kim, Gretchen Krueger, Sandhini Agarwal}, 96 | year = {2021} 97 | } 98 | ``` 99 | 100 | ```bibtex 101 | @misc{kitaev2020reformer, 102 | title = {Reformer: The Efficient Transformer}, 103 | author = {Nikita Kitaev and Łukasz Kaiser and Anselm Levskaya}, 104 | year = {2020}, 105 | eprint = {2001.04451}, 106 | archivePrefix = {arXiv}, 107 | primaryClass = {cs.LG} 108 | } 109 | ``` 110 | 111 | ```bibtex 112 | @misc{esser2021taming, 113 | title = {Taming Transformers for High-Resolution Image Synthesis}, 114 | author = {Patrick Esser and Robin Rombach and Björn Ommer}, 115 | year = {2021}, 116 | eprint = {2012.09841}, 117 | archivePrefix = {arXiv}, 118 | primaryClass = {cs.CV} 119 | } 120 | ``` 121 | 122 | ```bibtex 123 | @misc{ding2021cogview, 124 | title = {CogView: Mastering Text-to-Image Generation via Transformers}, 125 | author = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang}, 126 | year = {2021}, 127 | eprint = {2105.13290}, 128 | archivePrefix = {arXiv}, 129 | primaryClass = {cs.CV} 130 | } 131 | ``` 132 | 133 | ```bibtex 134 | @software{peng_bo_2021_5196578, 135 | author = {PENG Bo}, 136 | title = {BlinkDL/RWKV-LM: 0.01}, 137 | month = {aug}, 138 | year = {2021}, 139 | publisher = {Zenodo}, 140 | version = {0.01}, 141 | doi = {10.5281/zenodo.5196578}, 142 | url = {https://doi.org/10.5281/zenodo.5196578} 143 | } 144 | ``` 145 | 146 | ```bibtex 147 | @misc{su2021roformer, 148 | title = {RoFormer: Enhanced Transformer with Rotary Position Embedding}, 149 | author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu}, 150 | year = {2021}, 151 | eprint = {2104.09864}, 152 | archivePrefix = {arXiv}, 153 | primaryClass = {cs.CL} 154 | } 155 | ``` 156 | 157 | *Those who do not want to imitate anything, produce nothing.* - Dali 158 | -------------------------------------------------------------------------------- /dalle_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from dalle_pytorch.dalle_pytorch import DALLE, CLIP, DiscreteVAE 2 | from dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE 3 | -------------------------------------------------------------------------------- /dalle_pytorch/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | from math import ceil 3 | 4 | import torch 5 | from torch import nn, einsum 6 | import torch.nn.functional as F 7 | from einops import rearrange, repeat 8 | 9 | from rotary_embedding_torch import apply_rotary_emb 10 | 11 | # helpers 12 | 13 | def exists(val): 14 | return val is not None 15 | 16 | def uniq(arr): 17 | return{el: True for el in arr}.keys() 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | def max_neg_value(t): 25 | return -torch.finfo(t.dtype).max 26 | 27 | def stable_softmax(t, dim = -1, alpha = 32 ** 2): 28 | t = t / alpha 29 | t = t - torch.amax(t, dim = dim, keepdim = True).detach() 30 | return (t * alpha).softmax(dim = dim) 31 | 32 | def apply_pos_emb(pos_emb, qkv): 33 | n = qkv[0].shape[-2] 34 | pos_emb = pos_emb[..., :n, :] 35 | return tuple(map(lambda t: apply_rotary_emb(pos_emb, t), qkv)) 36 | 37 | # classes 38 | 39 | class Attention(nn.Module): 40 | def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False): 41 | super().__init__() 42 | inner_dim = dim_head * heads 43 | self.heads = heads 44 | self.seq_len = seq_len 45 | self.scale = dim_head ** -0.5 46 | 47 | self.stable = stable 48 | self.causal = causal 49 | 50 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 51 | self.to_out = nn.Sequential( 52 | nn.Linear(inner_dim, dim), 53 | nn.Dropout(dropout) 54 | ) 55 | 56 | def forward(self, x, mask = None, rotary_pos_emb = None): 57 | b, n, _, h, device = *x.shape, self.heads, x.device 58 | softmax = torch.softmax if not self.stable else stable_softmax 59 | 60 | qkv = self.to_qkv(x).chunk(3, dim = -1) 61 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 62 | 63 | if exists(rotary_pos_emb): 64 | q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v)) 65 | 66 | q = q * self.scale 67 | 68 | dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) 69 | mask_value = max_neg_value(dots) 70 | 71 | if exists(mask): 72 | mask = rearrange(mask, 'b j -> b () () j') 73 | dots.masked_fill_(~mask, mask_value) 74 | del mask 75 | 76 | if self.causal: 77 | i, j = dots.shape[-2:] 78 | mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() 79 | dots.masked_fill_(mask, mask_value) 80 | 81 | attn = softmax(dots, dim=-1) 82 | 83 | out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) 84 | out = rearrange(out, 'b h n d -> b n (h d)') 85 | out = self.to_out(out) 86 | return out 87 | 88 | # sparse attention with convolutional pattern, as mentioned in the blog post. customizable kernel size and dilation 89 | 90 | class SparseConvCausalAttention(nn.Module): 91 | def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1, heads = 8, dim_head = 64, dropout = 0., stable = False, **kwargs): 92 | super().__init__() 93 | assert kernel_size % 2 == 1, 'kernel size must be odd' 94 | 95 | inner_dim = dim_head * heads 96 | self.seq_len = seq_len 97 | self.heads = heads 98 | self.scale = dim_head ** -0.5 99 | self.image_size = image_size 100 | self.kernel_size = kernel_size 101 | self.dilation = dilation 102 | 103 | self.stable = stable 104 | 105 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 106 | 107 | self.to_out = nn.Sequential( 108 | nn.Linear(inner_dim, dim), 109 | nn.Dropout(dropout) 110 | ) 111 | 112 | def forward(self, x, mask = None, rotary_pos_emb = None): 113 | b, n, _, h, img_size, kernel_size, dilation, seq_len, device = *x.shape, self.heads, self.image_size, self.kernel_size, self.dilation, self.seq_len, x.device 114 | softmax = torch.softmax if not self.stable else stable_softmax 115 | 116 | img_seq_len = img_size ** 2 117 | text_len = seq_len + 1 - img_seq_len 118 | 119 | # padding 120 | 121 | padding = seq_len - n + 1 122 | mask = default(mask, lambda: torch.ones(b, text_len, device = device).bool()) 123 | 124 | x = F.pad(x, (0, 0, 0, padding), value = 0) 125 | mask = mask[:, :text_len] 126 | 127 | # derive query / keys / values 128 | 129 | qkv = self.to_qkv(x).chunk(3, dim = -1) 130 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv) 131 | 132 | if exists(rotary_pos_emb): 133 | q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v)) 134 | 135 | q *= self.scale 136 | 137 | ((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v)) 138 | 139 | # text attention 140 | 141 | dots_text = einsum('b i d, b j d -> b i j', q_text, k_text) 142 | mask_value = max_neg_value(dots_text) 143 | 144 | i, j = dots_text.shape[-2:] 145 | text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() 146 | dots_text.masked_fill_(text_causal_mask, mask_value) 147 | 148 | attn_text = softmax(dots_text, dim = -1) 149 | out_text = einsum('b i j, b j d -> b i d', attn_text, v_text) 150 | 151 | # image attention 152 | 153 | effective_kernel_size = (kernel_size - 1) * dilation + 1 154 | padding = effective_kernel_size // 2 155 | 156 | k_img, v_img = map(lambda t: rearrange(t, 'b (h w) c -> b c h w', h = img_size), (k_img, v_img)) 157 | k_img, v_img = map(lambda t: F.unfold(t, kernel_size, padding = padding, dilation = dilation), (k_img, v_img)) 158 | k_img, v_img = map(lambda t: rearrange(t, 'b (d j) i -> b i j d', j = kernel_size ** 2), (k_img, v_img)) 159 | 160 | # let image attend to all of text 161 | 162 | dots_image = einsum('b i d, b i j d -> b i j', q_img, k_img) 163 | dots_image_to_text = einsum('b i d, b j d -> b i j', q_img, k_text) 164 | 165 | # calculate causal attention for local convolution 166 | 167 | i, j = dots_image.shape[-2:] 168 | img_seq = torch.arange(img_seq_len, device = device) 169 | k_img_indices = rearrange(img_seq.float(), '(h w) -> () () h w', h = img_size) 170 | k_img_indices = F.pad(k_img_indices, (padding,) * 4, value = img_seq_len) # padding set to be max, so it is never attended to 171 | k_img_indices = F.unfold(k_img_indices, kernel_size, dilation = dilation) 172 | k_img_indices = rearrange(k_img_indices, 'b j i -> b i j') 173 | 174 | # mask image attention 175 | 176 | q_img_indices = rearrange(img_seq, 'i -> () i ()') 177 | causal_mask = q_img_indices < k_img_indices 178 | 179 | # concat text mask with image causal mask 180 | 181 | causal_mask = repeat(causal_mask, '() i j -> b i j', b = b * h) 182 | mask = repeat(mask, 'b j -> (b h) i j', i = i, h = h) 183 | mask = torch.cat((~mask, causal_mask), dim = -1) 184 | 185 | # image can attend to all of text 186 | 187 | dots = torch.cat((dots_image_to_text, dots_image), dim = -1) 188 | dots.masked_fill_(mask, mask_value) 189 | 190 | attn = softmax(dots, dim = -1) 191 | 192 | # aggregate 193 | 194 | attn_image_to_text, attn_image = attn[..., :text_len], attn[..., text_len:] 195 | 196 | out_image_to_image = einsum('b i j, b i j d -> b i d', attn_image, v_img) 197 | out_image_to_text = einsum('b i j, b j d -> b i d', attn_image_to_text, v_text) 198 | 199 | out_image = out_image_to_image + out_image_to_text 200 | 201 | # combine attended values for both text and image 202 | 203 | out = torch.cat((out_text, out_image), dim = 1) 204 | 205 | out = rearrange(out, '(b h) n d -> b n (h d)', h = h) 206 | out = self.to_out(out) 207 | return out[:, :n] 208 | 209 | # sparse axial causal attention 210 | 211 | class SparseAxialCausalAttention(nn.Module): 212 | def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head = 64, dropout = 0., stable = False, **kwargs): 213 | super().__init__() 214 | assert axis in {0, 1}, 'axis must be either 0 (along height) or 1 (along width)' 215 | self.axis = axis 216 | 217 | inner_dim = dim_head * heads 218 | self.seq_len = seq_len 219 | self.heads = heads 220 | self.scale = dim_head ** -0.5 221 | self.image_size = image_size 222 | 223 | self.stable = stable 224 | 225 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 226 | 227 | self.to_out = nn.Sequential( 228 | nn.Linear(inner_dim, dim), 229 | nn.Dropout(dropout) 230 | ) 231 | 232 | def forward(self, x, mask = None, rotary_pos_emb = None): 233 | b, n, _, h, img_size, axis, seq_len, device = *x.shape, self.heads, self.image_size, self.axis, self.seq_len, x.device 234 | softmax = torch.softmax if not self.stable else stable_softmax 235 | 236 | img_seq_len = img_size ** 2 237 | text_len = seq_len + 1 - img_seq_len 238 | 239 | # padding 240 | 241 | padding = seq_len - n + 1 242 | mask = default(mask, lambda: torch.ones(b, text_len, device = device).bool()) 243 | 244 | x = F.pad(x, (0, 0, 0, padding), value = 0) 245 | mask = mask[:, :text_len] 246 | 247 | # derive queries / keys / values 248 | 249 | qkv = self.to_qkv(x).chunk(3, dim = -1) 250 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv) 251 | 252 | if exists(rotary_pos_emb): 253 | q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v)) 254 | 255 | q *= self.scale 256 | 257 | ((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v)) 258 | 259 | # text attention 260 | 261 | dots_text = einsum('b i d, b j d -> b i j', q_text, k_text) 262 | mask_value = max_neg_value(dots_text) 263 | 264 | i, j = dots_text.shape[-2:] 265 | text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() 266 | dots_text.masked_fill_(text_causal_mask, mask_value) 267 | 268 | attn_text = softmax(dots_text, dim = -1) 269 | out_text = einsum('b i j, b j d -> b i d', attn_text, v_text) 270 | 271 | # image attention 272 | 273 | split_axis_einops = 'b (h w) c -> b h w c' if axis == 0 else 'b (h w) c -> b w h c' 274 | merge_axis_einops = 'b x n d -> b (x n) d' if axis == 0 else 'b x n d -> b (n x) d' 275 | 276 | # split out axis 277 | 278 | q_img, k_img, v_img = map(lambda t: rearrange(t, split_axis_einops, h = img_size), (q_img, k_img, v_img)) 279 | 280 | # similarity 281 | 282 | dots_image_to_image = einsum('b x i d, b x j d -> b x i j', q_img, k_img) 283 | dots_image_to_text = einsum('b x i d, b j d -> b x i j', q_img, k_text) 284 | 285 | dots = torch.cat((dots_image_to_text, dots_image_to_image), dim = -1) 286 | 287 | # mask so image has full attention to text, but causal along axis 288 | 289 | bh, x, i, j = dots.shape 290 | causal_mask = torch.ones(i, img_size, device = device).triu_(img_size - i + 1).bool() 291 | causal_mask = repeat(causal_mask, 'i j -> b x i j', b = bh, x = x) 292 | 293 | mask = repeat(mask, 'b j -> (b h) x i j', h = h, x = x, i = i) 294 | mask = torch.cat((~mask, causal_mask), dim = -1) 295 | 296 | dots.masked_fill_(mask, mask_value) 297 | 298 | # attention. 299 | 300 | attn = softmax(dots, dim = -1) 301 | 302 | # aggregate 303 | 304 | attn_image_to_text, attn_image_to_image = attn[..., :text_len], attn[..., text_len:] 305 | 306 | out_image_to_image = einsum('b x i j, b x j d -> b x i d', attn_image_to_image, v_img) 307 | out_image_to_text = einsum('b x i j, b j d -> b x i d', attn_image_to_text, v_text) 308 | 309 | out_image = out_image_to_image + out_image_to_text 310 | 311 | # merge back axis 312 | 313 | out_image = rearrange(out_image, merge_axis_einops, x = img_size) 314 | 315 | # combine attended values for both text and image 316 | 317 | out = torch.cat((out_text, out_image), dim = 1) 318 | 319 | out = rearrange(out, '(b h) n d -> b n (h d)', h = h) 320 | out = self.to_out(out) 321 | return out[:, :n] 322 | 323 | # microsoft sparse attention CUDA kernel 324 | 325 | class SparseAttention(Attention): 326 | def __init__( 327 | self, 328 | *args, 329 | block_size = 16, 330 | text_seq_len = 256, 331 | num_random_blocks = None, 332 | **kwargs 333 | ): 334 | super().__init__(*args, **kwargs) 335 | from deepspeed.ops.sparse_attention import SparseSelfAttention, VariableSparsityConfig 336 | self.block_size = block_size 337 | 338 | num_random_blocks = default(num_random_blocks, self.seq_len // block_size // 4) 339 | global_block_indices = list(range(ceil(text_seq_len / block_size))) 340 | 341 | self.attn_fn = SparseSelfAttention( 342 | sparsity_config = VariableSparsityConfig( 343 | num_heads = self.heads, 344 | block = self.block_size, 345 | num_random_blocks = num_random_blocks, 346 | global_block_indices = global_block_indices, 347 | attention = 'unidirectional' if self.causal else 'bidirectional' 348 | ), 349 | max_seq_length = self.seq_len, 350 | attn_mask_mode = 'add' 351 | ) 352 | 353 | def forward(self, x, mask = None, rotary_pos_emb = None): 354 | b, n, _, h, device = *x.shape, self.heads, x.device 355 | remainder = n % self.block_size 356 | mask = default(mask, lambda: torch.ones(b, n, device = device).bool()) 357 | 358 | if remainder > 0: 359 | padding = self.block_size - remainder 360 | x = F.pad(x, (0, 0, 0, padding), value = 0) 361 | mask = F.pad(mask, (0, padding), value = False) 362 | 363 | qkv = self.to_qkv(x).chunk(3, dim = -1) 364 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 365 | 366 | if exists(rotary_pos_emb): 367 | q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v)) 368 | 369 | key_pad_mask = None 370 | if exists(mask): 371 | key_pad_mask = ~mask 372 | 373 | attn_mask = None 374 | if self.causal: 375 | i, j = q.shape[-2], k.shape[-2] 376 | mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() 377 | attn_mask = torch.zeros(i, j, device = device).to(q) 378 | mask_value = max_neg_value(q) / 2 379 | attn_mask.masked_fill_(mask, mask_value) 380 | 381 | out = self.attn_fn(q, k, v, attn_mask = attn_mask, key_padding_mask = key_pad_mask) 382 | out = rearrange(out, 'b h n d -> b n (h d)') 383 | out = self.to_out(out) 384 | return out[:, :n] 385 | -------------------------------------------------------------------------------- /dalle_pytorch/dalle_pytorch.py: -------------------------------------------------------------------------------- 1 | from math import log2, sqrt 2 | import torch 3 | from torch import nn, einsum 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | from axial_positional_embedding import AxialPositionalEmbedding 8 | from einops import rearrange 9 | 10 | from dalle_pytorch import distributed_utils 11 | from dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE 12 | from dalle_pytorch.transformer import Transformer, DivideMax 13 | 14 | # helpers 15 | 16 | def exists(val): 17 | return val is not None 18 | 19 | def default(val, d): 20 | return val if exists(val) else d 21 | 22 | class always(): 23 | def __init__(self, val): 24 | self.val = val 25 | def __call__(self, x, *args, **kwargs): 26 | return self.val 27 | 28 | def is_empty(t): 29 | return t.nelement() == 0 30 | 31 | def masked_mean(t, mask, dim = 1): 32 | t = t.masked_fill(~mask[:, :, None], 0.) 33 | return t.sum(dim = 1) / mask.sum(dim = 1)[..., None] 34 | 35 | def set_requires_grad(model, value): 36 | for param in model.parameters(): 37 | param.requires_grad = value 38 | 39 | def eval_decorator(fn): 40 | def inner(model, *args, **kwargs): 41 | was_training = model.training 42 | model.eval() 43 | out = fn(model, *args, **kwargs) 44 | model.train(was_training) 45 | return out 46 | return inner 47 | 48 | # sampling helpers 49 | 50 | def top_k(logits, thres = 0.5): 51 | num_logits = logits.shape[-1] 52 | k = max(int((1 - thres) * num_logits), 1) 53 | val, ind = torch.topk(logits, k) 54 | probs = torch.full_like(logits, float('-inf')) 55 | probs.scatter_(1, ind, val) 56 | return probs 57 | 58 | def top_p(logits, thres = 0.9): 59 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 60 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 61 | 62 | # Remove tokens with cumulative probability above the threshold 63 | sorted_indices_to_remove = cumulative_probs > thres 64 | 65 | # Shift the indices to the right to keep also the first token above the threshold 66 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 67 | sorted_indices_to_remove[..., 0] = 0 68 | 69 | # scatter sorted tensors to original indexing 70 | indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) 71 | logits[indices_to_remove] = float('-inf') 72 | return logits 73 | 74 | # discrete vae class 75 | 76 | class ResBlock(nn.Module): 77 | def __init__(self, chan): 78 | super().__init__() 79 | self.net = nn.Sequential( 80 | nn.Conv2d(chan, chan, 3, padding = 1), 81 | nn.ReLU(), 82 | nn.Conv2d(chan, chan, 3, padding = 1), 83 | nn.ReLU(), 84 | nn.Conv2d(chan, chan, 1) 85 | ) 86 | 87 | def forward(self, x): 88 | return self.net(x) + x 89 | 90 | class DiscreteVAE(nn.Module): 91 | def __init__( 92 | self, 93 | image_size = 256, 94 | num_tokens = 512, 95 | codebook_dim = 512, 96 | num_layers = 3, 97 | num_resnet_blocks = 0, 98 | hidden_dim = 64, 99 | channels = 3, 100 | smooth_l1_loss = False, 101 | temperature = 0.9, 102 | straight_through = False, 103 | kl_div_loss_weight = 0., 104 | normalization = ((0.5,) * 3, (0.5,) * 3) 105 | ): 106 | super().__init__() 107 | assert log2(image_size).is_integer(), 'image size must be a power of 2' 108 | assert num_layers >= 1, 'number of layers must be greater than or equal to 1' 109 | has_resblocks = num_resnet_blocks > 0 110 | 111 | self.image_size = image_size 112 | self.num_tokens = num_tokens 113 | self.num_layers = num_layers 114 | self.temperature = temperature 115 | self.straight_through = straight_through 116 | self.codebook = nn.Embedding(num_tokens, codebook_dim) 117 | 118 | hdim = hidden_dim 119 | 120 | enc_chans = [hidden_dim] * num_layers 121 | dec_chans = list(reversed(enc_chans)) 122 | 123 | enc_chans = [channels, *enc_chans] 124 | 125 | dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0] 126 | dec_chans = [dec_init_chan, *dec_chans] 127 | 128 | enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans)) 129 | 130 | enc_layers = [] 131 | dec_layers = [] 132 | 133 | for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io): 134 | enc_layers.append(nn.Sequential(nn.Conv2d(enc_in, enc_out, 4, stride = 2, padding = 1), nn.ReLU())) 135 | dec_layers.append(nn.Sequential(nn.ConvTranspose2d(dec_in, dec_out, 4, stride = 2, padding = 1), nn.ReLU())) 136 | 137 | for _ in range(num_resnet_blocks): 138 | dec_layers.insert(0, ResBlock(dec_chans[1])) 139 | enc_layers.append(ResBlock(enc_chans[-1])) 140 | 141 | if num_resnet_blocks > 0: 142 | dec_layers.insert(0, nn.Conv2d(codebook_dim, dec_chans[1], 1)) 143 | 144 | enc_layers.append(nn.Conv2d(enc_chans[-1], num_tokens, 1)) 145 | dec_layers.append(nn.Conv2d(dec_chans[-1], channels, 1)) 146 | 147 | self.encoder = nn.Sequential(*enc_layers) 148 | self.decoder = nn.Sequential(*dec_layers) 149 | 150 | self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss 151 | self.kl_div_loss_weight = kl_div_loss_weight 152 | 153 | # take care of normalization within class 154 | self.normalization = normalization 155 | 156 | self._register_external_parameters() 157 | 158 | def _register_external_parameters(self): 159 | """Register external parameters for DeepSpeed partitioning.""" 160 | if ( 161 | not distributed_utils.is_distributed 162 | or not distributed_utils.using_backend( 163 | distributed_utils.DeepSpeedBackend) 164 | ): 165 | return 166 | 167 | deepspeed = distributed_utils.backend.backend_module 168 | deepspeed.zero.register_external_parameter(self, self.codebook.weight) 169 | 170 | def norm(self, images): 171 | if not exists(self.normalization): 172 | return images 173 | 174 | means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization) 175 | means, stds = map(lambda t: rearrange(t, 'c -> () c () ()'), (means, stds)) 176 | images = images.clone() 177 | images.sub_(means).div_(stds) 178 | return images 179 | 180 | @torch.no_grad() 181 | @eval_decorator 182 | def get_codebook_indices(self, images): 183 | logits = self(images, return_logits = True) 184 | codebook_indices = logits.argmax(dim = 1).flatten(1) 185 | return codebook_indices 186 | 187 | def decode( 188 | self, 189 | img_seq 190 | ): 191 | image_embeds = self.codebook(img_seq) 192 | b, n, d = image_embeds.shape 193 | h = w = int(sqrt(n)) 194 | 195 | image_embeds = rearrange(image_embeds, 'b (h w) d -> b d h w', h = h, w = w) 196 | images = self.decoder(image_embeds) 197 | return images 198 | 199 | def forward( 200 | self, 201 | img, 202 | return_loss = False, 203 | return_recons = False, 204 | return_logits = False, 205 | temp = None 206 | ): 207 | device, num_tokens, image_size, kl_div_loss_weight = img.device, self.num_tokens, self.image_size, self.kl_div_loss_weight 208 | assert img.shape[-1] == image_size and img.shape[-2] == image_size, f'input must have the correct image size {image_size}' 209 | 210 | img = self.norm(img) 211 | 212 | logits = self.encoder(img) 213 | 214 | if return_logits: 215 | return logits # return logits for getting hard image indices for DALL-E training 216 | 217 | temp = default(temp, self.temperature) 218 | soft_one_hot = F.gumbel_softmax(logits, tau = temp, dim = 1, hard = self.straight_through) 219 | sampled = einsum('b n h w, n d -> b d h w', soft_one_hot, self.codebook.weight) 220 | out = self.decoder(sampled) 221 | 222 | if not return_loss: 223 | return out 224 | 225 | # reconstruction loss 226 | 227 | recon_loss = self.loss_fn(img, out) 228 | 229 | # kl divergence 230 | 231 | logits = rearrange(logits, 'b n h w -> b (h w) n') 232 | log_qy = F.log_softmax(logits, dim = -1) 233 | log_uniform = torch.log(torch.tensor([1. / num_tokens], device = device)) 234 | kl_div = F.kl_div(log_uniform, log_qy, None, None, 'batchmean', log_target = True) 235 | 236 | loss = recon_loss + (kl_div * kl_div_loss_weight) 237 | 238 | if not return_recons: 239 | return loss 240 | 241 | return loss, out 242 | 243 | # main classes 244 | 245 | class CLIP(nn.Module): 246 | def __init__( 247 | self, 248 | *, 249 | dim_text = 512, 250 | dim_image = 512, 251 | dim_latent = 512, 252 | num_text_tokens = 10000, 253 | text_enc_depth = 6, 254 | text_seq_len = 256, 255 | text_heads = 8, 256 | num_visual_tokens = 512, 257 | visual_enc_depth = 6, 258 | visual_heads = 8, 259 | visual_image_size = 256, 260 | visual_patch_size = 32, 261 | channels = 3 262 | ): 263 | super().__init__() 264 | self.text_emb = nn.Embedding(num_text_tokens, dim_text) 265 | self.text_pos_emb = nn.Embedding(text_seq_len, dim_text) 266 | self.text_transformer = Transformer(causal = False, seq_len = text_seq_len, dim = dim_text, depth = text_enc_depth, heads = text_heads, rotary_emb = False) 267 | self.to_text_latent = nn.Linear(dim_text, dim_latent, bias = False) 268 | 269 | assert visual_image_size % visual_patch_size == 0, 'Image dimensions must be divisible by the patch size.' 270 | num_patches = (visual_image_size // visual_patch_size) ** 2 271 | patch_dim = channels * visual_patch_size ** 2 272 | 273 | self.visual_patch_size = visual_patch_size 274 | self.to_visual_embedding = nn.Linear(patch_dim, dim_image) 275 | self.visual_pos_emb = nn.Embedding(num_patches, dim_image) 276 | self.visual_transformer = Transformer(causal = False, seq_len = num_patches, dim = dim_image, depth = visual_enc_depth, heads = visual_heads, rotary_emb = False) 277 | self.to_visual_latent = nn.Linear(dim_image, dim_latent, bias = False) 278 | 279 | self.temperature = nn.Parameter(torch.tensor(1.)) 280 | 281 | def forward( 282 | self, 283 | text, 284 | image, 285 | text_mask = None, 286 | return_loss = False 287 | ): 288 | b, device, p = text.shape[0], text.device, self.visual_patch_size 289 | 290 | text_emb = self.text_emb(text) 291 | text_emb += self.text_pos_emb(torch.arange(text.shape[1], device = device)) 292 | 293 | image_patches = rearrange(image, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) 294 | image_emb = self.to_visual_embedding(image_patches) 295 | image_emb += self.visual_pos_emb(torch.arange(image_emb.shape[1], device = device)) 296 | 297 | enc_text = self.text_transformer(text_emb, mask = text_mask) 298 | enc_image = self.visual_transformer(image_emb) 299 | 300 | if exists(text_mask): 301 | text_latents = masked_mean(enc_text, text_mask, dim = 1) 302 | else: 303 | text_latents = enc_text.mean(dim = 1) 304 | 305 | image_latents = enc_image.mean(dim = 1) 306 | 307 | text_latents = self.to_text_latent(text_latents) 308 | image_latents = self.to_visual_latent(image_latents) 309 | 310 | text_latents, image_latents = map(lambda t: F.normalize(t, p = 2, dim = -1), (text_latents, image_latents)) 311 | 312 | temp = self.temperature.exp() 313 | 314 | if not return_loss: 315 | sim = einsum('n d, n d -> n', text_latents, image_latents) * temp 316 | return sim 317 | 318 | sim = einsum('i d, j d -> i j', text_latents, image_latents) * temp 319 | labels = torch.arange(b, device = device) 320 | loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2 321 | return loss 322 | 323 | # main DALL-E class 324 | 325 | class DALLE(nn.Module): 326 | def __init__( 327 | self, 328 | *, 329 | dim, 330 | vae, 331 | num_text_tokens = 10000, 332 | text_seq_len = 256, 333 | depth, 334 | heads = 8, 335 | dim_head = 64, 336 | reversible = False, 337 | attn_dropout = 0., 338 | ff_dropout = 0, 339 | sparse_attn = False, 340 | attn_types = None, 341 | loss_img_weight = 7, 342 | stable = False, 343 | sandwich_norm = False, 344 | shift_tokens = True, 345 | rotary_emb = True 346 | ): 347 | super().__init__() 348 | assert isinstance(vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE)), 'vae must be an instance of DiscreteVAE' 349 | 350 | #image_size = vae.image_size 351 | image_size = 128 352 | num_image_tokens = vae.num_tokens 353 | #image_fmap_size = (vae.image_size // (2 ** vae.num_layers)) 354 | image_fmap_size = 16 355 | image_seq_len = image_fmap_size ** 2 356 | 357 | num_text_tokens = num_text_tokens + text_seq_len # reserve unique padding tokens for each position (text seq len) 358 | 359 | self.text_emb = nn.Embedding(num_text_tokens, dim) 360 | self.image_emb = nn.Embedding(num_image_tokens, dim) 361 | 362 | self.text_pos_emb = nn.Embedding(text_seq_len + 1, dim) if not rotary_emb else always(0) # +1 for 363 | self.image_pos_emb = AxialPositionalEmbedding(dim, axial_shape = (image_fmap_size, image_fmap_size)) if not rotary_emb else always(0) 364 | 365 | self.num_text_tokens = num_text_tokens # for offsetting logits index and calculating cross entropy loss 366 | self.num_image_tokens = num_image_tokens 367 | 368 | self.text_seq_len = text_seq_len 369 | self.image_seq_len = image_seq_len 370 | 371 | seq_len = text_seq_len + image_seq_len 372 | total_tokens = num_text_tokens + num_image_tokens 373 | self.total_tokens = total_tokens 374 | self.total_seq_len = seq_len 375 | 376 | self.vae = vae 377 | set_requires_grad(self.vae, False) # freeze VAE from being trained 378 | 379 | self.transformer = Transformer( 380 | dim = dim, 381 | causal = True, 382 | seq_len = seq_len, 383 | depth = depth, 384 | heads = heads, 385 | dim_head = dim_head, 386 | reversible = reversible, 387 | attn_dropout = attn_dropout, 388 | ff_dropout = ff_dropout, 389 | attn_types = attn_types, 390 | image_fmap_size = image_fmap_size, 391 | sparse_attn = sparse_attn, 392 | stable = stable, 393 | sandwich_norm = sandwich_norm, 394 | shift_tokens = shift_tokens, 395 | rotary_emb = rotary_emb 396 | ) 397 | 398 | self.stable = stable 399 | 400 | if stable: 401 | self.norm_by_max = DivideMax(dim = -1) 402 | 403 | self.to_logits = nn.Sequential( 404 | nn.LayerNorm(dim), 405 | nn.Linear(dim, self.total_tokens), 406 | ) 407 | 408 | seq_range = torch.arange(seq_len) 409 | logits_range = torch.arange(total_tokens) 410 | 411 | seq_range = rearrange(seq_range, 'n -> () n ()') 412 | logits_range = rearrange(logits_range, 'd -> () () d') 413 | 414 | logits_mask = ( 415 | ((seq_range >= text_seq_len) & (logits_range < num_text_tokens)) | 416 | ((seq_range < text_seq_len) & (logits_range >= num_text_tokens)) 417 | ) 418 | 419 | self.register_buffer('logits_mask', logits_mask, persistent=False) 420 | self.loss_img_weight = loss_img_weight 421 | 422 | 423 | @torch.no_grad() 424 | @eval_decorator 425 | def generate_texts( 426 | self, 427 | tokenizer, 428 | text = None, 429 | *, 430 | filter_thres = 0.5, 431 | temperature = 1. 432 | ): 433 | text_seq_len = self.text_seq_len 434 | if text is None or text == "": 435 | text_tokens = torch.tensor([[0]]).cuda() 436 | else: 437 | text_tokens = torch.tensor(tokenizer.tokenizer.encode(text)).cuda().unsqueeze(0) 438 | 439 | for _ in range(text_tokens.shape[1], text_seq_len): 440 | device = text_tokens.device 441 | 442 | tokens = self.text_emb(text_tokens) 443 | tokens += self.text_pos_emb(torch.arange(text_tokens.shape[1], device = device)) 444 | 445 | seq_len = tokens.shape[1] 446 | 447 | output_transf = self.transformer(tokens) 448 | 449 | if self.stable: 450 | output_transf = self.norm_by_max(output_transf) 451 | 452 | logits = self.to_logits(output_transf) 453 | 454 | # mask logits to make sure text predicts text (except last token), and image predicts image 455 | 456 | logits_mask = self.logits_mask[:, :seq_len] 457 | max_neg_value = -torch.finfo(logits.dtype).max 458 | logits.masked_fill_(logits_mask, max_neg_value) 459 | logits = logits[:, -1, :] 460 | 461 | filtered_logits = top_k(logits, thres = filter_thres) 462 | probs = F.softmax(filtered_logits / temperature, dim = -1) 463 | sample = torch.multinomial(probs, 1) 464 | 465 | text_tokens = torch.cat((text_tokens, sample), dim=-1) 466 | 467 | padding_tokens = set(np.arange(self.text_seq_len) + (self.num_text_tokens - self.text_seq_len)) 468 | texts = [tokenizer.tokenizer.decode(text_token, pad_tokens=padding_tokens) for text_token in text_tokens] 469 | return text_tokens, texts 470 | 471 | @torch.no_grad() 472 | @eval_decorator 473 | def generate_images( 474 | self, 475 | text, 476 | *, 477 | clip = None, 478 | mask = None, 479 | top_k_thresh = None, 480 | top_p_thresh = None, 481 | temperature = 1., 482 | img = None, 483 | num_init_img_tokens = None, 484 | return_tokens = False 485 | ): 486 | vae, text_seq_len, image_seq_len, num_text_tokens = self.vae, self.text_seq_len, self.image_seq_len, self.num_text_tokens 487 | total_len = text_seq_len + image_seq_len 488 | 489 | text = text[:, :text_seq_len] # make sure text is within bounds 490 | out = text 491 | 492 | if exists(img): 493 | image_size = vae.image_size 494 | assert img.shape[1] == 3 and img.shape[2] == image_size and img.shape[3] == image_size, f'input image must have the correct image size {image_size}' 495 | 496 | indices = vae.get_codebook_indices(img) 497 | num_img_tokens = default(num_init_img_tokens, int(0.4375 * image_seq_len)) # OpenAI used 14 * 32 initial tokens to prime 498 | assert num_img_tokens < image_seq_len, 'number of initial image tokens for priming must be less than the total image token sequence length' 499 | 500 | indices = indices[:, :num_img_tokens] 501 | out = torch.cat((out, indices), dim = -1) 502 | 503 | for cur_len in range(out.shape[1], total_len): 504 | is_image = cur_len >= text_seq_len 505 | 506 | text, image = out[:, :text_seq_len], out[:, text_seq_len:] 507 | 508 | logits = self(text, image, mask = mask)[:, -1, :] 509 | 510 | if top_k_thresh is not None: 511 | filtered_logits = top_k(logits, thres = top_k_thresh) 512 | else: 513 | filtered_logits = top_p(logits, thres = top_p_thresh) 514 | 515 | probs = F.softmax(filtered_logits / temperature, dim = -1) 516 | sample = torch.multinomial(probs, 1) 517 | 518 | sample -= (num_text_tokens if is_image else 0) # offset sampled token if it is an image token, since logit space is composed of text and then image tokens 519 | out = torch.cat((out, sample), dim=-1) 520 | 521 | if out.shape[1] <= text_seq_len: 522 | mask = F.pad(mask, (0, 1), value = True) 523 | 524 | text_seq = out[:, :text_seq_len] 525 | 526 | img_seq = out[:, -image_seq_len:] 527 | images = vae.decode(img_seq) 528 | 529 | if exists(clip): 530 | scores = clip(text_seq, images, return_loss = False) 531 | return images, scores 532 | 533 | if return_tokens: 534 | return images, img_seq 535 | else: 536 | return images, None 537 | 538 | def forward( 539 | self, 540 | text, 541 | image = None, 542 | mask = None, 543 | return_loss = False 544 | ): 545 | assert text.shape[-1] == self.text_seq_len, f'the length {text.shape[-1]} of the text tokens you passed in does not have the correct length ({self.text_seq_len})' 546 | device, total_seq_len = text.device, self.total_seq_len 547 | 548 | # make sure padding in text tokens get unique padding token id 549 | 550 | text_range = torch.arange(self.text_seq_len, device = device) + (self.num_text_tokens - self.text_seq_len) 551 | text = torch.where(text == 0, text_range, text) 552 | 553 | # add 554 | 555 | text = F.pad(text, (1, 0), value = 0) 556 | 557 | tokens = self.text_emb(text) 558 | tokens += self.text_pos_emb(torch.arange(text.shape[1], device = device)) 559 | 560 | seq_len = tokens.shape[1] 561 | 562 | if exists(image) and not is_empty(image): 563 | is_raw_image = len(image.shape) == 4 564 | 565 | if is_raw_image: 566 | image_size = self.vae.image_size 567 | assert tuple(image.shape[1:]) == (3, image_size, image_size), f'invalid image of dimensions {image.shape} passed in during training' 568 | 569 | image = self.vae.get_codebook_indices(image) 570 | 571 | image_len = image.shape[1] 572 | image_emb = self.image_emb(image) 573 | 574 | image_emb += self.image_pos_emb(image_emb) 575 | 576 | tokens = torch.cat((tokens, image_emb), dim = 1) 577 | 578 | seq_len += image_len 579 | 580 | # when training, if the length exceeds the total text + image length 581 | # remove the last token, since it needs not to be trained 582 | 583 | if tokens.shape[1] > total_seq_len: 584 | seq_len -= 1 585 | tokens = tokens[:, :-1] 586 | 587 | if self.stable: 588 | alpha = 0.1 589 | tokens = tokens * alpha + tokens.detach() * (1 - alpha) 590 | 591 | out = self.transformer(tokens) 592 | 593 | if self.stable: 594 | out = self.norm_by_max(out) 595 | 596 | logits = self.to_logits(out) 597 | 598 | # mask logits to make sure text predicts text (except last token), and image predicts image 599 | 600 | logits_mask = self.logits_mask[:, :seq_len] 601 | max_neg_value = -torch.finfo(logits.dtype).max 602 | logits.masked_fill_(logits_mask, max_neg_value) 603 | 604 | if not return_loss: 605 | return logits 606 | 607 | assert exists(image), 'when training, image must be supplied' 608 | 609 | offsetted_image = image + self.num_text_tokens 610 | labels = torch.cat((text[:, 1:], offsetted_image), dim = 1) 611 | 612 | logits = rearrange(logits, 'b n c -> b c n') 613 | 614 | loss_text = F.cross_entropy(logits[:, :, :self.text_seq_len], labels[:, :self.text_seq_len]) 615 | loss_img = F.cross_entropy(logits[:, :, self.text_seq_len:], labels[:, self.text_seq_len:]) 616 | 617 | loss = (loss_text + self.loss_img_weight * loss_img) / (self.loss_img_weight + 1) 618 | return loss 619 | -------------------------------------------------------------------------------- /dalle_pytorch/distributed_backends/__init__.py: -------------------------------------------------------------------------------- 1 | from .deepspeed_backend import DeepSpeedBackend 2 | from .distributed_backend import DistributedBackend 3 | from .dummy_backend import DummyBackend 4 | from .horovod_backend import HorovodBackend 5 | -------------------------------------------------------------------------------- /dalle_pytorch/distributed_backends/deepspeed_backend.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import torch 5 | 6 | from .distributed_backend import DistributedBackend 7 | 8 | 9 | class DeepSpeedBackend(DistributedBackend): 10 | """Distributed backend using the DeepSpeed engine.""" 11 | 12 | BACKEND_MODULE_NAME = 'deepspeed' 13 | BACKEND_NAME = 'DeepSpeed' 14 | 15 | def wrap_arg_parser(self, parser): 16 | if not self.has_backend(): 17 | parser.add_argument( 18 | '--deepspeed', 19 | type=lambda _: False, 20 | help=( 21 | 'whether to use DeepSpeed ' 22 | "(ignored since it's not available)" 23 | ), 24 | ) 25 | else: 26 | parser = self.backend_module.add_config_arguments(parser) 27 | 28 | parser.add_argument( 29 | '--local_rank', 30 | type=int, 31 | default=-1, 32 | help='local rank passed from distributed launcher', 33 | ) 34 | return parser 35 | 36 | def _initialize(self): 37 | self.backend_module.init_distributed() 38 | if torch.cuda.is_available(): 39 | torch.cuda.set_device(self._get_local_rank()) 40 | 41 | @staticmethod 42 | def _require_torch_distributed_init(): 43 | """Raise an error when `torch.distributed` has not been 44 | initialized yet. 45 | """ 46 | assert torch.distributed.is_initialized(), \ 47 | ('`torch.distributed` is not initialized; please call ' 48 | '`DeepSpeedBackend.initialize` at the start of your script') 49 | 50 | def _get_world_size(self): 51 | self._require_torch_distributed_init() 52 | return torch.distributed.get_world_size() 53 | 54 | def _get_rank(self): 55 | self._require_torch_distributed_init() 56 | return torch.distributed.get_rank() 57 | 58 | def _get_local_rank(self): 59 | self._require_torch_distributed_init() 60 | return int(os.environ['LOCAL_RANK']) 61 | 62 | def _local_barrier(self): 63 | self._require_torch_distributed_init() 64 | torch.distributed.barrier() 65 | 66 | def _check_args(self, args, optimizer, lr_scheduler, kwargs): 67 | """Return an appropriate optimizer and learning rate scheduler 68 | after checking the values passed to `distribute`. 69 | """ 70 | self._check_argvs(args, optimizer, lr_scheduler, kwargs) 71 | (optimizer, lr_scheduler) = self._check_config( 72 | args, optimizer, lr_scheduler, kwargs) 73 | return (optimizer, lr_scheduler) 74 | 75 | def _check_argvs(self, args, optimizer, lr_scheduler, kwargs): 76 | """Apply several sanity checks to the given command 77 | line arguments. 78 | """ 79 | has_json_config = (hasattr(args, 'deepspeed_config') 80 | and args.deepspeed_config is not None) 81 | has_dict_config = 'config_params' in kwargs 82 | if ( 83 | # No config given 84 | (not has_json_config and not has_dict_config) 85 | # JSON config file does not exist 86 | or (not has_dict_config 87 | and not os.path.isfile(args.deepspeed_config)) 88 | ): 89 | # Let DeepSpeed handle these argument errors. 90 | return 91 | 92 | if not args.deepspeed: 93 | print( 94 | 'WARNING: DeepSpeed backend was selected; setting ' 95 | '`args.deepspeed = True`' 96 | ) 97 | args.deepspeed = True 98 | 99 | if has_json_config and has_dict_config: 100 | print( 101 | 'WARNING: DeepSpeed config was given as both JSON file and ' 102 | 'Python dictionary. Python dictionary takes precedence.' 103 | ) 104 | 105 | def _check_config(self, args, optimizer, lr_scheduler, kwargs): 106 | """Return an appropriate optimizer and learning rate scheduler 107 | for the DeepSpeed configuration. 108 | """ 109 | if 'config_params' in kwargs: 110 | config = kwargs['config_params'] 111 | else: 112 | with open(args.deepspeed_config, 'r') as json_config_file: 113 | config = json.load(json_config_file) 114 | 115 | if 'optimizer' in config and optimizer is not None: 116 | print( 117 | 'WARNING: Optimizer encountered in both DeepSpeed config and ' 118 | 'keyword arguments. Optimizer in DeepSpeed config ' 119 | 'takes precedence.' 120 | ) 121 | optimizer = None 122 | 123 | if 'scheduler' in config and lr_scheduler is not None: 124 | print( 125 | 'WARNING: Learning rate scheduler encountered in both ' 126 | 'DeepSpeed config and keyword arguments. Learning rate ' 127 | 'scheduler in DeepSpeed config takes precedence.' 128 | ) 129 | # For the LR scheduler, the JSON config already has 130 | # precedence. We do this for forward compatibility. 131 | lr_scheduler = None 132 | 133 | return (optimizer, lr_scheduler) 134 | 135 | def _distribute( 136 | self, 137 | args=None, 138 | model=None, 139 | optimizer=None, 140 | model_parameters=None, 141 | training_data=None, 142 | lr_scheduler=None, 143 | **kwargs, 144 | ): 145 | """Return a distributed model engine, optimizer, dataloader, and 146 | learning rate scheduler. These are obtained by wrapping the 147 | given values with the backend. 148 | 149 | For the other or other possible arguments, 150 | see `deepspeed.initialize`. 151 | """ 152 | (optimizer, lr_scheduler) = self._check_args( 153 | args, optimizer, lr_scheduler, kwargs) 154 | 155 | return self.backend_module.initialize( 156 | args=args, 157 | model=model, 158 | optimizer=optimizer, 159 | model_parameters=model_parameters, 160 | training_data=training_data, 161 | lr_scheduler=lr_scheduler, 162 | **kwargs, 163 | ) 164 | 165 | def _average_all(self, tensor): 166 | self._require_torch_distributed_init() 167 | # We copy because modification happens in-place 168 | averaged = tensor.detach().clone() 169 | # We use `all_reduce` because it is better supported than `reduce` 170 | torch.distributed.all_reduce(averaged, torch.distributed.ReduceOp.SUM) 171 | return averaged / self.get_world_size() 172 | -------------------------------------------------------------------------------- /dalle_pytorch/distributed_backends/distributed_backend.py: -------------------------------------------------------------------------------- 1 | """ 2 | An abstract backend for distributed deep learning. 3 | 4 | Provides several standard utility methods under a common API. 5 | Please check the documentation of the class `DistributedBackend` for 6 | details to implement a new backend. 7 | """ 8 | 9 | from importlib import import_module 10 | 11 | 12 | class DistributedBackend: 13 | """An abstract backend class for distributed deep learning. 14 | 15 | Provides several standard utility methods under a common API. 16 | Variables that must be overridden: 17 | - BACKEND_MODULE_NAME 18 | - BACKEND_NAME 19 | Methods that must be overridden: 20 | - wrap_arg_parser 21 | - _initialize 22 | - _get_world_size 23 | - _get_rank 24 | - _get_local_rank 25 | - _local_barrier 26 | - _distribute 27 | - _average_all 28 | """ 29 | 30 | BACKEND_MODULE_NAME = None 31 | """Name of the module to import for the backend.""" 32 | BACKEND_NAME = None 33 | """Name of the backend for printing.""" 34 | 35 | ROOT_RANK = 0 36 | 37 | backend_module = None 38 | """The module to access the backend.""" 39 | is_initialized = False 40 | """Whether the backend is initialized.""" 41 | 42 | def __init__(self): 43 | if self.BACKEND_MODULE_NAME is None: 44 | raise NotImplementedError('BACKEND_MODULE_NAME is not set') 45 | if self.BACKEND_NAME is None: 46 | raise NotImplementedError('BACKEND_NAME is not set') 47 | 48 | def has_backend(self): 49 | """Return whether the backend module is now imported.""" 50 | try: 51 | self.backend_module = import_module(self.BACKEND_MODULE_NAME) 52 | except ModuleNotFoundError: 53 | return False 54 | return True 55 | 56 | def check_batch_size(self, batch_size): 57 | """Check whether the batch size makes sense for distribution.""" 58 | assert batch_size >= self.get_world_size(), \ 59 | (f"batch size can't be smaller than number of processes " 60 | f'({batch_size} < {self.get_world_size()})') 61 | 62 | def wrap_arg_parser(self, parser): 63 | """Add arguments to support optional distributed backend usage.""" 64 | raise NotImplementedError 65 | 66 | def initialize(self): 67 | """Initialize the distributed backend.""" 68 | self._initialize() 69 | self.is_initialized = True 70 | 71 | def _initialize(self): 72 | """Initialize the distributed backend.""" 73 | raise NotImplementedError 74 | 75 | def require_init(self): 76 | """Raise an error when the backend has not been initialized yet.""" 77 | assert self.is_initialized, \ 78 | (f'{BACKEND_NAME} backend has not been initialized; please call ' 79 | f'`distributed_utils.initialize` at the start of your script to ' 80 | f'allow optional distributed usage') 81 | 82 | def get_world_size(self): 83 | """Return the amount of distributed processes.""" 84 | self.require_init() 85 | return self._get_world_size() 86 | 87 | def _get_world_size(self): 88 | """Return the amount of distributed processes.""" 89 | raise NotImplementedError 90 | 91 | def get_rank(self): 92 | """Return the global rank of the calling worker process.""" 93 | self.require_init() 94 | return self._get_rank() 95 | 96 | def _get_rank(self): 97 | """Return the global rank of the calling worker process.""" 98 | raise NotImplementedError 99 | 100 | def get_local_rank(self): 101 | """Return the local rank of the calling worker process. 102 | The local rank is the rank based on a single node's processes. 103 | """ 104 | self.require_init() 105 | return self._get_local_rank() 106 | 107 | def _get_local_rank(self): 108 | """Return the local rank of the calling worker process. 109 | The local rank is the rank based on a single node's processes. 110 | """ 111 | raise NotImplementedError 112 | 113 | def is_root_worker(self): 114 | """Return whether the calling worker has the root rank.""" 115 | return self.get_rank() == self.ROOT_RANK 116 | 117 | def is_local_root_worker(self): 118 | """Return whether the calling worker has the root rank on this node.""" 119 | return self.get_local_rank() == self.ROOT_RANK 120 | 121 | def local_barrier(self): 122 | """Wait until all processes on this node have called this function.""" 123 | self.require_init() 124 | self._local_barrier() 125 | 126 | def _local_barrier(self): 127 | """Wait until all processes on this node have called this function.""" 128 | raise NotImplementedError 129 | 130 | def distribute( 131 | self, 132 | args=None, 133 | model=None, 134 | optimizer=None, 135 | model_parameters=None, 136 | training_data=None, 137 | lr_scheduler=None, 138 | **kwargs, 139 | ): 140 | """Return a distributed model engine, optimizer, dataloader, and 141 | learning rate scheduler. These are obtained by wrapping the 142 | given values with the backend. 143 | """ 144 | self.require_init() 145 | return self._distribute( 146 | args, 147 | model, 148 | optimizer, 149 | model_parameters, 150 | training_data, 151 | lr_scheduler, 152 | **kwargs, 153 | ) 154 | 155 | def _distribute( 156 | self, 157 | args=None, 158 | model=None, 159 | optimizer=None, 160 | model_parameters=None, 161 | training_data=None, 162 | lr_scheduler=None, 163 | **kwargs, 164 | ): 165 | """Return a distributed model engine, optimizer, dataloader, and 166 | learning rate scheduler. These are obtained by wrapping the 167 | given values with the backend. 168 | """ 169 | raise NotImplementedError 170 | 171 | def average_all(self, tensor): 172 | """Return the average of `tensor` over all workers.""" 173 | self.require_init() 174 | return self._average_all(tensor) 175 | 176 | def _average_all(self, tensor): 177 | """Return the average of `tensor` over all workers.""" 178 | raise NotImplementedError 179 | -------------------------------------------------------------------------------- /dalle_pytorch/distributed_backends/dummy_backend.py: -------------------------------------------------------------------------------- 1 | from .distributed_backend import DistributedBackend 2 | 3 | 4 | class DummyBackend(DistributedBackend): 5 | """Acts like a distributed backend. 6 | 7 | Used as a stand-in replacement to obtain a non-distributed program. 8 | """ 9 | 10 | # We define this so we can use `super().__init__` but want this to 11 | # throw an error upon import. 12 | BACKEND_MODULE_NAME = 'NO MODULE' 13 | BACKEND_NAME = 'Dummy' 14 | 15 | def has_backend(self): 16 | return True 17 | 18 | def wrap_arg_parser(self, parser): 19 | return parser 20 | 21 | def _initialize(self): 22 | pass 23 | 24 | def _get_world_size(self): 25 | return 1 26 | 27 | def _get_rank(self): 28 | return self.ROOT_RANK 29 | 30 | def _get_local_rank(self): 31 | return self.ROOT_RANK 32 | 33 | def _local_barrier(self): 34 | pass 35 | 36 | def _distribute( 37 | self, 38 | _args=None, 39 | model=None, 40 | optimizer=None, 41 | _model_parameters=None, 42 | training_data=None, 43 | lr_scheduler=None, 44 | **_kwargs, 45 | ): 46 | """Return the model, optimizer, dataloader, and learning rate scheduler 47 | as is. 48 | """ 49 | return (model, optimizer, training_data, lr_scheduler) 50 | 51 | def _average_all(self, tensor): 52 | return tensor 53 | -------------------------------------------------------------------------------- /dalle_pytorch/distributed_backends/horovod_backend.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .distributed_backend import DistributedBackend 4 | 5 | 6 | class HorovodBackend(DistributedBackend): 7 | """Distributed backend using Horovod.""" 8 | 9 | BACKEND_MODULE_NAME = 'horovod.torch' 10 | BACKEND_NAME = 'Horovod' 11 | 12 | def wrap_arg_parser(self, parser): 13 | return parser 14 | 15 | def check_batch_size(self, batch_size): 16 | # Horovod uses the local batch size to determine the effective 17 | # batch size. 18 | pass 19 | 20 | def _initialize(self): 21 | self.backend_module.init() 22 | if torch.cuda.is_available(): 23 | torch.cuda.set_device(self._get_local_rank()) 24 | 25 | def _get_world_size(self): 26 | return self.backend_module.size() 27 | 28 | def _get_rank(self): 29 | return self.backend_module.rank() 30 | 31 | def _get_local_rank(self): 32 | return self.backend_module.local_rank() 33 | 34 | def _local_barrier(self): 35 | # Actually a global barrier but works for our purposes. 36 | self.backend_module.join() 37 | 38 | def _distribute( 39 | self, 40 | _args=None, 41 | model=None, 42 | optimizer=None, 43 | _model_parameters=None, 44 | training_data=None, 45 | lr_scheduler=None, 46 | **_kwargs, 47 | ): 48 | optimizer = self.backend_module.DistributedOptimizer(optimizer) 49 | self.backend_module.broadcast_parameters( 50 | model.state_dict(), root_rank=self.ROOT_RANK) 51 | self.backend_module.broadcast_optimizer_state( 52 | optimizer, root_rank=self.ROOT_RANK) 53 | return (model, optimizer, training_data, lr_scheduler) 54 | 55 | def _average_all(self, tensor): 56 | # Reduce op is average by default 57 | averaged = self.backend_module.allreduce(tensor) 58 | return averaged 59 | -------------------------------------------------------------------------------- /dalle_pytorch/distributed_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for optional distributed execution. 3 | 4 | To use, 5 | 1. set the `BACKENDS` to the ones you want to make available, 6 | 2. in the script, wrap the argument parser with `wrap_arg_parser`, 7 | 3. in the script, set and use the backend by calling 8 | `set_backend_from_args`. 9 | 10 | You can check whether a backend is in use with the `using_backend` 11 | function. 12 | """ 13 | 14 | from dalle_pytorch.distributed_backends import \ 15 | DeepSpeedBackend, \ 16 | DummyBackend, \ 17 | HorovodBackend 18 | 19 | _DEFAULT_BACKEND = DummyBackend() 20 | """Which backend to use by default. Assumed to be _not_ distributed.""" 21 | 22 | BACKENDS = [ 23 | _DEFAULT_BACKEND, 24 | DeepSpeedBackend(), 25 | HorovodBackend(), 26 | ] 27 | 28 | is_distributed = None 29 | """Whether we are distributed.""" 30 | backend = None 31 | """Backend in usage.""" 32 | 33 | 34 | def wrap_arg_parser(parser): 35 | """Add arguments to support optional distributed backend usage.""" 36 | parser.add_argument( 37 | '--distributed_backend', 38 | '--distr_backend', 39 | type=str, 40 | default=None, 41 | help='which distributed backend to use. Do not distribute by default', 42 | ) 43 | for distr_backend in BACKENDS: 44 | parser = distr_backend.wrap_arg_parser(parser) 45 | return parser 46 | 47 | 48 | def set_backend_from_args(args): 49 | """Set and return the backend based on the given `args`.""" 50 | global is_distributed, backend 51 | 52 | # Handle this specially for backwards compatibility. 53 | if args.deepspeed: 54 | args.distributed_backend = DeepSpeedBackend.BACKEND_NAME 55 | 56 | if not args.distributed_backend: 57 | is_distributed = False 58 | backend = _DEFAULT_BACKEND 59 | return backend 60 | 61 | backend_name = args.distributed_backend.lower() 62 | for distr_backend in BACKENDS: 63 | if distr_backend.BACKEND_NAME.lower() == backend_name: 64 | backend = distr_backend 65 | if not backend.has_backend(): 66 | raise ModuleNotFoundError( 67 | f'{backend.BACKEND_NAME} backend selected but ' 68 | 'module not available' 69 | ) 70 | 71 | print(f'Using {backend.BACKEND_NAME} for distributed execution') 72 | is_distributed = True 73 | return backend 74 | 75 | raise ValueError( 76 | 'unknown backend; please check `distributed_utils.BACKENDS`') 77 | 78 | 79 | def require_set_backend(): 80 | """Raise an `AssertionError` when the backend has not been set.""" 81 | assert backend is not None, ( 82 | 'distributed backend is not set. Please call ' 83 | '`distributed_utils.set_backend_from_args` at the start of your script' 84 | ) 85 | 86 | 87 | def using_backend(test_backend): 88 | """Return whether the backend is set to `test_backend`. 89 | 90 | `test_backend` may be a string of the name of the backend or 91 | its class. 92 | """ 93 | require_set_backend() 94 | if isinstance(test_backend, str): 95 | return backend.BACKEND_NAME == test_backend 96 | return isinstance(backend, test_backend) 97 | -------------------------------------------------------------------------------- /dalle_pytorch/loader.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from random import randint, choice 3 | 4 | import PIL 5 | 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms as T 8 | 9 | 10 | class TextImageDataset(Dataset): 11 | def __init__(self, 12 | folder, 13 | text_len=256, 14 | image_size=128, 15 | truncate_captions=False, 16 | resize_ratio=0.75, 17 | tokenizer=None, 18 | shuffle=False 19 | ): 20 | """ 21 | @param folder: Folder containing images and text files matched by their paths' respective "stem" 22 | @param truncate_captions: Rather than throw an exception, captions which are too long will be truncated. 23 | """ 24 | super().__init__() 25 | self.shuffle = shuffle 26 | path = Path(folder) 27 | 28 | text_files = [*path.glob('**/*.txt')] 29 | image_files = [ 30 | *path.glob('**/*.png'), *path.glob('**/*.jpg'), 31 | *path.glob('**/*.jpeg'), *path.glob('**/*.bmp') 32 | ] 33 | 34 | text_files = {text_file.stem: text_file for text_file in text_files} 35 | image_files = {image_file.stem: image_file for image_file in image_files} 36 | 37 | keys = (image_files.keys() & text_files.keys()) 38 | 39 | self.keys = list(keys) 40 | self.text_files = {k: v for k, v in text_files.items() if k in keys} 41 | self.image_files = {k: v for k, v in image_files.items() if k in keys} 42 | self.text_len = text_len 43 | self.truncate_captions = truncate_captions 44 | self.resize_ratio = resize_ratio 45 | self.tokenizer = tokenizer 46 | self.image_transform = T.Compose([ 47 | T.Lambda(lambda img: img.convert('RGB') 48 | if img.mode != 'RGB' else img), 49 | T.RandomResizedCrop(image_size, 50 | scale=(self.resize_ratio, 1.), 51 | ratio=(1., 1.)), 52 | T.ToTensor() 53 | ]) 54 | 55 | def __len__(self): 56 | return len(self.keys) 57 | 58 | def random_sample(self): 59 | return self.__getitem__(randint(0, self.__len__() - 1)) 60 | 61 | def sequential_sample(self, ind): 62 | if ind >= self.__len__() - 1: 63 | return self.__getitem__(0) 64 | return self.__getitem__(ind + 1) 65 | 66 | def skip_sample(self, ind): 67 | if self.shuffle: 68 | return self.random_sample() 69 | return self.sequential_sample(ind=ind) 70 | 71 | def __getitem__(self, ind): 72 | key = self.keys[ind] 73 | 74 | text_file = self.text_files[key] 75 | image_file = self.image_files[key] 76 | 77 | descriptions = text_file.read_text().split('\n') 78 | descriptions = list(filter(lambda t: len(t) > 0, descriptions)) 79 | try: 80 | description = choice(descriptions) 81 | except IndexError as zero_captions_in_file_ex: 82 | print(f"An exception occurred trying to load file {text_file}.") 83 | print(f"Skipping index {ind}") 84 | return self.skip_sample(ind) 85 | 86 | tokenized_text = self.tokenizer.tokenize( 87 | description, 88 | self.text_len, 89 | truncate_text=self.truncate_captions 90 | ).squeeze(0) 91 | try: 92 | image_tensor = self.image_transform(PIL.Image.open(image_file)) 93 | except (PIL.UnidentifiedImageError, OSError) as corrupt_image_exceptions: 94 | print(f"An exception occurred trying to load file {image_file}.") 95 | print(f"Skipping index {ind}") 96 | return self.skip_sample(ind) 97 | 98 | # Success 99 | return tokenized_text, image_tensor 100 | -------------------------------------------------------------------------------- /dalle_pytorch/reversible.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from operator import itemgetter 4 | from torch.autograd.function import Function 5 | from torch.utils.checkpoint import get_device_states, set_device_states 6 | 7 | # for routing arguments into the functions of the reversible layer 8 | def route_args(router, args, depth): 9 | routed_args = [(dict(), dict()) for _ in range(depth)] 10 | matched_keys = [key for key in args.keys() if key in router] 11 | 12 | for key in matched_keys: 13 | val = args[key] 14 | for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])): 15 | new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes) 16 | routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) 17 | return routed_args 18 | 19 | # following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html 20 | class Deterministic(nn.Module): 21 | def __init__(self, net): 22 | super().__init__() 23 | self.net = net 24 | self.cpu_state = None 25 | self.cuda_in_fwd = None 26 | self.gpu_devices = None 27 | self.gpu_states = None 28 | 29 | def record_rng(self, *args): 30 | self.cpu_state = torch.get_rng_state() 31 | if torch.cuda._initialized: 32 | self.cuda_in_fwd = True 33 | self.gpu_devices, self.gpu_states = get_device_states(*args) 34 | 35 | def forward(self, *args, record_rng = False, set_rng = False, **kwargs): 36 | if record_rng: 37 | self.record_rng(*args) 38 | 39 | if not set_rng: 40 | return self.net(*args, **kwargs) 41 | 42 | rng_devices = [] 43 | if self.cuda_in_fwd: 44 | rng_devices = self.gpu_devices 45 | 46 | with torch.random.fork_rng(devices=rng_devices, enabled=True): 47 | torch.set_rng_state(self.cpu_state) 48 | if self.cuda_in_fwd: 49 | set_device_states(self.gpu_devices, self.gpu_states) 50 | return self.net(*args, **kwargs) 51 | 52 | # heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 53 | # once multi-GPU is confirmed working, refactor and send PR back to source 54 | class ReversibleBlock(nn.Module): 55 | def __init__(self, f, g): 56 | super().__init__() 57 | self.f = Deterministic(f) 58 | self.g = Deterministic(g) 59 | 60 | def forward(self, x, f_args = {}, g_args = {}): 61 | x1, x2 = torch.chunk(x, 2, dim=2) 62 | y1, y2 = None, None 63 | 64 | with torch.no_grad(): 65 | y1 = x1 + self.f(x2, record_rng=self.training, **f_args) 66 | y2 = x2 + self.g(y1, record_rng=self.training, **g_args) 67 | 68 | return torch.cat([y1, y2], dim=2) 69 | 70 | def backward_pass(self, y, dy, f_args = {}, g_args = {}): 71 | y1, y2 = torch.chunk(y, 2, dim=2) 72 | del y 73 | 74 | dy1, dy2 = torch.chunk(dy, 2, dim=2) 75 | del dy 76 | 77 | with torch.enable_grad(): 78 | y1.requires_grad = True 79 | gy1 = self.g(y1, set_rng=True, **g_args) 80 | torch.autograd.backward(gy1, dy2) 81 | 82 | with torch.no_grad(): 83 | x2 = y2 - gy1 84 | del y2, gy1 85 | 86 | dx1 = dy1 + y1.grad 87 | del dy1 88 | y1.grad = None 89 | 90 | with torch.enable_grad(): 91 | x2.requires_grad = True 92 | fx2 = self.f(x2, set_rng=True, **f_args) 93 | torch.autograd.backward(fx2, dx1, retain_graph=True) 94 | 95 | with torch.no_grad(): 96 | x1 = y1 - fx2 97 | del y1, fx2 98 | 99 | dx2 = dy2 + x2.grad 100 | del dy2 101 | x2.grad = None 102 | 103 | x = torch.cat([x1, x2.detach()], dim=2) 104 | dx = torch.cat([dx1, dx2], dim=2) 105 | 106 | return x, dx 107 | 108 | class _ReversibleFunction(Function): 109 | @staticmethod 110 | def forward(ctx, x, blocks, args): 111 | ctx.args = args 112 | for block, kwarg in zip(blocks, args): 113 | x = block(x, **kwarg) 114 | ctx.y = x.detach() 115 | ctx.blocks = blocks 116 | return x 117 | 118 | @staticmethod 119 | def backward(ctx, dy): 120 | y = ctx.y 121 | args = ctx.args 122 | for block, kwargs in zip(ctx.blocks[::-1], args[::-1]): 123 | y, dy = block.backward_pass(y, dy, **kwargs) 124 | return dy, None, None 125 | 126 | class SequentialSequence(nn.Module): 127 | def __init__(self, layers, args_route = {}, layer_dropout = 0.): 128 | super().__init__() 129 | assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers' 130 | self.layers = layers 131 | self.args_route = args_route 132 | self.layer_dropout = layer_dropout 133 | 134 | def forward(self, x, **kwargs): 135 | args = route_args(self.args_route, kwargs, len(self.layers)) 136 | layers_and_args = list(zip(self.layers, args)) 137 | 138 | for (f, g), (f_args, g_args) in layers_and_args: 139 | x = x + f(x, **f_args) 140 | x = x + g(x, **g_args) 141 | return x 142 | 143 | class ReversibleSequence(nn.Module): 144 | def __init__(self, blocks, args_route = {}): 145 | super().__init__() 146 | self.args_route = args_route 147 | self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks]) 148 | 149 | def forward(self, x, **kwargs): 150 | x = torch.cat([x, x], dim=-1) 151 | 152 | blocks = self.blocks 153 | args = route_args(self.args_route, kwargs, len(blocks)) 154 | args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args)) 155 | 156 | out = _ReversibleFunction.apply(x, blocks, args) 157 | return torch.stack(out.chunk(2, dim=-1)).mean(dim=0) 158 | -------------------------------------------------------------------------------- /dalle_pytorch/tokenizer.py: -------------------------------------------------------------------------------- 1 | # take from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py 2 | # to give users a quick easy start to training DALL-E without doing BPE 3 | 4 | import torch 5 | 6 | import youtokentome as yttm 7 | from tokenizers import Tokenizer 8 | from tokenizers.processors import ByteLevel 9 | from transformers import BertTokenizer 10 | 11 | import html 12 | import os 13 | from functools import lru_cache 14 | from pathlib import Path 15 | import ftfy 16 | import regex as re 17 | 18 | # OpenAI simple tokenizer 19 | 20 | @lru_cache() 21 | def default_bpe(): 22 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/bpe_simple_vocab_16e6.txt") 23 | 24 | @lru_cache() 25 | def bytes_to_unicode(): 26 | bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2 ** 8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2 ** 8 + n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | def get_pairs(word): 38 | pairs = set() 39 | prev_char = word[0] 40 | for char in word[1:]: 41 | pairs.add((prev_char, char)) 42 | prev_char = char 43 | return pairs 44 | 45 | def basic_clean(text): 46 | text = ftfy.fix_text(text) 47 | text = html.unescape(html.unescape(text)) 48 | return text.strip() 49 | 50 | def whitespace_clean(text): 51 | text = re.sub(r'\s+', ' ', text) 52 | text = text.strip() 53 | return text 54 | 55 | class SimpleTokenizer(object): 56 | def __init__(self, bpe_path = default_bpe()): 57 | self.byte_encoder = bytes_to_unicode() 58 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 59 | merges = Path(bpe_path).read_text(encoding='utf8').split('\n') 60 | merges = merges[1:49152 - 256 - 2 + 1] 61 | merges = [tuple(merge.split()) for merge in merges] 62 | vocab = list(bytes_to_unicode().values()) 63 | vocab = vocab + [v + '' for v in vocab] 64 | for merge in merges: 65 | vocab.append(''.join(merge)) 66 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 67 | 68 | self.vocab_size = 49408 69 | 70 | self.encoder = dict(zip(vocab, range(len(vocab)))) 71 | self.decoder = {v: k for k, v in self.encoder.items()} 72 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 73 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 74 | self.pat = re.compile( 75 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", 76 | re.IGNORECASE) 77 | 78 | def bpe(self, token): 79 | if token in self.cache: 80 | return self.cache[token] 81 | word = tuple(token[:-1]) + (token[-1] + '',) 82 | pairs = get_pairs(word) 83 | 84 | if not pairs: 85 | return token + '' 86 | 87 | while True: 88 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 89 | if bigram not in self.bpe_ranks: 90 | break 91 | first, second = bigram 92 | new_word = [] 93 | i = 0 94 | while i < len(word): 95 | try: 96 | j = word.index(first, i) 97 | new_word.extend(word[i:j]) 98 | i = j 99 | except: 100 | new_word.extend(word[i:]) 101 | break 102 | 103 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 104 | new_word.append(first + second) 105 | i += 2 106 | else: 107 | new_word.append(word[i]) 108 | i += 1 109 | new_word = tuple(new_word) 110 | word = new_word 111 | if len(word) == 1: 112 | break 113 | else: 114 | pairs = get_pairs(word) 115 | word = ' '.join(word) 116 | self.cache[token] = word 117 | return word 118 | 119 | def encode(self, text): 120 | bpe_tokens = [] 121 | text = whitespace_clean(basic_clean(text)).lower() 122 | for token in re.findall(self.pat, text): 123 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 124 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 125 | return bpe_tokens 126 | 127 | def decode(self, tokens, remove_start_end = True, pad_tokens = {}): 128 | if torch.is_tensor(tokens): 129 | tokens = tokens.tolist() 130 | 131 | if remove_start_end: 132 | tokens = [token for token in tokens if token not in (49406, 40407, 0)] 133 | text = ''.join([self.decoder[token] for token in tokens if token not in pad_tokens]) 134 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 135 | return text 136 | 137 | def tokenize(self, texts, context_length = 256, truncate_text = False): 138 | if isinstance(texts, str): 139 | texts = [texts] 140 | 141 | all_tokens = [self.encode(text) for text in texts] 142 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 143 | 144 | for i, tokens in enumerate(all_tokens): 145 | if len(tokens) > context_length: 146 | if truncate_text: 147 | tokens = tokens[:context_length] 148 | else: 149 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 150 | result[i, :len(tokens)] = torch.tensor(tokens) 151 | 152 | return result 153 | 154 | tokenizer = SimpleTokenizer() 155 | 156 | # huggingface tokenizer 157 | 158 | class HugTokenizer: 159 | def __init__(self, bpe_path = None): 160 | bpe_path = Path(bpe_path) 161 | assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist' 162 | tokenizer = Tokenizer.from_file(str(bpe_path)) 163 | tokenizer.post_processor = ByteLevel(trim_offsets = True) 164 | self.tokenizer = tokenizer 165 | self.vocab_size = tokenizer.get_vocab_size() 166 | 167 | def decode(self, tokens, pad_tokens = {}): 168 | if torch.is_tensor(tokens): 169 | tokens = tokens.tolist() 170 | ignore_ids = pad_tokens.union({0}) 171 | tokens = [token for token in tokens if token not in ignore_ids] 172 | return self.tokenizer.decode(tokens, skip_special_tokens = True) 173 | 174 | def encode(self, text): 175 | return self.tokenizer.encode(text).ids 176 | 177 | def tokenize(self, texts, context_length = 256, truncate_text = False): 178 | if isinstance(texts, str): 179 | texts = [texts] 180 | 181 | all_tokens = [self.encode(text) for text in texts] 182 | 183 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 184 | for i, tokens in enumerate(all_tokens): 185 | if len(tokens) > context_length: 186 | if truncate_text: 187 | tokens = tokens[:context_length] 188 | else: 189 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 190 | result[i, :len(tokens)] = torch.tensor(tokens) 191 | 192 | return result 193 | 194 | # chinese tokenizer 195 | 196 | class ChineseTokenizer: 197 | def __init__(self): 198 | tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') 199 | self.tokenizer = tokenizer 200 | self.vocab_size = tokenizer.vocab_size 201 | 202 | def decode(self, tokens, pad_tokens = {}): 203 | if torch.is_tensor(tokens): 204 | tokens = tokens.tolist() 205 | 206 | ignore_ids = pad_tokens.union({0}) 207 | tokens = [token for token in tokens if token not in ignore_ids] 208 | return self.tokenizer.decode(tokens) 209 | 210 | def encode(self, text): 211 | return torch.tensor(self.tokenizer.encode(text, add_special_tokens = False)) 212 | 213 | def tokenize(self, texts, context_length = 256, truncate_text = False): 214 | if isinstance(texts, str): 215 | texts = [texts] 216 | 217 | all_tokens = [self.encode(text) for text in texts] 218 | 219 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 220 | for i, tokens in enumerate(all_tokens): 221 | if len(tokens) > context_length: 222 | if truncate_text: 223 | tokens = tokens[:context_length] 224 | else: 225 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 226 | result[i, :len(tokens)] = torch.tensor(tokens) 227 | 228 | return result 229 | 230 | # yttm tokenizer 231 | 232 | class YttmTokenizer: 233 | def __init__(self, bpe_path = None): 234 | bpe_path = Path(bpe_path) 235 | assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist' 236 | 237 | tokenizer = yttm.BPE(model = str(bpe_path)) 238 | self.tokenizer = tokenizer 239 | self.vocab_size = tokenizer.vocab_size() 240 | 241 | def decode(self, tokens, pad_tokens = {}): 242 | if torch.is_tensor(tokens): 243 | tokens = tokens.tolist() 244 | 245 | return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0})) 246 | 247 | def encode(self, texts): 248 | encoded = self.tokenizer.encode(texts, output_type = yttm.OutputType.ID) 249 | return list(map(torch.tensor, encoded)) 250 | 251 | def tokenize(self, texts, context_length = 256, truncate_text = False): 252 | if isinstance(texts, str): 253 | texts = [texts] 254 | 255 | all_tokens = self.encode(texts) 256 | 257 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 258 | for i, tokens in enumerate(all_tokens): 259 | if len(tokens) > context_length: 260 | if truncate_text: 261 | tokens = tokens[:context_length] 262 | else: 263 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 264 | result[i, :len(tokens)] = torch.tensor(tokens) 265 | 266 | return result 267 | -------------------------------------------------------------------------------- /dalle_pytorch/transformer.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from itertools import islice, cycle 3 | 4 | import torch 5 | from torch import nn, einsum 6 | import torch.nn.functional as F 7 | from einops import rearrange 8 | 9 | from dalle_pytorch.reversible import ReversibleSequence, SequentialSequence 10 | from dalle_pytorch.attention import Attention, SparseAttention, SparseConvCausalAttention, SparseAxialCausalAttention 11 | 12 | from rotary_embedding_torch import RotaryEmbedding, broadcat 13 | from g_mlp_pytorch import gMLPBlock 14 | 15 | # helpers 16 | 17 | def exists(val): 18 | return val is not None 19 | 20 | def default(val, d): 21 | return val if exists(val) else d 22 | 23 | def cast_tuple(val, depth = 1): 24 | if isinstance(val, list): 25 | val = tuple(val) 26 | return val if isinstance(val, tuple) else (val,) * depth 27 | 28 | # classes 29 | 30 | class DivideMax(nn.Module): 31 | def __init__(self, dim): 32 | super().__init__() 33 | self.dim = dim 34 | 35 | def forward(self, x): 36 | maxes = x.amax(dim = self.dim, keepdim = True).detach() 37 | return x / maxes 38 | 39 | # https://arxiv.org/abs/2103.17239 40 | class LayerScale(nn.Module): 41 | def __init__(self, dim, depth, fn): 42 | super().__init__() 43 | if depth <= 18: 44 | init_eps = 0.1 45 | elif depth > 18 and depth <= 24: 46 | init_eps = 1e-5 47 | else: 48 | init_eps = 1e-6 49 | 50 | scale = torch.zeros(1, 1, dim).fill_(init_eps) 51 | self.scale = nn.Parameter(scale) 52 | self.fn = fn 53 | def forward(self, x, **kwargs): 54 | return self.fn(x, **kwargs) * self.scale 55 | 56 | # layer norm 57 | 58 | class PreNorm(nn.Module): 59 | def __init__(self, dim, fn, sandwich = False): 60 | super().__init__() 61 | self.norm = nn.LayerNorm(dim) 62 | self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity() 63 | self.fn = fn 64 | 65 | def forward(self, x, **kwargs): 66 | x = self.norm(x) 67 | x = self.fn(x, **kwargs) 68 | return self.norm_out(x) 69 | 70 | # feed forward 71 | 72 | class GEGLU(nn.Module): 73 | def forward(self, x): 74 | x, gates = x.chunk(2, dim = -1) 75 | return x * F.gelu(gates) 76 | 77 | class FeedForward(nn.Module): 78 | def __init__(self, dim, dropout = 0., mult = 4.): 79 | super().__init__() 80 | self.net = nn.Sequential( 81 | nn.Linear(dim, dim * mult * 2), 82 | GEGLU(), 83 | nn.Dropout(dropout), 84 | nn.Linear(dim * mult, dim) 85 | ) 86 | 87 | def forward(self, x): 88 | return self.net(x) 89 | 90 | # token shift classes 91 | 92 | class PreShiftToken(nn.Module): 93 | def __init__(self, fn, image_size, seq_len): 94 | super().__init__() 95 | self.fn = fn 96 | self.image_size = image_size 97 | self.seq_len = seq_len 98 | 99 | def forward(self, x, **kwargs): 100 | n = x.shape[1] 101 | seq_len, image_size = self.seq_len, self.image_size 102 | img_seq_len = image_size ** 2 103 | text_len = seq_len - img_seq_len + 1 104 | padding = seq_len - n + 1 105 | 106 | # get text and image tokens 107 | 108 | x_text, x_img = x[:, :text_len], x[:, text_len:] 109 | x_img = F.pad(x_img, (0, 0, 0, padding)) 110 | x_img = rearrange(x_img, 'b (h w) d -> b h w d', h = image_size) 111 | 112 | # shift 1 from the left for text tokens 113 | 114 | x_text_shift, x_text_pass = x_text.chunk(2, dim = -1) 115 | x_text_shift = F.pad(x_text_shift, (0, 0, 1, -1)) 116 | x_text = torch.cat((x_text_shift, x_text_pass), dim = -1) 117 | 118 | # shift from top, left for image tokens 119 | 120 | x_img_shift_top, x_img_shift_left, *x_img_pass = x_img.chunk(4, dim = -1) 121 | x_img_shift_left = F.pad(x_img_shift_left, (0, 0, 1, -1)) 122 | x_img_shift_top = F.pad(x_img_shift_top, (0, 0, 0, 0, 1, -1)) 123 | x_img = torch.cat((x_img_shift_top, x_img_shift_left, *x_img_pass), dim = -1) 124 | 125 | # merge text and image sequence back together 126 | 127 | x_img = rearrange(x_img, 'b h w d -> b (h w) d') 128 | x = torch.cat((x_text, x_img[:, :-padding]), dim = 1) 129 | return self.fn(x, **kwargs) 130 | 131 | # main transformer class 132 | 133 | class Transformer(nn.Module): 134 | def __init__( 135 | self, 136 | *, 137 | dim, 138 | depth, 139 | seq_len, 140 | reversible = False, 141 | causal = True, 142 | heads = 8, 143 | dim_head = 64, 144 | ff_mult = 4, 145 | attn_dropout = 0., 146 | ff_dropout = 0., 147 | attn_types = None, 148 | image_fmap_size = None, 149 | sparse_attn = False, 150 | stable = False, 151 | sandwich_norm = False, 152 | shift_tokens = False, 153 | rotary_emb = True 154 | ): 155 | super().__init__() 156 | layers = nn.ModuleList([]) 157 | sparse_layer = cast_tuple(sparse_attn, depth) 158 | 159 | attn_types = default(attn_types, ('full',)) 160 | attn_types = cast_tuple(attn_types) 161 | attn_type_layer = islice(cycle(attn_types), depth) 162 | 163 | for ind, sparse_attn, attn_type in zip(range(depth), sparse_layer, attn_type_layer): 164 | if attn_type == 'full': 165 | attn_class = partial(Attention, stable = stable) 166 | elif attn_type == 'sparse': 167 | attn_class = SparseAttention 168 | elif attn_type == 'axial_row': 169 | attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_fmap_size, stable = stable) 170 | elif attn_type == 'axial_col': 171 | attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size, stable = stable) 172 | elif attn_type == 'conv_like': 173 | attn_class = partial(SparseConvCausalAttention, seq_len = seq_len, image_size = image_fmap_size, stable = stable) 174 | elif attn_type == 'mlp': 175 | attn_class = partial(gMLPBlock, seq_len = seq_len) 176 | else: 177 | raise ValueError(f'attention type "{attn_type}" is not valid') 178 | 179 | if attn_type != 'mlp': 180 | attn = attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout) 181 | else: 182 | attn = attn_class(dim = dim, causal = causal, dim_ff = dim * 4) 183 | 184 | ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout) 185 | 186 | if shift_tokens: 187 | attn, ff = map(lambda t: PreShiftToken(t, image_size = image_fmap_size, seq_len = seq_len), (attn, ff)) 188 | 189 | layers.append(nn.ModuleList([ 190 | LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich = sandwich_norm)), 191 | LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich = sandwich_norm)) 192 | ])) 193 | 194 | execute_type = ReversibleSequence if reversible else SequentialSequence 195 | route_attn = ((True, False),) * depth 196 | attn_route_map = {'mask': route_attn, 'rotary_pos_emb': route_attn} 197 | 198 | self.layers = execute_type(layers, args_route = attn_route_map) 199 | 200 | # generate positional embeddings for rotary 201 | 202 | pos_emb = None 203 | if rotary_emb: 204 | assert 'mlp' not in attn_types, 'you cannot use gMLPs if rotary embedding is turned on' 205 | 206 | rot_dim = dim_head // 3 207 | img_seq_len = (image_fmap_size ** 2) 208 | text_len = seq_len - img_seq_len + 1 209 | 210 | text_pos_emb = RotaryEmbedding(dim = rot_dim) 211 | img_axial_pos_emb = RotaryEmbedding(dim = rot_dim, freqs_for = 'pixel') 212 | 213 | text_freqs = text_pos_emb(torch.arange(text_len)) 214 | img_to_text_freqs = text_pos_emb(torch.full((img_seq_len,), 8192)) # image is given a position far away from text 215 | text_freqs = torch.cat((text_freqs, img_to_text_freqs), dim = 0) 216 | 217 | img_freqs_axial = img_axial_pos_emb(torch.linspace(-1, 1, steps = image_fmap_size)) 218 | img_freqs = broadcat((rearrange(img_freqs_axial, 'i d -> i () d'), rearrange(img_freqs_axial, 'j d -> () j d')), dim = -1) 219 | img_freqs = rearrange(img_freqs, 'h w d -> (h w) d') 220 | 221 | text_axial_freqs = img_axial_pos_emb(torch.full((text_len,), -10.)) # text is given a position of -10 apart from the image axial positions, which is from range [-1, 1] 222 | text_axial_freqs = torch.cat((text_axial_freqs, text_axial_freqs), dim = -1) 223 | img_freqs = torch.cat((text_axial_freqs, img_freqs), dim = 0) 224 | 225 | pos_emb = torch.cat((text_freqs, img_freqs), dim = -1) 226 | pos_emb = rearrange(pos_emb[:-1], 'n d -> () () n d') 227 | 228 | self.register_buffer('pos_emb', pos_emb) 229 | 230 | def forward(self, x, **kwargs): 231 | return self.layers(x, rotary_pos_emb = self.pos_emb, **kwargs) 232 | -------------------------------------------------------------------------------- /dalle_pytorch/vae.py: -------------------------------------------------------------------------------- 1 | import io 2 | import sys 3 | import os 4 | import requests 5 | import PIL 6 | import warnings 7 | import hashlib 8 | import urllib 9 | import yaml 10 | from pathlib import Path 11 | from tqdm import tqdm 12 | from math import sqrt, log 13 | from omegaconf import OmegaConf 14 | from taming.models.vqgan import VQModel, GumbelVQ 15 | import importlib 16 | 17 | import torch 18 | from torch import nn 19 | import torch.nn.functional as F 20 | 21 | from einops import rearrange 22 | 23 | from dalle_pytorch import distributed_utils 24 | 25 | # constants 26 | 27 | CACHE_PATH = os.path.expanduser("~/.cache/dalle") 28 | 29 | OPENAI_VAE_ENCODER_PATH = 'https://cdn.openai.com/dall-e/encoder.pkl' 30 | OPENAI_VAE_DECODER_PATH = 'https://cdn.openai.com/dall-e/decoder.pkl' 31 | 32 | VQGAN_VAE_PATH = 'https://heibox.uni-heidelberg.de/f/140747ba53464f49b476/?dl=1' 33 | VQGAN_VAE_CONFIG_PATH = 'https://heibox.uni-heidelberg.de/f/6ecf2af6c658432c8298/?dl=1' 34 | 35 | # helpers methods 36 | 37 | def exists(val): 38 | return val is not None 39 | 40 | def default(val, d): 41 | return val if exists(val) else d 42 | 43 | def load_model(path): 44 | with open(path, 'rb') as f: 45 | return torch.load(f, map_location = torch.device('cpu')) 46 | 47 | def map_pixels(x, eps = 0.1): 48 | return (1 - 2 * eps) * x + eps 49 | 50 | def unmap_pixels(x, eps = 0.1): 51 | return torch.clamp((x - eps) / (1 - 2 * eps), 0, 1) 52 | 53 | def download(url, filename = None, root = CACHE_PATH): 54 | if ( 55 | not distributed_utils.is_distributed 56 | or distributed_utils.backend.is_local_root_worker() 57 | ): 58 | os.makedirs(root, exist_ok = True) 59 | filename = default(filename, os.path.basename(url)) 60 | 61 | download_target = os.path.join(root, filename) 62 | download_target_tmp = os.path.join(root, f'tmp.{filename}') 63 | 64 | if os.path.exists(download_target) and not os.path.isfile(download_target): 65 | raise RuntimeError(f"{download_target} exists and is not a regular file") 66 | 67 | if ( 68 | distributed_utils.is_distributed 69 | and not distributed_utils.backend.is_local_root_worker() 70 | and not os.path.isfile(download_target) 71 | ): 72 | # If the file doesn't exist yet, wait until it's downloaded by the root worker. 73 | distributed_utils.backend.local_barrier() 74 | 75 | if os.path.isfile(download_target): 76 | return download_target 77 | 78 | with urllib.request.urlopen(url) as source, open(download_target_tmp, "wb") as output: 79 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop: 80 | while True: 81 | buffer = source.read(8192) 82 | if not buffer: 83 | break 84 | 85 | output.write(buffer) 86 | loop.update(len(buffer)) 87 | 88 | os.rename(download_target_tmp, download_target) 89 | if ( 90 | distributed_utils.is_distributed 91 | and distributed_utils.backend.is_local_root_worker() 92 | ): 93 | distributed_utils.backend.local_barrier() 94 | return download_target 95 | 96 | def make_contiguous(module): 97 | with torch.no_grad(): 98 | for param in module.parameters(): 99 | param.set_(param.contiguous()) 100 | 101 | # pretrained Discrete VAE from OpenAI 102 | 103 | class OpenAIDiscreteVAE(nn.Module): 104 | def __init__(self): 105 | super().__init__() 106 | 107 | self.enc = load_model(download(OPENAI_VAE_ENCODER_PATH)) 108 | self.dec = load_model(download(OPENAI_VAE_DECODER_PATH)) 109 | make_contiguous(self) 110 | 111 | self.num_layers = 3 112 | self.image_size = 256 113 | self.num_tokens = 8192 114 | 115 | @torch.no_grad() 116 | def get_codebook_indices(self, img): 117 | img = map_pixels(img) 118 | z_logits = self.enc.blocks(img) 119 | z = torch.argmax(z_logits, dim = 1) 120 | return rearrange(z, 'b h w -> b (h w)') 121 | 122 | def decode(self, img_seq): 123 | b, n = img_seq.shape 124 | img_seq = rearrange(img_seq, 'b (h w) -> b h w', h = int(sqrt(n))) 125 | 126 | z = F.one_hot(img_seq, num_classes = self.num_tokens) 127 | z = rearrange(z, 'b h w c -> b c h w').float() 128 | x_stats = self.dec(z).float() 129 | x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3])) 130 | return x_rec 131 | 132 | def forward(self, img): 133 | raise NotImplemented 134 | 135 | # VQGAN from Taming Transformers paper 136 | # https://arxiv.org/abs/2012.09841 137 | 138 | def get_obj_from_str(string, reload=False): 139 | module, cls = string.rsplit(".", 1) 140 | if reload: 141 | module_imp = importlib.import_module(module) 142 | importlib.reload(module_imp) 143 | return getattr(importlib.import_module(module, package=None), cls) 144 | 145 | def instantiate_from_config(config): 146 | if not "target" in config: 147 | raise KeyError("Expected key `target` to instantiate.") 148 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 149 | 150 | class VQGanVAE(nn.Module): 151 | def __init__(self, vqgan_model_path=None, vqgan_config_path=None): 152 | super().__init__() 153 | 154 | if vqgan_model_path is None: 155 | model_filename = 'vqgan.1024.model.ckpt' 156 | config_filename = 'vqgan.1024.config.yml' 157 | download(VQGAN_VAE_CONFIG_PATH, config_filename) 158 | download(VQGAN_VAE_PATH, model_filename) 159 | config_path = str(Path(CACHE_PATH) / config_filename) 160 | model_path = str(Path(CACHE_PATH) / model_filename) 161 | else: 162 | model_path = vqgan_model_path 163 | config_path = vqgan_config_path 164 | 165 | config = OmegaConf.load(config_path) 166 | 167 | model = instantiate_from_config(config["model"]) 168 | 169 | state = torch.load(model_path, map_location = 'cpu')['state_dict'] 170 | model.load_state_dict(state, strict = False) 171 | 172 | print(f"Loaded VQGAN from {model_path} and {config_path}") 173 | 174 | self.model = model 175 | 176 | # f as used in https://github.com/CompVis/taming-transformers#overview-of-pretrained-models 177 | f = config.model.params.ddconfig.resolution / config.model.params.ddconfig.attn_resolutions[0] 178 | self.num_layers = int(log(f)/log(2)) 179 | self.image_size = 256 180 | self.num_tokens = config.model.params.n_embed 181 | self.is_gumbel = isinstance(self.model, GumbelVQ) 182 | 183 | self._register_external_parameters() 184 | 185 | def _register_external_parameters(self): 186 | """Register external parameters for DeepSpeed partitioning.""" 187 | if ( 188 | not distributed_utils.is_distributed 189 | or not distributed_utils.using_backend( 190 | distributed_utils.DeepSpeedBackend) 191 | ): 192 | return 193 | 194 | deepspeed = distributed_utils.backend.backend_module 195 | deepspeed.zero.register_external_parameter( 196 | self, self.model.quantize.embed.weight if self.is_gumbel else self.model.quantize.embedding.weight) 197 | 198 | @torch.no_grad() 199 | def get_codebook_indices(self, img): 200 | b = img.shape[0] 201 | img = (2 * img) - 1 202 | _, _, [_, _, indices] = self.model.encode(img) 203 | if self.is_gumbel: 204 | return rearrange(indices, 'b h w -> b (h w)', b=b) 205 | return rearrange(indices, '(b n) -> b n', b = b) 206 | 207 | def decode(self, img_seq): 208 | b, n = img_seq.shape 209 | one_hot_indices = F.one_hot(img_seq, num_classes = self.num_tokens).float() 210 | z = one_hot_indices @ self.model.quantize.embed.weight if self.is_gumbel \ 211 | else (one_hot_indices @ self.model.quantize.embedding.weight) 212 | 213 | z = rearrange(z, 'b (h w) c -> b c h w', h = int(sqrt(n))) 214 | img = self.model.decode(z) 215 | 216 | img = (img.clamp(-1., 1.) + 1) * 0.5 217 | return img 218 | 219 | def forward(self, img): 220 | raise NotImplemented 221 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | 2 | ARG IMG_TAG=1.8.1-cuda10.2-cudnn7-devel 3 | ARG IMG_REPO=pytorch 4 | 5 | FROM pytorch/$IMG_REPO:$IMG_TAG 6 | 7 | RUN apt-get -y update && apt-get -y install git gcc llvm-9-dev cmake libaio-dev vim wget 8 | 9 | RUN git clone https://github.com/microsoft/DeepSpeed.git /tmp/DeepSpeed 10 | RUN cd /tmp/DeepSpeed && DS_BUILD_OPS=1 ./install.sh -r 11 | RUN pip install git+https://github.com/lucidrains/DALLE-pytorch.git 12 | 13 | WORKDIR dalle 14 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from tqdm import tqdm 4 | 5 | # torch 6 | 7 | import torch 8 | 9 | from einops import repeat 10 | 11 | # vision imports 12 | 13 | from PIL import Image 14 | from torchvision.utils import make_grid, save_image 15 | from torchvision.transforms import functional as TF 16 | 17 | # dalle related classes and utils 18 | 19 | from dalle_pytorch import DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE, DALLE 20 | from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, YttmTokenizer, ChineseTokenizer 21 | 22 | import numpy as np 23 | 24 | # argument parsing 25 | 26 | parser = argparse.ArgumentParser() 27 | 28 | parser.add_argument('--dalle_path', type = str, required = True, 29 | help='path to your trained DALL-E') 30 | 31 | parser.add_argument('--vqgan_model_path', type=str, default = None, 32 | help='path to your trained VQGAN weights. This should be a .ckpt file. (only valid when taming option is enabled)') 33 | 34 | parser.add_argument('--vqgan_config_path', type=str, default = None, 35 | help='path to your trained VQGAN config. This should be a .yaml file. (only valid when taming option is enabled)') 36 | 37 | parser.add_argument('--text', type = str, required = True, 38 | help='your text prompt') 39 | 40 | parser.add_argument('--num_images', type = int, default = 12, required = False, 41 | help='number of images') 42 | 43 | parser.add_argument('--batch_size', type = int, default = 4, required = False, 44 | help='batch size') 45 | 46 | parser.add_argument('--top_k', type = float, default = None, required = False, 47 | help='top k filter threshold') 48 | 49 | parser.add_argument('--top_p', type = float, default = None, required = False, 50 | help='top p filter threshold') 51 | 52 | parser.add_argument('--temperature', type = float, default = 1.0, required = False, 53 | help='sampling temperature') 54 | 55 | parser.add_argument('--outputs_dir', type = str, default = './outputs', required = False, 56 | help='output directory') 57 | 58 | parser.add_argument('--bpe_path', type = str, 59 | help='path to your huggingface BPE json file') 60 | 61 | parser.add_argument('--clip_sort', dest='clip_sort', action = 'store_true') 62 | 63 | parser.add_argument('--hug', dest='hug', action = 'store_true') 64 | 65 | parser.add_argument('--chinese', dest='chinese', action = 'store_true') 66 | 67 | parser.add_argument('--taming', dest='taming', action='store_true') 68 | 69 | parser.add_argument('--output_npy', dest='output_npy', action='store_true') 70 | 71 | parser.add_argument('--gentxt', dest='gentxt', action='store_true') 72 | 73 | args = parser.parse_args() 74 | 75 | if args.clip_sort: 76 | # load OpenAI clip 77 | import clip 78 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 79 | clip_model, clip_preprocess = clip.load('ViT-B/16', jit=False) 80 | clip_model.eval().requires_grad_(False).to(device) 81 | 82 | # helper fns 83 | 84 | def exists(val): 85 | return val is not None 86 | 87 | # tokenizer 88 | 89 | if exists(args.bpe_path): 90 | klass = HugTokenizer if args.hug else YttmTokenizer 91 | tokenizer = klass(args.bpe_path) 92 | elif args.chinese: 93 | tokenizer = ChineseTokenizer() 94 | 95 | # load DALL-E 96 | 97 | dalle_path = Path(args.dalle_path) 98 | 99 | assert dalle_path.exists(), 'trained DALL-E must exist' 100 | 101 | load_obj = torch.load(str(dalle_path)) 102 | dalle_params, vae_params, weights = load_obj.pop('hparams'), load_obj.pop('vae_params'), load_obj.pop('weights') 103 | 104 | dalle_params.pop('vae', None) # cleanup later 105 | 106 | if args.taming: 107 | vae = VQGanVAE(args.vqgan_model_path, args.vqgan_config_path) 108 | elif vae_params is not None: 109 | vae = DiscreteVAE(**vae_params) 110 | else: 111 | vae = OpenAIDiscreteVAE() 112 | 113 | dalle = DALLE(vae = vae, **dalle_params).cuda() 114 | 115 | dalle.load_state_dict(weights) 116 | 117 | # generate images 118 | 119 | image_size = vae.image_size 120 | 121 | texts = args.text.split('|') 122 | 123 | for j, text in tqdm(enumerate(texts)): 124 | 125 | text = text.lower() 126 | 127 | if args.gentxt: 128 | text_tokens, gen_texts = dalle.generate_texts(tokenizer, text=text, filter_thres = args.top_k) 129 | text = gen_texts[0] 130 | else: 131 | text_tokens = tokenizer.tokenize([text], dalle.text_seq_len).cuda() 132 | 133 | text_tokens = repeat(text_tokens, '() n -> b n', b = args.num_images) 134 | 135 | outputs = [] 136 | tokens = [] 137 | 138 | for text_chunk in tqdm(text_tokens.split(args.batch_size), desc = f'generating images for - {text}'): 139 | if args.top_k is not None: 140 | output, tok = dalle.generate_images(text_chunk, temperature=args.temperature, top_k_thresh = args.top_k, return_tokens = args.output_npy) 141 | elif args.top_p is not None: 142 | output, tok = dalle.generate_images(text_chunk, temperature=args.temperature, top_p_thresh = args.top_p, return_tokens = args.output_npy) 143 | else: 144 | output, tok = dalle.generate_images(text_chunk, temperature=1.0, top_p_thresh = 0.9, return_tokens = args.output_npy) 145 | 146 | outputs.append(output) 147 | 148 | if tok is not None: 149 | tokens.append(tok) 150 | 151 | outputs = torch.cat(outputs) 152 | 153 | if len(tokens) > 0: 154 | tokens = torch.cat(tokens).cpu().detach().numpy() 155 | 156 | # save all images 157 | file_name = text 158 | outputs_dir = Path(args.outputs_dir) / file_name.replace(' ', '_')[:(100)] 159 | outputs_dir.mkdir(parents = True, exist_ok = True) 160 | 161 | if not args.clip_sort: 162 | for i, image in tqdm(enumerate(outputs), desc = 'saving images'): 163 | save_image(image, outputs_dir / f'{i}.jpg', normalize=True) 164 | with open(outputs_dir / 'caption.txt', 'w') as f: 165 | f.write(file_name) 166 | if args.output_npy: 167 | with open(outputs_dir / f'{i}.npy', 'wb') as f: 168 | np.save(f, tokens[i]) 169 | else: 170 | images_sorted = [] 171 | for i, image in enumerate(outputs): 172 | image = image.clamp(-1,1) 173 | pimg = TF.to_pil_image(image) 174 | image = image.add(1).div(2) 175 | 176 | 177 | text_features = clip_model.encode_text(clip.tokenize(text).to(device)) 178 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 179 | 180 | image_features = clip_model.encode_image(clip_preprocess(pimg).unsqueeze(0).to(device)) 181 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 182 | 183 | similarity = torch.nn.functional.cosine_similarity(image_features, text_features, dim=-1) 184 | 185 | tup = (image, similarity.item(), tokens[i] if args.output_npy else None) 186 | images_sorted.append(tup) 187 | 188 | images_sorted.sort(key=lambda x:x[1], reverse=True) 189 | 190 | for i, image in enumerate(images_sorted): 191 | save_image(image[0], outputs_dir / f'{i}-{image[1]}.png', normalize=True) 192 | if args.output_npy: 193 | with open(outputs_dir / f'{i}.npy', 'wb') as f: 194 | np.save(f, image[2]) 195 | 196 | print(f'created {args.num_images} images at "{str(outputs_dir)}"') 197 | -------------------------------------------------------------------------------- /images/avocado-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/avocado-0.png -------------------------------------------------------------------------------- /images/avocado-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/avocado-1.png -------------------------------------------------------------------------------- /images/avocado-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/avocado-2.png -------------------------------------------------------------------------------- /images/avocado-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/avocado-3.png -------------------------------------------------------------------------------- /images/avocado-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/avocado-4.png -------------------------------------------------------------------------------- /images/avocado-5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/avocado-5.png -------------------------------------------------------------------------------- /images/avocado-6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/avocado-6.png -------------------------------------------------------------------------------- /images/avocado-7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/avocado-7.png -------------------------------------------------------------------------------- /images/avocado-before.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/avocado-before.png -------------------------------------------------------------------------------- /images/avocado.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/avocado.png -------------------------------------------------------------------------------- /images/banner.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/banner.jpg -------------------------------------------------------------------------------- /images/birds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/birds.png -------------------------------------------------------------------------------- /images/clothing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/clothing.png -------------------------------------------------------------------------------- /images/cloud-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cloud-0.png -------------------------------------------------------------------------------- /images/cloud-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cloud-1.png -------------------------------------------------------------------------------- /images/cloud-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cloud-2.png -------------------------------------------------------------------------------- /images/cloud-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cloud-3.png -------------------------------------------------------------------------------- /images/cloud-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cloud-4.png -------------------------------------------------------------------------------- /images/cloud-5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cloud-5.png -------------------------------------------------------------------------------- /images/cloud-6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cloud-6.png -------------------------------------------------------------------------------- /images/cloud-7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cloud-7.png -------------------------------------------------------------------------------- /images/cube-cloud-before.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cube-cloud-before.jpg -------------------------------------------------------------------------------- /images/cube-cloud.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cube-cloud.png -------------------------------------------------------------------------------- /images/cube-porcupine-before.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cube-porcupine-before.jpg -------------------------------------------------------------------------------- /images/cube-porcupine.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cube-porcupine.png -------------------------------------------------------------------------------- /images/cube-water-before.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cube-water-before.jpg -------------------------------------------------------------------------------- /images/cube-water.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cube-water.png -------------------------------------------------------------------------------- /images/girl-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/girl-0.png -------------------------------------------------------------------------------- /images/girl-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/girl-1.png -------------------------------------------------------------------------------- /images/girl-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/girl-2.png -------------------------------------------------------------------------------- /images/girl-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/girl-3.png -------------------------------------------------------------------------------- /images/girl-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/girl-4.png -------------------------------------------------------------------------------- /images/girl-5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/girl-5.png -------------------------------------------------------------------------------- /images/girl-6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/girl-6.png -------------------------------------------------------------------------------- /images/girl-7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/girl-7.png -------------------------------------------------------------------------------- /images/girl-glasses-before.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/girl-glasses-before.png -------------------------------------------------------------------------------- /images/girl-glasses.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/girl-glasses.png -------------------------------------------------------------------------------- /images/landscape.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/landscape.png -------------------------------------------------------------------------------- /images/layouts-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/layouts-1.jpg -------------------------------------------------------------------------------- /images/layouts-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/layouts-2.jpg -------------------------------------------------------------------------------- /images/researcher-mad-before.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/researcher-mad-before.png -------------------------------------------------------------------------------- /images/researcher-mad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/researcher-mad.png -------------------------------------------------------------------------------- /images/wb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/wb.png -------------------------------------------------------------------------------- /install_apex.sh: -------------------------------------------------------------------------------- 1 | git clone https://github.com/NVIDIA/apex.git /tmp/apex 2 | cd /tmp/apex && pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 3 | -------------------------------------------------------------------------------- /install_deepspeed.sh: -------------------------------------------------------------------------------- 1 | sudo apt-get -y install llvm-9-dev cmake 2 | git clone https://github.com/microsoft/DeepSpeed.git /tmp/Deepspeed 3 | cd /tmp/Deepspeed && DS_BUILD_SPARSE_ATTN=1 ./install.sh -s 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'dalle-pytorch', 5 | packages = find_packages(), 6 | include_package_data = True, 7 | version = '1.1.4', 8 | license='MIT', 9 | description = 'DALL-E - Pytorch', 10 | author = 'Phil Wang', 11 | author_email = 'lucidrains@gmail.com', 12 | url = 'https://github.com/lucidrains/dalle-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'attention mechanism', 16 | 'transformers', 17 | 'text-to-image' 18 | ], 19 | install_requires=[ 20 | 'axial_positional_embedding', 21 | 'DALL-E', 22 | 'einops>=0.3.2', 23 | 'ftfy', 24 | 'g-mlp-pytorch', 25 | 'pillow', 26 | 'regex', 27 | 'rotary-embedding-torch', 28 | 'taming-transformers-rom1504', 29 | 'tokenizers', 30 | 'torch>=1.6', 31 | 'torchvision', 32 | 'transformers', 33 | 'tqdm', 34 | 'youtokentome', 35 | 'WebDataset' 36 | ], 37 | classifiers=[ 38 | 'Development Status :: 4 - Beta', 39 | 'Intended Audience :: Developers', 40 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 41 | 'License :: OSI Approved :: MIT License', 42 | 'Programming Language :: Python :: 3.6', 43 | ], 44 | ) 45 | -------------------------------------------------------------------------------- /train_dalle.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import time 4 | from glob import glob 5 | import os 6 | import shutil 7 | 8 | import torch 9 | import wandb # Quit early if user doesn't have wandb installed. 10 | from torch.nn.utils import clip_grad_norm_ 11 | from torch.optim import Adam 12 | from torch.optim.lr_scheduler import ReduceLROnPlateau 13 | from torch.utils.data import DataLoader 14 | 15 | from dalle_pytorch import OpenAIDiscreteVAE, VQGanVAE, DiscreteVAE, DALLE 16 | from dalle_pytorch import distributed_utils 17 | from dalle_pytorch.loader import TextImageDataset 18 | from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, ChineseTokenizer, YttmTokenizer 19 | 20 | # libraries needed for webdataset support 21 | import webdataset as wds 22 | from torchvision import transforms as T 23 | from PIL import Image 24 | from io import BytesIO 25 | 26 | 27 | # argument parsing 28 | 29 | parser = argparse.ArgumentParser() 30 | 31 | group = parser.add_mutually_exclusive_group(required=False) 32 | 33 | group.add_argument('--vae_path', type=str, 34 | help='path to your trained discrete VAE') 35 | 36 | group.add_argument('--dalle_path', type=str, 37 | help='path to your partially trained DALL-E') 38 | 39 | parser.add_argument('--vqgan_model_path', type=str, default = None, 40 | help='path to your trained VQGAN weights. This should be a .ckpt file. (only valid when taming option is enabled)') 41 | 42 | parser.add_argument('--vqgan_config_path', type=str, default = None, 43 | help='path to your trained VQGAN config. This should be a .yaml file. (only valid when taming option is enabled)') 44 | 45 | parser.add_argument('--image_text_folder', type=str, required=True, 46 | help='path to your folder of images and text for learning the DALL-E') 47 | 48 | parser.add_argument( 49 | '--wds', 50 | type = str, 51 | default='', 52 | help = 'Comma separated list of WebDataset (1) image and (2) text column names. Must contain 2 values, e.g. img,cap.' 53 | ) 54 | 55 | parser.add_argument('--truncate_captions', dest='truncate_captions', action='store_true', 56 | help='Captions passed in which exceed the max token length will be truncated if this is set.') 57 | 58 | parser.add_argument('--random_resize_crop_lower_ratio', dest='resize_ratio', type=float, default=0.75, 59 | help='Random resized crop lower ratio') 60 | 61 | parser.add_argument('--chinese', dest='chinese', action='store_true') 62 | 63 | parser.add_argument('--taming', dest='taming', action='store_true') 64 | 65 | parser.add_argument('--hug', dest='hug', action='store_true') 66 | 67 | parser.add_argument('--bpe_path', type=str, 68 | help='path to your BPE json file') 69 | 70 | parser.add_argument('--dalle_output_file_name', type=str, default = "dalle", 71 | help='output_file_name') 72 | 73 | parser.add_argument('--fp16', action='store_true', 74 | help='(experimental) - Enable DeepSpeed 16 bit precision. Reduces VRAM.') 75 | 76 | 77 | parser.add_argument('--amp', action='store_true', 78 | help='Apex "O1" automatic mixed precision. More stable than 16 bit precision. Can\'t be used in conjunction with deepspeed zero stages 1-3.') 79 | 80 | parser.add_argument('--wandb_name', default='dalle_train_transformer', 81 | help='Name W&B will use when saving results.\ne.g. `--wandb_name "coco2017-full-sparse"`') 82 | 83 | parser.add_argument('--wandb_entity', default=None, 84 | help='(optional) Name of W&B team/entity to log to.') 85 | 86 | parser.add_argument('--stable_softmax', dest='stable_softmax', action='store_true', 87 | help='Prevent values from becoming too large during softmax. Helps with stability in fp16 and Mixture of Quantization training.') 88 | 89 | parser = distributed_utils.wrap_arg_parser(parser) 90 | 91 | train_group = parser.add_argument_group('Training settings') 92 | 93 | train_group.add_argument('--flops_profiler', dest = 'flops_profiler', action='store_true', help = 'Exits after printing detailed flops/runtime analysis of forward/backward') 94 | 95 | train_group.add_argument('--epochs', default = 20, type = int, help = 'Number of epochs') 96 | 97 | train_group.add_argument('--save_every_n_steps', default = 1000, type = int, help = 'Save a checkpoint every n steps') 98 | 99 | train_group.add_argument('--keep_n_checkpoints', default = None, type = int, help = '(Careful) Deletes old deepspeed checkpoints if there are more than n') 100 | 101 | train_group.add_argument('--batch_size', default = 4, type = int, help = 'Batch size') 102 | 103 | train_group.add_argument('--ga_steps', default = 1, type = int, help = 'Number of steps to accumulate gradients across per each iteration. DeepSpeed only.') 104 | 105 | train_group.add_argument('--learning_rate', default = 3e-4, type = float, help = 'Learning rate') 106 | 107 | train_group.add_argument('--clip_grad_norm', default = 0.5, type = float, help = 'Clip gradient norm') 108 | 109 | train_group.add_argument('--lr_decay', dest = 'lr_decay', action = 'store_true') 110 | 111 | model_group = parser.add_argument_group('Model settings') 112 | 113 | model_group.add_argument('--dim', default = 512, type = int, help = 'Model dimension') 114 | 115 | model_group.add_argument('--text_seq_len', default = 256, type = int, help = 'Text sequence length') 116 | 117 | model_group.add_argument('--depth', default = 2, type = int, help = 'Model depth') 118 | 119 | model_group.add_argument('--heads', default = 8, type = int, help = 'Model number of heads') 120 | 121 | model_group.add_argument('--dim_head', default = 64, type = int, help = 'Model head dimension') 122 | 123 | train_group.add_argument('--ff_dropout', default = 0.0, type = float, help = 'Feed forward dropout.') 124 | 125 | train_group.add_argument('--attn_dropout', default = 0.0, type = float, help = 'Feed forward dropout.') 126 | 127 | model_group.add_argument('--reversible', dest = 'reversible', action='store_true') 128 | 129 | model_group.add_argument('--loss_img_weight', default = 7, type = int, help = 'Image loss weight') 130 | 131 | model_group.add_argument('--attn_types', default = 'full', type = str, help = 'comma separated list of attention types. attention type can be: full or sparse or axial_row or axial_col or conv_like.') 132 | 133 | model_group.add_argument('--shift_tokens', help = 'Use the shift tokens feature', action = 'store_true') 134 | 135 | model_group.add_argument('--rotary_emb', help = 'Use rotary embeddings', action = 'store_true') 136 | 137 | args = parser.parse_args() 138 | 139 | # helpers 140 | 141 | def exists(val): 142 | return val is not None 143 | 144 | def get_trainable_params(model): 145 | return [params for params in model.parameters() if params.requires_grad] 146 | 147 | def cp_path_to_dir(cp_path, tag): 148 | """Convert a checkpoint path to a directory with `tag` inserted. 149 | If `cp_path` is already a directory, return it unchanged. 150 | """ 151 | if not isinstance(cp_path, Path): 152 | cp_path = Path(cp_path) 153 | if cp_path.is_dir(): 154 | return cp_path 155 | path_sans_extension = cp_path.parent / cp_path.stem 156 | cp_dir = Path(f'{path_sans_extension}-{tag}-cp') 157 | return cp_dir 158 | 159 | # constants 160 | WEBDATASET_IMAGE_TEXT_COLUMNS = tuple(args.wds.split(',')) 161 | ENABLE_WEBDATASET = True if len(WEBDATASET_IMAGE_TEXT_COLUMNS) == 2 else False 162 | 163 | DALLE_OUTPUT_FILE_NAME = args.dalle_output_file_name + ".pt" 164 | 165 | VAE_PATH = args.vae_path 166 | VQGAN_MODEL_PATH = args.vqgan_model_path 167 | VQGAN_CONFIG_PATH = args.vqgan_config_path 168 | DALLE_PATH = args.dalle_path 169 | RESUME = exists(DALLE_PATH) 170 | 171 | EPOCHS = args.epochs 172 | BATCH_SIZE = args.batch_size 173 | 174 | LEARNING_RATE = args.learning_rate 175 | GRAD_CLIP_NORM = args.clip_grad_norm 176 | LR_DECAY = args.lr_decay 177 | SAVE_EVERY_N_STEPS = args.save_every_n_steps 178 | KEEP_N_CHECKPOINTS = args.keep_n_checkpoints 179 | 180 | MODEL_DIM = args.dim 181 | TEXT_SEQ_LEN = args.text_seq_len 182 | DEPTH = args.depth 183 | HEADS = args.heads 184 | DIM_HEAD = args.dim_head 185 | REVERSIBLE = args.reversible 186 | LOSS_IMG_WEIGHT = args.loss_img_weight 187 | FF_DROPOUT = args.ff_dropout 188 | ATTN_DROPOUT = args.attn_dropout 189 | STABLE = args.stable_softmax 190 | SHIFT_TOKENS = args.shift_tokens 191 | ROTARY_EMB = args.rotary_emb 192 | 193 | ATTN_TYPES = tuple(args.attn_types.split(',')) 194 | 195 | DEEPSPEED_CP_AUX_FILENAME = 'auxiliary.pt' 196 | 197 | if not ENABLE_WEBDATASET: 198 | # quit early if you used the wrong folder name 199 | assert Path(args.image_text_folder).exists(), f'The path {args.image_text_folder} was not found.' 200 | else: 201 | # quit early if no tar files were found 202 | if Path(args.image_text_folder).is_dir(): 203 | DATASET = [str(p) for p in Path(args.image_text_folder).glob("**/*") if ".tar" in str(p).lower()] # .name 204 | assert len(DATASET) > 0, 'The directory ({}) does not contain any WebDataset/.tar files.'.format(args.image_text_folder) 205 | print('Found {} WebDataset .tar(.gz) file(s) under given path {}!'.format(len(DATASET), args.image_text_folder)) 206 | elif ('http://' in args.image_text_folder.lower()) | ('https://' in args.image_text_folder.lower()): 207 | DATASET = f"pipe:curl -L -s {args.image_text_folder} || true" 208 | print('Found {} http(s) link under given path!'.format(len(DATASET), args.image_text_folder)) 209 | elif 'gs://' in args.image_text_folder.lower(): 210 | DATASET = f"pipe:gsutil cat {args.image_text_folder} || true" 211 | print('Found {} GCS link under given path!'.format(len(DATASET), args.image_text_folder)) 212 | elif '.tar' in args.image_text_folder: 213 | DATASET = args.image_text_folder 214 | print('Found WebDataset .tar(.gz) file under given path {}!'.format(args.image_text_folder)) 215 | else: 216 | raise Exception('No folder, no .tar(.gz) and no url pointing to tar files provided under {}.'.format(args.image_text_folder)) 217 | 218 | # initialize distributed backend 219 | 220 | distr_backend = distributed_utils.set_backend_from_args(args) 221 | distr_backend.initialize() 222 | 223 | using_deepspeed = \ 224 | distributed_utils.using_backend(distributed_utils.DeepSpeedBackend) 225 | 226 | # tokenizer 227 | 228 | if exists(args.bpe_path): 229 | klass = HugTokenizer if args.hug else YttmTokenizer 230 | tokenizer = klass(args.bpe_path) 231 | elif args.chinese: 232 | tokenizer = ChineseTokenizer() 233 | 234 | # reconstitute vae 235 | if RESUME: 236 | dalle_path = Path(DALLE_PATH) 237 | if using_deepspeed: 238 | cp_dir = cp_path_to_dir(dalle_path, 'ds') 239 | assert cp_dir.is_dir(), \ 240 | f'DeepSpeed checkpoint directory {cp_dir} not found' 241 | dalle_path = cp_dir / DEEPSPEED_CP_AUX_FILENAME 242 | else: 243 | assert dalle_path.exists(), 'DALL-E model file does not exist' 244 | loaded_obj = torch.load(str(dalle_path), map_location='cpu') 245 | 246 | dalle_params, vae_params, weights = loaded_obj['hparams'], loaded_obj['vae_params'], loaded_obj['weights'] 247 | opt_state = loaded_obj.get('opt_state') 248 | scheduler_state = loaded_obj.get('scheduler_state') 249 | 250 | if vae_params is not None: 251 | vae = DiscreteVAE(**vae_params) 252 | else: 253 | if args.taming: 254 | vae = VQGanVAE(VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH) 255 | else: 256 | vae = OpenAIDiscreteVAE() 257 | 258 | dalle_params = dict( 259 | **dalle_params 260 | ) 261 | IMAGE_SIZE = vae.image_size 262 | resume_epoch = loaded_obj.get('epoch', 0) 263 | else: 264 | if exists(VAE_PATH): 265 | vae_path = Path(VAE_PATH) 266 | assert vae_path.exists(), 'VAE model file does not exist' 267 | assert not vae_path.is_dir(), \ 268 | ('Cannot load VAE model from directory; please use a ' 269 | 'standard *.pt checkpoint. ' 270 | 'Currently, merging a DeepSpeed-partitioned VAE into a DALLE ' 271 | 'model is not supported.') 272 | 273 | loaded_obj = torch.load(str(vae_path)) 274 | 275 | vae_params, weights = loaded_obj['hparams'], loaded_obj['weights'] 276 | 277 | vae = DiscreteVAE(**vae_params) 278 | vae.load_state_dict(weights) 279 | else: 280 | if distr_backend.is_root_worker(): 281 | print('using pretrained VAE for encoding images to tokens') 282 | vae_params = None 283 | 284 | if args.taming: 285 | vae = VQGanVAE(VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH) 286 | else: 287 | vae = OpenAIDiscreteVAE() 288 | 289 | IMAGE_SIZE = vae.image_size 290 | 291 | dalle_params = dict( 292 | num_text_tokens=tokenizer.vocab_size, 293 | text_seq_len=TEXT_SEQ_LEN, 294 | dim=MODEL_DIM, 295 | depth=DEPTH, 296 | heads=HEADS, 297 | dim_head=DIM_HEAD, 298 | reversible=REVERSIBLE, 299 | loss_img_weight=LOSS_IMG_WEIGHT, 300 | attn_types=ATTN_TYPES, 301 | ff_dropout=FF_DROPOUT, 302 | attn_dropout=ATTN_DROPOUT, 303 | stable=STABLE, 304 | shift_tokens=SHIFT_TOKENS, 305 | rotary_emb=ROTARY_EMB, 306 | ) 307 | resume_epoch = 0 308 | 309 | # configure OpenAI VAE for float16s 310 | 311 | if isinstance(vae, OpenAIDiscreteVAE) and args.fp16: 312 | vae.enc.blocks.output.conv.use_float16 = True 313 | 314 | 315 | # helpers 316 | 317 | def group_weight(model): 318 | group_decay, group_no_decay = [], [] 319 | for params in model.named_parameters(): 320 | if 'transformer' in params[0]: 321 | if 'bias' in params[0] or 'norm' in params[0]: 322 | group_no_decay.append(params[1]) 323 | continue 324 | group_decay.append(params[1]) 325 | 326 | assert len(list(model.parameters())) == len(group_decay) + len(group_no_decay) 327 | groups = [dict(params=group_decay), dict(params=group_no_decay, weight_decay=.0)] 328 | return groups 329 | 330 | 331 | # create dataset and dataloader 332 | 333 | is_shuffle = not distributed_utils.using_backend(distributed_utils.HorovodBackend) 334 | 335 | imagepreproc = T.Compose([ 336 | T.Lambda(lambda img: img.convert('RGB') 337 | if img.mode != 'RGB' else img), 338 | T.RandomResizedCrop(IMAGE_SIZE, 339 | scale=(args.resize_ratio, 1.), 340 | ratio=(1., 1.)), 341 | T.ToTensor(), 342 | ]) 343 | 344 | def imagetransform(b): 345 | return Image.open(BytesIO(b)) 346 | 347 | def tokenize(s): 348 | return tokenizer.tokenize( 349 | s.decode('utf-8'), 350 | TEXT_SEQ_LEN, 351 | truncate_text=args.truncate_captions).squeeze(0) 352 | 353 | if ENABLE_WEBDATASET: 354 | DATASET_SIZE = int(1e9) # You need to set a nominal length for the Dataset in order to avoid warnings from DataLoader 355 | 356 | myimg, mycap = WEBDATASET_IMAGE_TEXT_COLUMNS 357 | image_text_mapping = { 358 | myimg: imagetransform, 359 | mycap: tokenize 360 | } 361 | image_mapping = { 362 | myimg: imagepreproc 363 | } 364 | 365 | def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available. 366 | if mycap not in item: 367 | return False 368 | if myimg not in item: 369 | return False 370 | return True 371 | 372 | w_dataset = wds.WebDataset(DATASET, handler=wds.warn_and_continue) 373 | filtered_dataset = w_dataset.select(filter_dataset) 374 | ds = filtered_dataset.map_dict(**image_text_mapping).map_dict(**image_mapping).to_tuple(mycap, myimg).batched(BATCH_SIZE, partial=True) 375 | else: 376 | ds = TextImageDataset( 377 | args.image_text_folder, 378 | text_len=TEXT_SEQ_LEN, 379 | image_size=IMAGE_SIZE, 380 | resize_ratio=args.resize_ratio, 381 | truncate_captions=args.truncate_captions, 382 | tokenizer=tokenizer, 383 | shuffle=is_shuffle, 384 | ) 385 | assert len(ds) > 0, 'dataset is empty' 386 | 387 | if distr_backend.is_root_worker(): 388 | if not ENABLE_WEBDATASET: 389 | print(f'{len(ds)} image-text pairs found for training') 390 | 391 | if not is_shuffle: 392 | data_sampler = torch.utils.data.distributed.DistributedSampler( 393 | ds, 394 | num_replicas=distr_backend.get_world_size(), 395 | rank=distr_backend.get_rank() 396 | ) 397 | else: 398 | data_sampler = None 399 | 400 | if ENABLE_WEBDATASET: 401 | # WebLoader for WebDataset and DeepSpeed compatibility 402 | dl = wds.WebLoader(ds, batch_size=None, shuffle=False) # optionally add num_workers=2 (n) argument 403 | number_of_batches = DATASET_SIZE // (BATCH_SIZE * distr_backend.get_world_size()) 404 | dl = dl.repeat(2).slice(number_of_batches) 405 | dl.length = number_of_batches 406 | else: 407 | # Regular DataLoader for image-text-folder datasets 408 | dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=is_shuffle, drop_last=True, sampler=data_sampler) 409 | 410 | 411 | # initialize DALL-E 412 | 413 | dalle = DALLE(vae=vae, **dalle_params) 414 | if not using_deepspeed: 415 | if args.fp16: 416 | dalle = dalle.half() 417 | dalle = dalle.cuda() 418 | 419 | if RESUME and not using_deepspeed: 420 | dalle.load_state_dict(weights) 421 | 422 | # optimizer 423 | 424 | opt = Adam(get_trainable_params(dalle), lr=LEARNING_RATE) 425 | if RESUME and opt_state: 426 | opt.load_state_dict(opt_state) 427 | 428 | if LR_DECAY: 429 | scheduler = ReduceLROnPlateau( 430 | opt, 431 | mode="min", 432 | factor=0.5, 433 | patience=10, 434 | cooldown=10, 435 | min_lr=1e-6, 436 | verbose=True, 437 | ) 438 | if RESUME and scheduler_state: 439 | scheduler.load_state_dict(scheduler_state) 440 | else: 441 | scheduler = None 442 | 443 | if distr_backend.is_root_worker(): 444 | # experiment tracker 445 | 446 | model_config = dict( 447 | depth=DEPTH, 448 | heads=HEADS, 449 | dim_head=DIM_HEAD 450 | ) 451 | 452 | run = wandb.init( 453 | project=args.wandb_name, 454 | entity=args.wandb_entity, 455 | resume=False, 456 | config=model_config, 457 | ) 458 | 459 | # distribute 460 | 461 | distr_backend.check_batch_size(BATCH_SIZE) 462 | deepspeed_config = { 463 | 'train_batch_size': BATCH_SIZE, 464 | 'gradient_accumulation_steps': args.ga_steps, 465 | 'gradient_clipping': GRAD_CLIP_NORM, 466 | 'fp16': { 467 | 'enabled': args.fp16, 468 | }, 469 | 'amp': { 470 | 'enabled': args.amp, 471 | 'opt_level': 'O1', 472 | }, 473 | "flops_profiler": { 474 | "enabled": args.flops_profiler, 475 | "profile_step": 200, 476 | "module_depth": -1, 477 | "top_modules": 1, 478 | "detailed": True, 479 | "output_file": None # TODO Can't get this to work. 480 | }, 481 | } 482 | 483 | if deepspeed_config.get('zero_optimization', {}).get('stage', 0) >= 2: 484 | print(f"Checkpoints made with DeepSpeed ZeRO Stages 2 and 3 will be stored in deepspeed checkpoint folder") 485 | print(f"As such, they will require DeepSpeed as a dependency in order to resume from or generate with.") 486 | print("See the deespeed conversion script for details on how to convert your ZeRO stage 2/3 checkpoint to a single file.") 487 | print("If using a single GPU, consider running with apex automatic mixed precision instead for a similar speedup to ZeRO.") 488 | time.sleep(2) 489 | 490 | (distr_dalle, distr_opt, distr_dl, distr_scheduler) = distr_backend.distribute( 491 | args=args, 492 | model=dalle, 493 | optimizer=opt, 494 | model_parameters=get_trainable_params(dalle), 495 | training_data=( 496 | (None if ENABLE_WEBDATASET else ds) 497 | if using_deepspeed 498 | else dl 499 | ), 500 | # Do not pass the LR scheduler to DeepSpeed so we can manually 501 | # advance it. 502 | lr_scheduler=scheduler if LR_DECAY and not using_deepspeed else None, 503 | config_params=deepspeed_config, 504 | ) 505 | # Prefer scheduler in `deepspeed_config`. 506 | if LR_DECAY and distr_scheduler is None: 507 | distr_scheduler = scheduler 508 | avoid_model_calls = using_deepspeed and args.fp16 509 | 510 | if RESUME and using_deepspeed: 511 | distr_dalle.load_checkpoint(str(cp_dir)) 512 | 513 | 514 | def save_model(path, epoch=0): 515 | save_obj = { 516 | 'hparams': dalle_params, 517 | 'vae_params': vae_params, 518 | 'epoch': epoch, 519 | } 520 | if using_deepspeed: 521 | cp_dir = cp_path_to_dir(path, 'ds') 522 | 523 | if KEEP_N_CHECKPOINTS is not None and distr_backend.is_root_worker(): 524 | checkpoints = sorted(glob(str(cp_dir / "global*")), key=os.path.getmtime, reverse=True) 525 | for checkpoint in checkpoints[KEEP_N_CHECKPOINTS:]: 526 | shutil.rmtree(checkpoint) 527 | 528 | distr_dalle.save_checkpoint(cp_dir, client_state=save_obj) 529 | 530 | if not distr_backend.is_root_worker(): 531 | return 532 | 533 | # Save auxiliary values so we can reuse the standard routine 534 | # for loading. 535 | save_obj = { 536 | **save_obj, 537 | # Save a nonsense value that directs the user to 538 | # further help. 539 | 'weights': ( 540 | 'To get a working standard checkpoint, ' 541 | 'look into consolidating DeepSpeed checkpoints.' 542 | ), 543 | } 544 | torch.save(save_obj, str(cp_dir / DEEPSPEED_CP_AUX_FILENAME)) 545 | if deepspeed_config.get('zero_optimization', {}).get('stage', 0) >= 2: # see https://github.com/lucidrains/DALLE-pytorch/wiki/DeepSpeed-Checkpoints 546 | return 547 | 548 | if not distr_backend.is_root_worker(): 549 | return 550 | 551 | save_obj = { 552 | **save_obj, 553 | 'weights': dalle.state_dict(), 554 | 'opt_state': opt.state_dict(), 555 | } 556 | save_obj['scheduler_state'] = (scheduler.state_dict() if scheduler else None) 557 | torch.save(save_obj, path) 558 | 559 | # training 560 | 561 | # Saves a checkpoint before training begins to fail early when mis-configured. 562 | # See https://github.com/lucidrains/DALLE-pytorch/wiki/DeepSpeed-Checkpoints 563 | save_model(DALLE_OUTPUT_FILE_NAME, epoch=resume_epoch) 564 | for epoch in range(resume_epoch, EPOCHS): 565 | if data_sampler: 566 | data_sampler.set_epoch(epoch) 567 | for i, (text, images) in enumerate((dl if ENABLE_WEBDATASET else distr_dl)): 568 | if i % 10 == 0 and distr_backend.is_root_worker(): 569 | t = time.time() 570 | if args.fp16: 571 | images = images.half() 572 | text, images = map(lambda t: t.cuda(), (text, images)) 573 | 574 | loss = distr_dalle(text, images, return_loss=True) 575 | 576 | if using_deepspeed: 577 | distr_dalle.backward(loss) 578 | distr_dalle.step() 579 | # Gradients are automatically zeroed after the step 580 | else: 581 | loss.backward() 582 | clip_grad_norm_(distr_dalle.parameters(), GRAD_CLIP_NORM) 583 | distr_opt.step() 584 | distr_opt.zero_grad() 585 | 586 | # Collective loss, averaged 587 | avg_loss = distr_backend.average_all(loss) 588 | 589 | log = {} 590 | 591 | if i % 10 == 0 and distr_backend.is_root_worker(): 592 | print(epoch, i, f'loss - {avg_loss.item()}') 593 | 594 | log = { 595 | **log, 596 | 'epoch': epoch, 597 | 'iter': i, 598 | 'loss': avg_loss.item() 599 | } 600 | 601 | if i % SAVE_EVERY_N_STEPS == 0: 602 | save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch) 603 | 604 | if i % 100 == 0: 605 | if distr_backend.is_root_worker(): 606 | sample_text = text[:1] 607 | token_list = sample_text.masked_select(sample_text != 0).tolist() 608 | decoded_text = tokenizer.decode(token_list) 609 | 610 | if not avoid_model_calls: 611 | # CUDA index errors when we don't guard this 612 | image = dalle.generate_images(text[:1], top_p_thres=0.9) # topp sampling at 0.9 613 | 614 | 615 | log = { 616 | **log, 617 | } 618 | if not avoid_model_calls: 619 | log['image'] = wandb.Image(image, caption=decoded_text) 620 | 621 | if i % 10 == 9 and distr_backend.is_root_worker(): 622 | sample_per_sec = BATCH_SIZE * 10 / (time.time() - t) 623 | log["sample_per_sec"] = sample_per_sec 624 | print(epoch, i, f'sample_per_sec - {sample_per_sec}') 625 | 626 | if i == 201 and args.flops_profiler: 627 | raise StopIteration("Profiler has finished running. Stopping training early.") 628 | 629 | if distr_backend.is_root_worker(): 630 | wandb.log(log) 631 | 632 | if LR_DECAY: 633 | distr_scheduler.step(avg_loss) 634 | 635 | save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch) 636 | 637 | if distr_backend.is_root_worker(): 638 | # save trained model to wandb as an artifact every epoch's end 639 | 640 | model_artifact = wandb.Artifact('trained-dalle', type='model', metadata=dict(model_config)) 641 | model_artifact.add_file(DALLE_OUTPUT_FILE_NAME) 642 | run.log_artifact(model_artifact) 643 | 644 | save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch) 645 | if distr_backend.is_root_worker(): 646 | wandb.save(DALLE_OUTPUT_FILE_NAME) 647 | model_artifact = wandb.Artifact('trained-dalle', type='model', metadata=dict(model_config)) 648 | model_artifact.add_file(DALLE_OUTPUT_FILE_NAME) 649 | run.log_artifact(model_artifact) 650 | 651 | wandb.finish() 652 | -------------------------------------------------------------------------------- /train_vae.py: -------------------------------------------------------------------------------- 1 | import math 2 | from math import sqrt 3 | import argparse 4 | from pathlib import Path 5 | 6 | # torch 7 | 8 | import torch 9 | from torch.optim import Adam 10 | from torch.optim.lr_scheduler import ExponentialLR 11 | 12 | # vision imports 13 | 14 | from torchvision import transforms as T 15 | from torch.utils.data import DataLoader 16 | from torchvision.datasets import ImageFolder 17 | from torchvision.utils import make_grid, save_image 18 | 19 | # dalle classes and utils 20 | 21 | from dalle_pytorch import distributed_utils 22 | from dalle_pytorch import DiscreteVAE 23 | 24 | # argument parsing 25 | 26 | parser = argparse.ArgumentParser() 27 | 28 | parser.add_argument('--image_folder', type = str, required = True, 29 | help='path to your folder of images for learning the discrete VAE and its codebook') 30 | 31 | parser.add_argument('--image_size', type = int, required = False, default = 128, 32 | help='image size') 33 | 34 | parser = distributed_utils.wrap_arg_parser(parser) 35 | 36 | 37 | train_group = parser.add_argument_group('Training settings') 38 | 39 | train_group.add_argument('--epochs', type = int, default = 20, help = 'number of epochs') 40 | 41 | train_group.add_argument('--batch_size', type = int, default = 8, help = 'batch size') 42 | 43 | train_group.add_argument('--learning_rate', type = float, default = 1e-3, help = 'learning rate') 44 | 45 | train_group.add_argument('--lr_decay_rate', type = float, default = 0.98, help = 'learning rate decay') 46 | 47 | train_group.add_argument('--starting_temp', type = float, default = 1., help = 'starting temperature') 48 | 49 | train_group.add_argument('--temp_min', type = float, default = 0.5, help = 'minimum temperature to anneal to') 50 | 51 | train_group.add_argument('--anneal_rate', type = float, default = 1e-6, help = 'temperature annealing rate') 52 | 53 | train_group.add_argument('--num_images_save', type = int, default = 4, help = 'number of images to save') 54 | 55 | model_group = parser.add_argument_group('Model settings') 56 | 57 | model_group.add_argument('--num_tokens', type = int, default = 8192, help = 'number of image tokens') 58 | 59 | model_group.add_argument('--num_layers', type = int, default = 3, help = 'number of layers (should be 3 or above)') 60 | 61 | model_group.add_argument('--num_resnet_blocks', type = int, default = 2, help = 'number of residual net blocks') 62 | 63 | model_group.add_argument('--smooth_l1_loss', dest = 'smooth_l1_loss', action = 'store_true') 64 | 65 | model_group.add_argument('--emb_dim', type = int, default = 512, help = 'embedding dimension') 66 | 67 | model_group.add_argument('--hidden_dim', type = int, default = 256, help = 'hidden dimension') 68 | 69 | model_group.add_argument('--kl_loss_weight', type = float, default = 0., help = 'KL loss weight') 70 | 71 | args = parser.parse_args() 72 | 73 | # constants 74 | 75 | IMAGE_SIZE = args.image_size 76 | IMAGE_PATH = args.image_folder 77 | 78 | EPOCHS = args.epochs 79 | BATCH_SIZE = args.batch_size 80 | LEARNING_RATE = args.learning_rate 81 | LR_DECAY_RATE = args.lr_decay_rate 82 | 83 | NUM_TOKENS = args.num_tokens 84 | NUM_LAYERS = args.num_layers 85 | NUM_RESNET_BLOCKS = args.num_resnet_blocks 86 | SMOOTH_L1_LOSS = args.smooth_l1_loss 87 | EMB_DIM = args.emb_dim 88 | HIDDEN_DIM = args.hidden_dim 89 | KL_LOSS_WEIGHT = args.kl_loss_weight 90 | 91 | STARTING_TEMP = args.starting_temp 92 | TEMP_MIN = args.temp_min 93 | ANNEAL_RATE = args.anneal_rate 94 | 95 | NUM_IMAGES_SAVE = args.num_images_save 96 | 97 | # initialize distributed backend 98 | 99 | distr_backend = distributed_utils.set_backend_from_args(args) 100 | distr_backend.initialize() 101 | 102 | using_deepspeed = \ 103 | distributed_utils.using_backend(distributed_utils.DeepSpeedBackend) 104 | 105 | # data 106 | 107 | ds = ImageFolder( 108 | IMAGE_PATH, 109 | T.Compose([ 110 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 111 | T.Resize(IMAGE_SIZE), 112 | T.CenterCrop(IMAGE_SIZE), 113 | T.ToTensor() 114 | ]) 115 | ) 116 | 117 | if distributed_utils.using_backend(distributed_utils.HorovodBackend): 118 | data_sampler = torch.utils.data.distributed.DistributedSampler( 119 | ds, num_replicas=distr_backend.get_world_size(), 120 | rank=distr_backend.get_rank()) 121 | else: 122 | data_sampler = None 123 | 124 | dl = DataLoader(ds, BATCH_SIZE, shuffle = not data_sampler, sampler=data_sampler) 125 | 126 | vae_params = dict( 127 | image_size = IMAGE_SIZE, 128 | num_layers = NUM_LAYERS, 129 | num_tokens = NUM_TOKENS, 130 | codebook_dim = EMB_DIM, 131 | hidden_dim = HIDDEN_DIM, 132 | num_resnet_blocks = NUM_RESNET_BLOCKS 133 | ) 134 | 135 | vae = DiscreteVAE( 136 | **vae_params, 137 | smooth_l1_loss = SMOOTH_L1_LOSS, 138 | kl_div_loss_weight = KL_LOSS_WEIGHT 139 | ) 140 | if not using_deepspeed: 141 | vae = vae.cuda() 142 | 143 | 144 | assert len(ds) > 0, 'folder does not contain any images' 145 | if distr_backend.is_root_worker(): 146 | print(f'{len(ds)} images found for training') 147 | 148 | # optimizer 149 | 150 | opt = Adam(vae.parameters(), lr = LEARNING_RATE) 151 | sched = ExponentialLR(optimizer = opt, gamma = LR_DECAY_RATE) 152 | 153 | 154 | if distr_backend.is_root_worker(): 155 | # weights & biases experiment tracking 156 | 157 | import wandb 158 | 159 | model_config = dict( 160 | num_tokens = NUM_TOKENS, 161 | smooth_l1_loss = SMOOTH_L1_LOSS, 162 | num_resnet_blocks = NUM_RESNET_BLOCKS, 163 | kl_loss_weight = KL_LOSS_WEIGHT 164 | ) 165 | 166 | run = wandb.init( 167 | project = 'dalle_train_vae', 168 | job_type = 'train_model', 169 | config = model_config 170 | ) 171 | 172 | # distribute 173 | 174 | distr_backend.check_batch_size(BATCH_SIZE) 175 | deepspeed_config = {'train_batch_size': BATCH_SIZE} 176 | 177 | (distr_vae, distr_opt, distr_dl, distr_sched) = distr_backend.distribute( 178 | args=args, 179 | model=vae, 180 | optimizer=opt, 181 | model_parameters=vae.parameters(), 182 | training_data=ds if using_deepspeed else dl, 183 | lr_scheduler=sched if not using_deepspeed else None, 184 | config_params=deepspeed_config, 185 | ) 186 | 187 | using_deepspeed_sched = False 188 | # Prefer scheduler in `deepspeed_config`. 189 | if distr_sched is None: 190 | distr_sched = sched 191 | elif using_deepspeed: 192 | # We are using a DeepSpeed LR scheduler and want to let DeepSpeed 193 | # handle its scheduling. 194 | using_deepspeed_sched = True 195 | 196 | def save_model(path): 197 | save_obj = { 198 | 'hparams': vae_params, 199 | } 200 | if using_deepspeed: 201 | cp_path = Path(path) 202 | path_sans_extension = cp_path.parent / cp_path.stem 203 | cp_dir = str(path_sans_extension) + '-ds-cp' 204 | 205 | distr_vae.save_checkpoint(cp_dir, client_state=save_obj) 206 | # We do not return so we do get a "normal" checkpoint to refer to. 207 | 208 | if not distr_backend.is_root_worker(): 209 | return 210 | 211 | save_obj = { 212 | **save_obj, 213 | 'weights': vae.state_dict() 214 | } 215 | 216 | torch.save(save_obj, path) 217 | 218 | # starting temperature 219 | 220 | global_step = 0 221 | temp = STARTING_TEMP 222 | 223 | for epoch in range(EPOCHS): 224 | for i, (images, _) in enumerate(distr_dl): 225 | images = images.cuda() 226 | 227 | loss, recons = distr_vae( 228 | images, 229 | return_loss = True, 230 | return_recons = True, 231 | temp = temp 232 | ) 233 | 234 | if using_deepspeed: 235 | # Gradients are automatically zeroed after the step 236 | distr_vae.backward(loss) 237 | distr_vae.step() 238 | else: 239 | distr_opt.zero_grad() 240 | loss.backward() 241 | distr_opt.step() 242 | 243 | logs = {} 244 | 245 | if i % 100 == 0: 246 | if distr_backend.is_root_worker(): 247 | k = NUM_IMAGES_SAVE 248 | 249 | with torch.no_grad(): 250 | codes = vae.get_codebook_indices(images[:k]) 251 | hard_recons = vae.decode(codes) 252 | 253 | images, recons = map(lambda t: t[:k], (images, recons)) 254 | images, recons, hard_recons, codes = map(lambda t: t.detach().cpu(), (images, recons, hard_recons, codes)) 255 | images, recons, hard_recons = map(lambda t: make_grid(t.float(), nrow = int(sqrt(k)), normalize = True, range = (-1, 1)), (images, recons, hard_recons)) 256 | 257 | logs = { 258 | **logs, 259 | 'sample images': wandb.Image(images, caption = 'original images'), 260 | 'reconstructions': wandb.Image(recons, caption = 'reconstructions'), 261 | 'hard reconstructions': wandb.Image(hard_recons, caption = 'hard reconstructions'), 262 | 'codebook_indices': wandb.Histogram(codes), 263 | 'temperature': temp 264 | } 265 | 266 | wandb.save('./vae.pt') 267 | save_model(f'./vae.pt') 268 | 269 | # temperature anneal 270 | 271 | temp = max(temp * math.exp(-ANNEAL_RATE * global_step), TEMP_MIN) 272 | 273 | # lr decay 274 | 275 | # Do not advance schedulers from `deepspeed_config`. 276 | if not using_deepspeed_sched: 277 | distr_sched.step() 278 | 279 | # Collective loss, averaged 280 | avg_loss = distr_backend.average_all(loss) 281 | 282 | if distr_backend.is_root_worker(): 283 | if i % 10 == 0: 284 | lr = distr_sched.get_last_lr()[0] 285 | print(epoch, i, f'lr - {lr:6f} loss - {avg_loss.item()}') 286 | 287 | logs = { 288 | **logs, 289 | 'epoch': epoch, 290 | 'iter': i, 291 | 'loss': avg_loss.item(), 292 | 'lr': lr 293 | } 294 | 295 | wandb.log(logs) 296 | global_step += 1 297 | 298 | if distr_backend.is_root_worker(): 299 | # save trained model to wandb as an artifact every epoch's end 300 | 301 | model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config)) 302 | model_artifact.add_file('vae.pt') 303 | run.log_artifact(model_artifact) 304 | 305 | if distr_backend.is_root_worker(): 306 | # save final vae and cleanup 307 | 308 | save_model('./vae-final.pt') 309 | wandb.save('./vae-final.pt') 310 | 311 | model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config)) 312 | model_artifact.add_file('vae-final.pt') 313 | run.log_artifact(model_artifact) 314 | 315 | wandb.finish() 316 | --------------------------------------------------------------------------------