├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.md
├── dalle_pytorch
├── __init__.py
├── attention.py
├── dalle_pytorch.py
├── data
│ └── bpe_simple_vocab_16e6.txt
├── distributed_backends
│ ├── __init__.py
│ ├── deepspeed_backend.py
│ ├── distributed_backend.py
│ ├── dummy_backend.py
│ └── horovod_backend.py
├── distributed_utils.py
├── loader.py
├── reversible.py
├── tokenizer.py
├── transformer.py
└── vae.py
├── docker
└── Dockerfile
├── examples
└── rainbow_dalle.ipynb
├── generate.py
├── images
├── avocado-0.png
├── avocado-1.png
├── avocado-2.png
├── avocado-3.png
├── avocado-4.png
├── avocado-5.png
├── avocado-6.png
├── avocado-7.png
├── avocado-before.png
├── avocado.png
├── banner.jpg
├── birds.png
├── clothing.png
├── cloud-0.png
├── cloud-1.png
├── cloud-2.png
├── cloud-3.png
├── cloud-4.png
├── cloud-5.png
├── cloud-6.png
├── cloud-7.png
├── cube-cloud-before.jpg
├── cube-cloud.png
├── cube-porcupine-before.jpg
├── cube-porcupine.png
├── cube-water-before.jpg
├── cube-water.png
├── girl-0.png
├── girl-1.png
├── girl-2.png
├── girl-3.png
├── girl-4.png
├── girl-5.png
├── girl-6.png
├── girl-7.png
├── girl-glasses-before.png
├── girl-glasses.png
├── landscape.png
├── layouts-1.jpg
├── layouts-2.jpg
├── researcher-mad-before.png
├── researcher-mad.png
└── wb.png
├── install_apex.sh
├── install_deepspeed.sh
├── setup.py
├── train_dalle.py
└── train_vae.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflows will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Build and publish
26 | env:
27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29 | run: |
30 | python setup.py sdist bdist_wheel
31 | twine upload dist/*
32 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # dall-e generation outputs
2 | outputs/
3 | *.pt
4 | taming/
5 | wandb/
6 | dalle-ds-cp/
7 |
8 | # Byte-compiled / optimized / DLL files
9 | __pycache__/
10 | *.py[cod]
11 | *$py.class
12 |
13 | # C extensions
14 | *.so
15 |
16 | # Distribution / packaging
17 | .Python
18 | build/
19 | develop-eggs/
20 | dist/
21 | downloads/
22 | eggs/
23 | .eggs/
24 | lib/
25 | lib64/
26 | parts/
27 | sdist/
28 | var/
29 | wheels/
30 | pip-wheel-metadata/
31 | share/python-wheels/
32 | *.egg-info/
33 | .installed.cfg
34 | *.egg
35 | MANIFEST
36 |
37 | # PyInstaller
38 | # Usually these files are written by a python script from a template
39 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
40 | *.manifest
41 | *.spec
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .nox/
51 | .coverage
52 | .coverage.*
53 | .cache
54 | nosetests.xml
55 | coverage.xml
56 | *.cover
57 | *.py,cover
58 | .hypothesis/
59 | .pytest_cache/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 | db.sqlite3
69 | db.sqlite3-journal
70 |
71 | # Flask stuff:
72 | instance/
73 | .webassets-cache
74 |
75 | # Scrapy stuff:
76 | .scrapy
77 |
78 | # Sphinx documentation
79 | docs/_build/
80 |
81 | # PyBuilder
82 | target/
83 |
84 | # Jupyter Notebook
85 | .ipynb_checkpoints
86 |
87 | # IPython
88 | profile_default/
89 | ipython_config.py
90 |
91 | # pyenv
92 | .python-version
93 |
94 | # Visual Studio Code
95 | .vscode
96 |
97 | # pipenv
98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
101 | # install all needed dependencies.
102 | #Pipfile.lock
103 |
104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
105 | __pypackages__/
106 |
107 | # Celery stuff
108 | celerybeat-schedule
109 | celerybeat.pid
110 |
111 | # SageMath parsed files
112 | *.sage.py
113 |
114 | # Environments
115 | .env
116 | .venv
117 | env/
118 | venv/
119 | ENV/
120 | env.bak/
121 | venv.bak/
122 |
123 | # Spyder project settings
124 | .spyderproject
125 | .spyproject
126 |
127 | # Rope project settings
128 | .ropeproject
129 |
130 | # mkdocs documentation
131 | /site
132 |
133 | # mypy
134 | .mypy_cache/
135 | .dmypy.json
136 | dmypy.json
137 |
138 | # Pyre type checker
139 | .pyre/
140 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | recursive-include dalle_pytorch *.txt
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DALL-3
2 |
3 | try it out on colab
4 |
5 |
6 |
7 |
8 | DALL-3 is a mashup of DALLE-pytorch, VQGAN and Clip-Guided Diffusion. The basic idea is to use a diffusion model instead of VAE for the decoder stage, which allows us to use 16x16 tokens instead of 32x32 while maintaining comparable image quality.
9 |
10 | This DALLE model is meant to be used with https://github.com/Jack000/guided-diffusion
11 |
12 | minor modifications to DALLE-pytorch:
13 | - hardcoded 128px image size in dalle (in order to use mismatched VAE/DALLE image sizes)
14 | - added top-p filtering
15 |
16 | Cherry picked sample images:
17 |
18 |
19 | Before diffusion | After diffusion |
20 |  |  |
21 | Prompt: A cube made of cloud, a cube with the texture of cloud |
22 |  |  |
23 | Prompt: A cube made of water, a cube with the texture of water |
24 |  |  |
25 | Prompt: A cube made of porcupine, a cube with the texture of porcupine |
26 |
27 |  |  |
28 | Prompt: An armchair shaped like an avocado, an avocado armchair |
29 |
30 |  |  |
31 | Prompt: A girl with thick glasses, a girl wearing glasses |
32 |
33 |  |  |
34 | Prompt: A machine learning researcher smashes his computer in a fit of rage |
35 |
36 |
37 | Non-Cherry picked images (clip re-ranked best 8 out of 1024):
38 |
39 |
40 |  |  |  |  |
41 |  |  |  |  |
42 | A cube made of cloud. A cube with the texture of cloud |
43 |
44 |
45 |
46 |  |  |  |  |
47 |  |  |  |  |
48 | An armchair shaped like an avocado. An avocado armchair |
49 |
50 |
51 |
52 |  |  |  |  |
53 |  |  |  |  |
54 | A girl with thick glasses. A girl wearing glasses |
55 |
56 |
57 | ## Usage
58 | ```# git clone this repo, then
59 | cd DALLE-pytorch
60 | pip install -e .
61 |
62 | # download GumbelVQ VAE model
63 | mkdir -p vqgan_gumbel_f8
64 | wget 'https://heibox.uni-heidelberg.de/f/b24d14998a8d4f19a34f/?dl=1' -O 'vqgan_gumbel_f8/model.yaml'
65 | wget 'https://heibox.uni-heidelberg.de/f/34a747d5765840b5a99d/?dl=1' -O 'vqgan_gumbel_f8/last.ckpt'
66 |
67 | # download DALL-E models
68 | wget https://dall-3.com/models/dalle/bpe.model
69 | wget https://dall-3.com/models/dalle/dalle-latest.pt
70 |
71 | # generate (optionally install OpenAI clip for --clip_sort)
72 | python generate.py --top_p 0.85 --temperature 1.0 --clip_sort --output_npy --dalle_path ./dalle-latest.pt --bpe_path bpe.model --taming --vqgan_model_path vqgan_gumbel_f8/last.ckpt --vqgan_config_path vqgan_gumbel_f8/model.yaml --text 'a girl with thick glasses. a girl wearing glasses'
73 |
74 | # post process
75 | # use the npy file as input to clip guided diffusion https://github.com/Jack000/guided-diffusion
76 |
77 | ```
78 |
79 | ## Citations
80 |
81 | ```bibtex
82 | @misc{ramesh2021zeroshot,
83 | title = {Zero-Shot Text-to-Image Generation},
84 | author = {Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever},
85 | year = {2021},
86 | eprint = {2102.12092},
87 | archivePrefix = {arXiv},
88 | primaryClass = {cs.CV}
89 | }
90 | ```
91 |
92 | ```bibtex
93 | @misc{unpublished2021clip,
94 | title = {CLIP: Connecting Text and Images},
95 | author = {Alec Radford, Ilya Sutskever, Jong Wook Kim, Gretchen Krueger, Sandhini Agarwal},
96 | year = {2021}
97 | }
98 | ```
99 |
100 | ```bibtex
101 | @misc{kitaev2020reformer,
102 | title = {Reformer: The Efficient Transformer},
103 | author = {Nikita Kitaev and Łukasz Kaiser and Anselm Levskaya},
104 | year = {2020},
105 | eprint = {2001.04451},
106 | archivePrefix = {arXiv},
107 | primaryClass = {cs.LG}
108 | }
109 | ```
110 |
111 | ```bibtex
112 | @misc{esser2021taming,
113 | title = {Taming Transformers for High-Resolution Image Synthesis},
114 | author = {Patrick Esser and Robin Rombach and Björn Ommer},
115 | year = {2021},
116 | eprint = {2012.09841},
117 | archivePrefix = {arXiv},
118 | primaryClass = {cs.CV}
119 | }
120 | ```
121 |
122 | ```bibtex
123 | @misc{ding2021cogview,
124 | title = {CogView: Mastering Text-to-Image Generation via Transformers},
125 | author = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
126 | year = {2021},
127 | eprint = {2105.13290},
128 | archivePrefix = {arXiv},
129 | primaryClass = {cs.CV}
130 | }
131 | ```
132 |
133 | ```bibtex
134 | @software{peng_bo_2021_5196578,
135 | author = {PENG Bo},
136 | title = {BlinkDL/RWKV-LM: 0.01},
137 | month = {aug},
138 | year = {2021},
139 | publisher = {Zenodo},
140 | version = {0.01},
141 | doi = {10.5281/zenodo.5196578},
142 | url = {https://doi.org/10.5281/zenodo.5196578}
143 | }
144 | ```
145 |
146 | ```bibtex
147 | @misc{su2021roformer,
148 | title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
149 | author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
150 | year = {2021},
151 | eprint = {2104.09864},
152 | archivePrefix = {arXiv},
153 | primaryClass = {cs.CL}
154 | }
155 | ```
156 |
157 | *Those who do not want to imitate anything, produce nothing.* - Dali
158 |
--------------------------------------------------------------------------------
/dalle_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from dalle_pytorch.dalle_pytorch import DALLE, CLIP, DiscreteVAE
2 | from dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE
3 |
--------------------------------------------------------------------------------
/dalle_pytorch/attention.py:
--------------------------------------------------------------------------------
1 | from inspect import isfunction
2 | from math import ceil
3 |
4 | import torch
5 | from torch import nn, einsum
6 | import torch.nn.functional as F
7 | from einops import rearrange, repeat
8 |
9 | from rotary_embedding_torch import apply_rotary_emb
10 |
11 | # helpers
12 |
13 | def exists(val):
14 | return val is not None
15 |
16 | def uniq(arr):
17 | return{el: True for el in arr}.keys()
18 |
19 | def default(val, d):
20 | if exists(val):
21 | return val
22 | return d() if isfunction(d) else d
23 |
24 | def max_neg_value(t):
25 | return -torch.finfo(t.dtype).max
26 |
27 | def stable_softmax(t, dim = -1, alpha = 32 ** 2):
28 | t = t / alpha
29 | t = t - torch.amax(t, dim = dim, keepdim = True).detach()
30 | return (t * alpha).softmax(dim = dim)
31 |
32 | def apply_pos_emb(pos_emb, qkv):
33 | n = qkv[0].shape[-2]
34 | pos_emb = pos_emb[..., :n, :]
35 | return tuple(map(lambda t: apply_rotary_emb(pos_emb, t), qkv))
36 |
37 | # classes
38 |
39 | class Attention(nn.Module):
40 | def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False):
41 | super().__init__()
42 | inner_dim = dim_head * heads
43 | self.heads = heads
44 | self.seq_len = seq_len
45 | self.scale = dim_head ** -0.5
46 |
47 | self.stable = stable
48 | self.causal = causal
49 |
50 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
51 | self.to_out = nn.Sequential(
52 | nn.Linear(inner_dim, dim),
53 | nn.Dropout(dropout)
54 | )
55 |
56 | def forward(self, x, mask = None, rotary_pos_emb = None):
57 | b, n, _, h, device = *x.shape, self.heads, x.device
58 | softmax = torch.softmax if not self.stable else stable_softmax
59 |
60 | qkv = self.to_qkv(x).chunk(3, dim = -1)
61 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
62 |
63 | if exists(rotary_pos_emb):
64 | q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
65 |
66 | q = q * self.scale
67 |
68 | dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
69 | mask_value = max_neg_value(dots)
70 |
71 | if exists(mask):
72 | mask = rearrange(mask, 'b j -> b () () j')
73 | dots.masked_fill_(~mask, mask_value)
74 | del mask
75 |
76 | if self.causal:
77 | i, j = dots.shape[-2:]
78 | mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
79 | dots.masked_fill_(mask, mask_value)
80 |
81 | attn = softmax(dots, dim=-1)
82 |
83 | out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
84 | out = rearrange(out, 'b h n d -> b n (h d)')
85 | out = self.to_out(out)
86 | return out
87 |
88 | # sparse attention with convolutional pattern, as mentioned in the blog post. customizable kernel size and dilation
89 |
90 | class SparseConvCausalAttention(nn.Module):
91 | def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1, heads = 8, dim_head = 64, dropout = 0., stable = False, **kwargs):
92 | super().__init__()
93 | assert kernel_size % 2 == 1, 'kernel size must be odd'
94 |
95 | inner_dim = dim_head * heads
96 | self.seq_len = seq_len
97 | self.heads = heads
98 | self.scale = dim_head ** -0.5
99 | self.image_size = image_size
100 | self.kernel_size = kernel_size
101 | self.dilation = dilation
102 |
103 | self.stable = stable
104 |
105 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
106 |
107 | self.to_out = nn.Sequential(
108 | nn.Linear(inner_dim, dim),
109 | nn.Dropout(dropout)
110 | )
111 |
112 | def forward(self, x, mask = None, rotary_pos_emb = None):
113 | b, n, _, h, img_size, kernel_size, dilation, seq_len, device = *x.shape, self.heads, self.image_size, self.kernel_size, self.dilation, self.seq_len, x.device
114 | softmax = torch.softmax if not self.stable else stable_softmax
115 |
116 | img_seq_len = img_size ** 2
117 | text_len = seq_len + 1 - img_seq_len
118 |
119 | # padding
120 |
121 | padding = seq_len - n + 1
122 | mask = default(mask, lambda: torch.ones(b, text_len, device = device).bool())
123 |
124 | x = F.pad(x, (0, 0, 0, padding), value = 0)
125 | mask = mask[:, :text_len]
126 |
127 | # derive query / keys / values
128 |
129 | qkv = self.to_qkv(x).chunk(3, dim = -1)
130 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
131 |
132 | if exists(rotary_pos_emb):
133 | q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
134 |
135 | q *= self.scale
136 |
137 | ((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v))
138 |
139 | # text attention
140 |
141 | dots_text = einsum('b i d, b j d -> b i j', q_text, k_text)
142 | mask_value = max_neg_value(dots_text)
143 |
144 | i, j = dots_text.shape[-2:]
145 | text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
146 | dots_text.masked_fill_(text_causal_mask, mask_value)
147 |
148 | attn_text = softmax(dots_text, dim = -1)
149 | out_text = einsum('b i j, b j d -> b i d', attn_text, v_text)
150 |
151 | # image attention
152 |
153 | effective_kernel_size = (kernel_size - 1) * dilation + 1
154 | padding = effective_kernel_size // 2
155 |
156 | k_img, v_img = map(lambda t: rearrange(t, 'b (h w) c -> b c h w', h = img_size), (k_img, v_img))
157 | k_img, v_img = map(lambda t: F.unfold(t, kernel_size, padding = padding, dilation = dilation), (k_img, v_img))
158 | k_img, v_img = map(lambda t: rearrange(t, 'b (d j) i -> b i j d', j = kernel_size ** 2), (k_img, v_img))
159 |
160 | # let image attend to all of text
161 |
162 | dots_image = einsum('b i d, b i j d -> b i j', q_img, k_img)
163 | dots_image_to_text = einsum('b i d, b j d -> b i j', q_img, k_text)
164 |
165 | # calculate causal attention for local convolution
166 |
167 | i, j = dots_image.shape[-2:]
168 | img_seq = torch.arange(img_seq_len, device = device)
169 | k_img_indices = rearrange(img_seq.float(), '(h w) -> () () h w', h = img_size)
170 | k_img_indices = F.pad(k_img_indices, (padding,) * 4, value = img_seq_len) # padding set to be max, so it is never attended to
171 | k_img_indices = F.unfold(k_img_indices, kernel_size, dilation = dilation)
172 | k_img_indices = rearrange(k_img_indices, 'b j i -> b i j')
173 |
174 | # mask image attention
175 |
176 | q_img_indices = rearrange(img_seq, 'i -> () i ()')
177 | causal_mask = q_img_indices < k_img_indices
178 |
179 | # concat text mask with image causal mask
180 |
181 | causal_mask = repeat(causal_mask, '() i j -> b i j', b = b * h)
182 | mask = repeat(mask, 'b j -> (b h) i j', i = i, h = h)
183 | mask = torch.cat((~mask, causal_mask), dim = -1)
184 |
185 | # image can attend to all of text
186 |
187 | dots = torch.cat((dots_image_to_text, dots_image), dim = -1)
188 | dots.masked_fill_(mask, mask_value)
189 |
190 | attn = softmax(dots, dim = -1)
191 |
192 | # aggregate
193 |
194 | attn_image_to_text, attn_image = attn[..., :text_len], attn[..., text_len:]
195 |
196 | out_image_to_image = einsum('b i j, b i j d -> b i d', attn_image, v_img)
197 | out_image_to_text = einsum('b i j, b j d -> b i d', attn_image_to_text, v_text)
198 |
199 | out_image = out_image_to_image + out_image_to_text
200 |
201 | # combine attended values for both text and image
202 |
203 | out = torch.cat((out_text, out_image), dim = 1)
204 |
205 | out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
206 | out = self.to_out(out)
207 | return out[:, :n]
208 |
209 | # sparse axial causal attention
210 |
211 | class SparseAxialCausalAttention(nn.Module):
212 | def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head = 64, dropout = 0., stable = False, **kwargs):
213 | super().__init__()
214 | assert axis in {0, 1}, 'axis must be either 0 (along height) or 1 (along width)'
215 | self.axis = axis
216 |
217 | inner_dim = dim_head * heads
218 | self.seq_len = seq_len
219 | self.heads = heads
220 | self.scale = dim_head ** -0.5
221 | self.image_size = image_size
222 |
223 | self.stable = stable
224 |
225 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
226 |
227 | self.to_out = nn.Sequential(
228 | nn.Linear(inner_dim, dim),
229 | nn.Dropout(dropout)
230 | )
231 |
232 | def forward(self, x, mask = None, rotary_pos_emb = None):
233 | b, n, _, h, img_size, axis, seq_len, device = *x.shape, self.heads, self.image_size, self.axis, self.seq_len, x.device
234 | softmax = torch.softmax if not self.stable else stable_softmax
235 |
236 | img_seq_len = img_size ** 2
237 | text_len = seq_len + 1 - img_seq_len
238 |
239 | # padding
240 |
241 | padding = seq_len - n + 1
242 | mask = default(mask, lambda: torch.ones(b, text_len, device = device).bool())
243 |
244 | x = F.pad(x, (0, 0, 0, padding), value = 0)
245 | mask = mask[:, :text_len]
246 |
247 | # derive queries / keys / values
248 |
249 | qkv = self.to_qkv(x).chunk(3, dim = -1)
250 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
251 |
252 | if exists(rotary_pos_emb):
253 | q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
254 |
255 | q *= self.scale
256 |
257 | ((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v))
258 |
259 | # text attention
260 |
261 | dots_text = einsum('b i d, b j d -> b i j', q_text, k_text)
262 | mask_value = max_neg_value(dots_text)
263 |
264 | i, j = dots_text.shape[-2:]
265 | text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
266 | dots_text.masked_fill_(text_causal_mask, mask_value)
267 |
268 | attn_text = softmax(dots_text, dim = -1)
269 | out_text = einsum('b i j, b j d -> b i d', attn_text, v_text)
270 |
271 | # image attention
272 |
273 | split_axis_einops = 'b (h w) c -> b h w c' if axis == 0 else 'b (h w) c -> b w h c'
274 | merge_axis_einops = 'b x n d -> b (x n) d' if axis == 0 else 'b x n d -> b (n x) d'
275 |
276 | # split out axis
277 |
278 | q_img, k_img, v_img = map(lambda t: rearrange(t, split_axis_einops, h = img_size), (q_img, k_img, v_img))
279 |
280 | # similarity
281 |
282 | dots_image_to_image = einsum('b x i d, b x j d -> b x i j', q_img, k_img)
283 | dots_image_to_text = einsum('b x i d, b j d -> b x i j', q_img, k_text)
284 |
285 | dots = torch.cat((dots_image_to_text, dots_image_to_image), dim = -1)
286 |
287 | # mask so image has full attention to text, but causal along axis
288 |
289 | bh, x, i, j = dots.shape
290 | causal_mask = torch.ones(i, img_size, device = device).triu_(img_size - i + 1).bool()
291 | causal_mask = repeat(causal_mask, 'i j -> b x i j', b = bh, x = x)
292 |
293 | mask = repeat(mask, 'b j -> (b h) x i j', h = h, x = x, i = i)
294 | mask = torch.cat((~mask, causal_mask), dim = -1)
295 |
296 | dots.masked_fill_(mask, mask_value)
297 |
298 | # attention.
299 |
300 | attn = softmax(dots, dim = -1)
301 |
302 | # aggregate
303 |
304 | attn_image_to_text, attn_image_to_image = attn[..., :text_len], attn[..., text_len:]
305 |
306 | out_image_to_image = einsum('b x i j, b x j d -> b x i d', attn_image_to_image, v_img)
307 | out_image_to_text = einsum('b x i j, b j d -> b x i d', attn_image_to_text, v_text)
308 |
309 | out_image = out_image_to_image + out_image_to_text
310 |
311 | # merge back axis
312 |
313 | out_image = rearrange(out_image, merge_axis_einops, x = img_size)
314 |
315 | # combine attended values for both text and image
316 |
317 | out = torch.cat((out_text, out_image), dim = 1)
318 |
319 | out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
320 | out = self.to_out(out)
321 | return out[:, :n]
322 |
323 | # microsoft sparse attention CUDA kernel
324 |
325 | class SparseAttention(Attention):
326 | def __init__(
327 | self,
328 | *args,
329 | block_size = 16,
330 | text_seq_len = 256,
331 | num_random_blocks = None,
332 | **kwargs
333 | ):
334 | super().__init__(*args, **kwargs)
335 | from deepspeed.ops.sparse_attention import SparseSelfAttention, VariableSparsityConfig
336 | self.block_size = block_size
337 |
338 | num_random_blocks = default(num_random_blocks, self.seq_len // block_size // 4)
339 | global_block_indices = list(range(ceil(text_seq_len / block_size)))
340 |
341 | self.attn_fn = SparseSelfAttention(
342 | sparsity_config = VariableSparsityConfig(
343 | num_heads = self.heads,
344 | block = self.block_size,
345 | num_random_blocks = num_random_blocks,
346 | global_block_indices = global_block_indices,
347 | attention = 'unidirectional' if self.causal else 'bidirectional'
348 | ),
349 | max_seq_length = self.seq_len,
350 | attn_mask_mode = 'add'
351 | )
352 |
353 | def forward(self, x, mask = None, rotary_pos_emb = None):
354 | b, n, _, h, device = *x.shape, self.heads, x.device
355 | remainder = n % self.block_size
356 | mask = default(mask, lambda: torch.ones(b, n, device = device).bool())
357 |
358 | if remainder > 0:
359 | padding = self.block_size - remainder
360 | x = F.pad(x, (0, 0, 0, padding), value = 0)
361 | mask = F.pad(mask, (0, padding), value = False)
362 |
363 | qkv = self.to_qkv(x).chunk(3, dim = -1)
364 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
365 |
366 | if exists(rotary_pos_emb):
367 | q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
368 |
369 | key_pad_mask = None
370 | if exists(mask):
371 | key_pad_mask = ~mask
372 |
373 | attn_mask = None
374 | if self.causal:
375 | i, j = q.shape[-2], k.shape[-2]
376 | mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
377 | attn_mask = torch.zeros(i, j, device = device).to(q)
378 | mask_value = max_neg_value(q) / 2
379 | attn_mask.masked_fill_(mask, mask_value)
380 |
381 | out = self.attn_fn(q, k, v, attn_mask = attn_mask, key_padding_mask = key_pad_mask)
382 | out = rearrange(out, 'b h n d -> b n (h d)')
383 | out = self.to_out(out)
384 | return out[:, :n]
385 |
--------------------------------------------------------------------------------
/dalle_pytorch/dalle_pytorch.py:
--------------------------------------------------------------------------------
1 | from math import log2, sqrt
2 | import torch
3 | from torch import nn, einsum
4 | import torch.nn.functional as F
5 | import numpy as np
6 |
7 | from axial_positional_embedding import AxialPositionalEmbedding
8 | from einops import rearrange
9 |
10 | from dalle_pytorch import distributed_utils
11 | from dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE
12 | from dalle_pytorch.transformer import Transformer, DivideMax
13 |
14 | # helpers
15 |
16 | def exists(val):
17 | return val is not None
18 |
19 | def default(val, d):
20 | return val if exists(val) else d
21 |
22 | class always():
23 | def __init__(self, val):
24 | self.val = val
25 | def __call__(self, x, *args, **kwargs):
26 | return self.val
27 |
28 | def is_empty(t):
29 | return t.nelement() == 0
30 |
31 | def masked_mean(t, mask, dim = 1):
32 | t = t.masked_fill(~mask[:, :, None], 0.)
33 | return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]
34 |
35 | def set_requires_grad(model, value):
36 | for param in model.parameters():
37 | param.requires_grad = value
38 |
39 | def eval_decorator(fn):
40 | def inner(model, *args, **kwargs):
41 | was_training = model.training
42 | model.eval()
43 | out = fn(model, *args, **kwargs)
44 | model.train(was_training)
45 | return out
46 | return inner
47 |
48 | # sampling helpers
49 |
50 | def top_k(logits, thres = 0.5):
51 | num_logits = logits.shape[-1]
52 | k = max(int((1 - thres) * num_logits), 1)
53 | val, ind = torch.topk(logits, k)
54 | probs = torch.full_like(logits, float('-inf'))
55 | probs.scatter_(1, ind, val)
56 | return probs
57 |
58 | def top_p(logits, thres = 0.9):
59 | sorted_logits, sorted_indices = torch.sort(logits, descending=True)
60 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
61 |
62 | # Remove tokens with cumulative probability above the threshold
63 | sorted_indices_to_remove = cumulative_probs > thres
64 |
65 | # Shift the indices to the right to keep also the first token above the threshold
66 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
67 | sorted_indices_to_remove[..., 0] = 0
68 |
69 | # scatter sorted tensors to original indexing
70 | indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
71 | logits[indices_to_remove] = float('-inf')
72 | return logits
73 |
74 | # discrete vae class
75 |
76 | class ResBlock(nn.Module):
77 | def __init__(self, chan):
78 | super().__init__()
79 | self.net = nn.Sequential(
80 | nn.Conv2d(chan, chan, 3, padding = 1),
81 | nn.ReLU(),
82 | nn.Conv2d(chan, chan, 3, padding = 1),
83 | nn.ReLU(),
84 | nn.Conv2d(chan, chan, 1)
85 | )
86 |
87 | def forward(self, x):
88 | return self.net(x) + x
89 |
90 | class DiscreteVAE(nn.Module):
91 | def __init__(
92 | self,
93 | image_size = 256,
94 | num_tokens = 512,
95 | codebook_dim = 512,
96 | num_layers = 3,
97 | num_resnet_blocks = 0,
98 | hidden_dim = 64,
99 | channels = 3,
100 | smooth_l1_loss = False,
101 | temperature = 0.9,
102 | straight_through = False,
103 | kl_div_loss_weight = 0.,
104 | normalization = ((0.5,) * 3, (0.5,) * 3)
105 | ):
106 | super().__init__()
107 | assert log2(image_size).is_integer(), 'image size must be a power of 2'
108 | assert num_layers >= 1, 'number of layers must be greater than or equal to 1'
109 | has_resblocks = num_resnet_blocks > 0
110 |
111 | self.image_size = image_size
112 | self.num_tokens = num_tokens
113 | self.num_layers = num_layers
114 | self.temperature = temperature
115 | self.straight_through = straight_through
116 | self.codebook = nn.Embedding(num_tokens, codebook_dim)
117 |
118 | hdim = hidden_dim
119 |
120 | enc_chans = [hidden_dim] * num_layers
121 | dec_chans = list(reversed(enc_chans))
122 |
123 | enc_chans = [channels, *enc_chans]
124 |
125 | dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
126 | dec_chans = [dec_init_chan, *dec_chans]
127 |
128 | enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))
129 |
130 | enc_layers = []
131 | dec_layers = []
132 |
133 | for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
134 | enc_layers.append(nn.Sequential(nn.Conv2d(enc_in, enc_out, 4, stride = 2, padding = 1), nn.ReLU()))
135 | dec_layers.append(nn.Sequential(nn.ConvTranspose2d(dec_in, dec_out, 4, stride = 2, padding = 1), nn.ReLU()))
136 |
137 | for _ in range(num_resnet_blocks):
138 | dec_layers.insert(0, ResBlock(dec_chans[1]))
139 | enc_layers.append(ResBlock(enc_chans[-1]))
140 |
141 | if num_resnet_blocks > 0:
142 | dec_layers.insert(0, nn.Conv2d(codebook_dim, dec_chans[1], 1))
143 |
144 | enc_layers.append(nn.Conv2d(enc_chans[-1], num_tokens, 1))
145 | dec_layers.append(nn.Conv2d(dec_chans[-1], channels, 1))
146 |
147 | self.encoder = nn.Sequential(*enc_layers)
148 | self.decoder = nn.Sequential(*dec_layers)
149 |
150 | self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
151 | self.kl_div_loss_weight = kl_div_loss_weight
152 |
153 | # take care of normalization within class
154 | self.normalization = normalization
155 |
156 | self._register_external_parameters()
157 |
158 | def _register_external_parameters(self):
159 | """Register external parameters for DeepSpeed partitioning."""
160 | if (
161 | not distributed_utils.is_distributed
162 | or not distributed_utils.using_backend(
163 | distributed_utils.DeepSpeedBackend)
164 | ):
165 | return
166 |
167 | deepspeed = distributed_utils.backend.backend_module
168 | deepspeed.zero.register_external_parameter(self, self.codebook.weight)
169 |
170 | def norm(self, images):
171 | if not exists(self.normalization):
172 | return images
173 |
174 | means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
175 | means, stds = map(lambda t: rearrange(t, 'c -> () c () ()'), (means, stds))
176 | images = images.clone()
177 | images.sub_(means).div_(stds)
178 | return images
179 |
180 | @torch.no_grad()
181 | @eval_decorator
182 | def get_codebook_indices(self, images):
183 | logits = self(images, return_logits = True)
184 | codebook_indices = logits.argmax(dim = 1).flatten(1)
185 | return codebook_indices
186 |
187 | def decode(
188 | self,
189 | img_seq
190 | ):
191 | image_embeds = self.codebook(img_seq)
192 | b, n, d = image_embeds.shape
193 | h = w = int(sqrt(n))
194 |
195 | image_embeds = rearrange(image_embeds, 'b (h w) d -> b d h w', h = h, w = w)
196 | images = self.decoder(image_embeds)
197 | return images
198 |
199 | def forward(
200 | self,
201 | img,
202 | return_loss = False,
203 | return_recons = False,
204 | return_logits = False,
205 | temp = None
206 | ):
207 | device, num_tokens, image_size, kl_div_loss_weight = img.device, self.num_tokens, self.image_size, self.kl_div_loss_weight
208 | assert img.shape[-1] == image_size and img.shape[-2] == image_size, f'input must have the correct image size {image_size}'
209 |
210 | img = self.norm(img)
211 |
212 | logits = self.encoder(img)
213 |
214 | if return_logits:
215 | return logits # return logits for getting hard image indices for DALL-E training
216 |
217 | temp = default(temp, self.temperature)
218 | soft_one_hot = F.gumbel_softmax(logits, tau = temp, dim = 1, hard = self.straight_through)
219 | sampled = einsum('b n h w, n d -> b d h w', soft_one_hot, self.codebook.weight)
220 | out = self.decoder(sampled)
221 |
222 | if not return_loss:
223 | return out
224 |
225 | # reconstruction loss
226 |
227 | recon_loss = self.loss_fn(img, out)
228 |
229 | # kl divergence
230 |
231 | logits = rearrange(logits, 'b n h w -> b (h w) n')
232 | log_qy = F.log_softmax(logits, dim = -1)
233 | log_uniform = torch.log(torch.tensor([1. / num_tokens], device = device))
234 | kl_div = F.kl_div(log_uniform, log_qy, None, None, 'batchmean', log_target = True)
235 |
236 | loss = recon_loss + (kl_div * kl_div_loss_weight)
237 |
238 | if not return_recons:
239 | return loss
240 |
241 | return loss, out
242 |
243 | # main classes
244 |
245 | class CLIP(nn.Module):
246 | def __init__(
247 | self,
248 | *,
249 | dim_text = 512,
250 | dim_image = 512,
251 | dim_latent = 512,
252 | num_text_tokens = 10000,
253 | text_enc_depth = 6,
254 | text_seq_len = 256,
255 | text_heads = 8,
256 | num_visual_tokens = 512,
257 | visual_enc_depth = 6,
258 | visual_heads = 8,
259 | visual_image_size = 256,
260 | visual_patch_size = 32,
261 | channels = 3
262 | ):
263 | super().__init__()
264 | self.text_emb = nn.Embedding(num_text_tokens, dim_text)
265 | self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
266 | self.text_transformer = Transformer(causal = False, seq_len = text_seq_len, dim = dim_text, depth = text_enc_depth, heads = text_heads, rotary_emb = False)
267 | self.to_text_latent = nn.Linear(dim_text, dim_latent, bias = False)
268 |
269 | assert visual_image_size % visual_patch_size == 0, 'Image dimensions must be divisible by the patch size.'
270 | num_patches = (visual_image_size // visual_patch_size) ** 2
271 | patch_dim = channels * visual_patch_size ** 2
272 |
273 | self.visual_patch_size = visual_patch_size
274 | self.to_visual_embedding = nn.Linear(patch_dim, dim_image)
275 | self.visual_pos_emb = nn.Embedding(num_patches, dim_image)
276 | self.visual_transformer = Transformer(causal = False, seq_len = num_patches, dim = dim_image, depth = visual_enc_depth, heads = visual_heads, rotary_emb = False)
277 | self.to_visual_latent = nn.Linear(dim_image, dim_latent, bias = False)
278 |
279 | self.temperature = nn.Parameter(torch.tensor(1.))
280 |
281 | def forward(
282 | self,
283 | text,
284 | image,
285 | text_mask = None,
286 | return_loss = False
287 | ):
288 | b, device, p = text.shape[0], text.device, self.visual_patch_size
289 |
290 | text_emb = self.text_emb(text)
291 | text_emb += self.text_pos_emb(torch.arange(text.shape[1], device = device))
292 |
293 | image_patches = rearrange(image, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
294 | image_emb = self.to_visual_embedding(image_patches)
295 | image_emb += self.visual_pos_emb(torch.arange(image_emb.shape[1], device = device))
296 |
297 | enc_text = self.text_transformer(text_emb, mask = text_mask)
298 | enc_image = self.visual_transformer(image_emb)
299 |
300 | if exists(text_mask):
301 | text_latents = masked_mean(enc_text, text_mask, dim = 1)
302 | else:
303 | text_latents = enc_text.mean(dim = 1)
304 |
305 | image_latents = enc_image.mean(dim = 1)
306 |
307 | text_latents = self.to_text_latent(text_latents)
308 | image_latents = self.to_visual_latent(image_latents)
309 |
310 | text_latents, image_latents = map(lambda t: F.normalize(t, p = 2, dim = -1), (text_latents, image_latents))
311 |
312 | temp = self.temperature.exp()
313 |
314 | if not return_loss:
315 | sim = einsum('n d, n d -> n', text_latents, image_latents) * temp
316 | return sim
317 |
318 | sim = einsum('i d, j d -> i j', text_latents, image_latents) * temp
319 | labels = torch.arange(b, device = device)
320 | loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
321 | return loss
322 |
323 | # main DALL-E class
324 |
325 | class DALLE(nn.Module):
326 | def __init__(
327 | self,
328 | *,
329 | dim,
330 | vae,
331 | num_text_tokens = 10000,
332 | text_seq_len = 256,
333 | depth,
334 | heads = 8,
335 | dim_head = 64,
336 | reversible = False,
337 | attn_dropout = 0.,
338 | ff_dropout = 0,
339 | sparse_attn = False,
340 | attn_types = None,
341 | loss_img_weight = 7,
342 | stable = False,
343 | sandwich_norm = False,
344 | shift_tokens = True,
345 | rotary_emb = True
346 | ):
347 | super().__init__()
348 | assert isinstance(vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE)), 'vae must be an instance of DiscreteVAE'
349 |
350 | #image_size = vae.image_size
351 | image_size = 128
352 | num_image_tokens = vae.num_tokens
353 | #image_fmap_size = (vae.image_size // (2 ** vae.num_layers))
354 | image_fmap_size = 16
355 | image_seq_len = image_fmap_size ** 2
356 |
357 | num_text_tokens = num_text_tokens + text_seq_len # reserve unique padding tokens for each position (text seq len)
358 |
359 | self.text_emb = nn.Embedding(num_text_tokens, dim)
360 | self.image_emb = nn.Embedding(num_image_tokens, dim)
361 |
362 | self.text_pos_emb = nn.Embedding(text_seq_len + 1, dim) if not rotary_emb else always(0) # +1 for
363 | self.image_pos_emb = AxialPositionalEmbedding(dim, axial_shape = (image_fmap_size, image_fmap_size)) if not rotary_emb else always(0)
364 |
365 | self.num_text_tokens = num_text_tokens # for offsetting logits index and calculating cross entropy loss
366 | self.num_image_tokens = num_image_tokens
367 |
368 | self.text_seq_len = text_seq_len
369 | self.image_seq_len = image_seq_len
370 |
371 | seq_len = text_seq_len + image_seq_len
372 | total_tokens = num_text_tokens + num_image_tokens
373 | self.total_tokens = total_tokens
374 | self.total_seq_len = seq_len
375 |
376 | self.vae = vae
377 | set_requires_grad(self.vae, False) # freeze VAE from being trained
378 |
379 | self.transformer = Transformer(
380 | dim = dim,
381 | causal = True,
382 | seq_len = seq_len,
383 | depth = depth,
384 | heads = heads,
385 | dim_head = dim_head,
386 | reversible = reversible,
387 | attn_dropout = attn_dropout,
388 | ff_dropout = ff_dropout,
389 | attn_types = attn_types,
390 | image_fmap_size = image_fmap_size,
391 | sparse_attn = sparse_attn,
392 | stable = stable,
393 | sandwich_norm = sandwich_norm,
394 | shift_tokens = shift_tokens,
395 | rotary_emb = rotary_emb
396 | )
397 |
398 | self.stable = stable
399 |
400 | if stable:
401 | self.norm_by_max = DivideMax(dim = -1)
402 |
403 | self.to_logits = nn.Sequential(
404 | nn.LayerNorm(dim),
405 | nn.Linear(dim, self.total_tokens),
406 | )
407 |
408 | seq_range = torch.arange(seq_len)
409 | logits_range = torch.arange(total_tokens)
410 |
411 | seq_range = rearrange(seq_range, 'n -> () n ()')
412 | logits_range = rearrange(logits_range, 'd -> () () d')
413 |
414 | logits_mask = (
415 | ((seq_range >= text_seq_len) & (logits_range < num_text_tokens)) |
416 | ((seq_range < text_seq_len) & (logits_range >= num_text_tokens))
417 | )
418 |
419 | self.register_buffer('logits_mask', logits_mask, persistent=False)
420 | self.loss_img_weight = loss_img_weight
421 |
422 |
423 | @torch.no_grad()
424 | @eval_decorator
425 | def generate_texts(
426 | self,
427 | tokenizer,
428 | text = None,
429 | *,
430 | filter_thres = 0.5,
431 | temperature = 1.
432 | ):
433 | text_seq_len = self.text_seq_len
434 | if text is None or text == "":
435 | text_tokens = torch.tensor([[0]]).cuda()
436 | else:
437 | text_tokens = torch.tensor(tokenizer.tokenizer.encode(text)).cuda().unsqueeze(0)
438 |
439 | for _ in range(text_tokens.shape[1], text_seq_len):
440 | device = text_tokens.device
441 |
442 | tokens = self.text_emb(text_tokens)
443 | tokens += self.text_pos_emb(torch.arange(text_tokens.shape[1], device = device))
444 |
445 | seq_len = tokens.shape[1]
446 |
447 | output_transf = self.transformer(tokens)
448 |
449 | if self.stable:
450 | output_transf = self.norm_by_max(output_transf)
451 |
452 | logits = self.to_logits(output_transf)
453 |
454 | # mask logits to make sure text predicts text (except last token), and image predicts image
455 |
456 | logits_mask = self.logits_mask[:, :seq_len]
457 | max_neg_value = -torch.finfo(logits.dtype).max
458 | logits.masked_fill_(logits_mask, max_neg_value)
459 | logits = logits[:, -1, :]
460 |
461 | filtered_logits = top_k(logits, thres = filter_thres)
462 | probs = F.softmax(filtered_logits / temperature, dim = -1)
463 | sample = torch.multinomial(probs, 1)
464 |
465 | text_tokens = torch.cat((text_tokens, sample), dim=-1)
466 |
467 | padding_tokens = set(np.arange(self.text_seq_len) + (self.num_text_tokens - self.text_seq_len))
468 | texts = [tokenizer.tokenizer.decode(text_token, pad_tokens=padding_tokens) for text_token in text_tokens]
469 | return text_tokens, texts
470 |
471 | @torch.no_grad()
472 | @eval_decorator
473 | def generate_images(
474 | self,
475 | text,
476 | *,
477 | clip = None,
478 | mask = None,
479 | top_k_thresh = None,
480 | top_p_thresh = None,
481 | temperature = 1.,
482 | img = None,
483 | num_init_img_tokens = None,
484 | return_tokens = False
485 | ):
486 | vae, text_seq_len, image_seq_len, num_text_tokens = self.vae, self.text_seq_len, self.image_seq_len, self.num_text_tokens
487 | total_len = text_seq_len + image_seq_len
488 |
489 | text = text[:, :text_seq_len] # make sure text is within bounds
490 | out = text
491 |
492 | if exists(img):
493 | image_size = vae.image_size
494 | assert img.shape[1] == 3 and img.shape[2] == image_size and img.shape[3] == image_size, f'input image must have the correct image size {image_size}'
495 |
496 | indices = vae.get_codebook_indices(img)
497 | num_img_tokens = default(num_init_img_tokens, int(0.4375 * image_seq_len)) # OpenAI used 14 * 32 initial tokens to prime
498 | assert num_img_tokens < image_seq_len, 'number of initial image tokens for priming must be less than the total image token sequence length'
499 |
500 | indices = indices[:, :num_img_tokens]
501 | out = torch.cat((out, indices), dim = -1)
502 |
503 | for cur_len in range(out.shape[1], total_len):
504 | is_image = cur_len >= text_seq_len
505 |
506 | text, image = out[:, :text_seq_len], out[:, text_seq_len:]
507 |
508 | logits = self(text, image, mask = mask)[:, -1, :]
509 |
510 | if top_k_thresh is not None:
511 | filtered_logits = top_k(logits, thres = top_k_thresh)
512 | else:
513 | filtered_logits = top_p(logits, thres = top_p_thresh)
514 |
515 | probs = F.softmax(filtered_logits / temperature, dim = -1)
516 | sample = torch.multinomial(probs, 1)
517 |
518 | sample -= (num_text_tokens if is_image else 0) # offset sampled token if it is an image token, since logit space is composed of text and then image tokens
519 | out = torch.cat((out, sample), dim=-1)
520 |
521 | if out.shape[1] <= text_seq_len:
522 | mask = F.pad(mask, (0, 1), value = True)
523 |
524 | text_seq = out[:, :text_seq_len]
525 |
526 | img_seq = out[:, -image_seq_len:]
527 | images = vae.decode(img_seq)
528 |
529 | if exists(clip):
530 | scores = clip(text_seq, images, return_loss = False)
531 | return images, scores
532 |
533 | if return_tokens:
534 | return images, img_seq
535 | else:
536 | return images, None
537 |
538 | def forward(
539 | self,
540 | text,
541 | image = None,
542 | mask = None,
543 | return_loss = False
544 | ):
545 | assert text.shape[-1] == self.text_seq_len, f'the length {text.shape[-1]} of the text tokens you passed in does not have the correct length ({self.text_seq_len})'
546 | device, total_seq_len = text.device, self.total_seq_len
547 |
548 | # make sure padding in text tokens get unique padding token id
549 |
550 | text_range = torch.arange(self.text_seq_len, device = device) + (self.num_text_tokens - self.text_seq_len)
551 | text = torch.where(text == 0, text_range, text)
552 |
553 | # add
554 |
555 | text = F.pad(text, (1, 0), value = 0)
556 |
557 | tokens = self.text_emb(text)
558 | tokens += self.text_pos_emb(torch.arange(text.shape[1], device = device))
559 |
560 | seq_len = tokens.shape[1]
561 |
562 | if exists(image) and not is_empty(image):
563 | is_raw_image = len(image.shape) == 4
564 |
565 | if is_raw_image:
566 | image_size = self.vae.image_size
567 | assert tuple(image.shape[1:]) == (3, image_size, image_size), f'invalid image of dimensions {image.shape} passed in during training'
568 |
569 | image = self.vae.get_codebook_indices(image)
570 |
571 | image_len = image.shape[1]
572 | image_emb = self.image_emb(image)
573 |
574 | image_emb += self.image_pos_emb(image_emb)
575 |
576 | tokens = torch.cat((tokens, image_emb), dim = 1)
577 |
578 | seq_len += image_len
579 |
580 | # when training, if the length exceeds the total text + image length
581 | # remove the last token, since it needs not to be trained
582 |
583 | if tokens.shape[1] > total_seq_len:
584 | seq_len -= 1
585 | tokens = tokens[:, :-1]
586 |
587 | if self.stable:
588 | alpha = 0.1
589 | tokens = tokens * alpha + tokens.detach() * (1 - alpha)
590 |
591 | out = self.transformer(tokens)
592 |
593 | if self.stable:
594 | out = self.norm_by_max(out)
595 |
596 | logits = self.to_logits(out)
597 |
598 | # mask logits to make sure text predicts text (except last token), and image predicts image
599 |
600 | logits_mask = self.logits_mask[:, :seq_len]
601 | max_neg_value = -torch.finfo(logits.dtype).max
602 | logits.masked_fill_(logits_mask, max_neg_value)
603 |
604 | if not return_loss:
605 | return logits
606 |
607 | assert exists(image), 'when training, image must be supplied'
608 |
609 | offsetted_image = image + self.num_text_tokens
610 | labels = torch.cat((text[:, 1:], offsetted_image), dim = 1)
611 |
612 | logits = rearrange(logits, 'b n c -> b c n')
613 |
614 | loss_text = F.cross_entropy(logits[:, :, :self.text_seq_len], labels[:, :self.text_seq_len])
615 | loss_img = F.cross_entropy(logits[:, :, self.text_seq_len:], labels[:, self.text_seq_len:])
616 |
617 | loss = (loss_text + self.loss_img_weight * loss_img) / (self.loss_img_weight + 1)
618 | return loss
619 |
--------------------------------------------------------------------------------
/dalle_pytorch/distributed_backends/__init__.py:
--------------------------------------------------------------------------------
1 | from .deepspeed_backend import DeepSpeedBackend
2 | from .distributed_backend import DistributedBackend
3 | from .dummy_backend import DummyBackend
4 | from .horovod_backend import HorovodBackend
5 |
--------------------------------------------------------------------------------
/dalle_pytorch/distributed_backends/deepspeed_backend.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | import torch
5 |
6 | from .distributed_backend import DistributedBackend
7 |
8 |
9 | class DeepSpeedBackend(DistributedBackend):
10 | """Distributed backend using the DeepSpeed engine."""
11 |
12 | BACKEND_MODULE_NAME = 'deepspeed'
13 | BACKEND_NAME = 'DeepSpeed'
14 |
15 | def wrap_arg_parser(self, parser):
16 | if not self.has_backend():
17 | parser.add_argument(
18 | '--deepspeed',
19 | type=lambda _: False,
20 | help=(
21 | 'whether to use DeepSpeed '
22 | "(ignored since it's not available)"
23 | ),
24 | )
25 | else:
26 | parser = self.backend_module.add_config_arguments(parser)
27 |
28 | parser.add_argument(
29 | '--local_rank',
30 | type=int,
31 | default=-1,
32 | help='local rank passed from distributed launcher',
33 | )
34 | return parser
35 |
36 | def _initialize(self):
37 | self.backend_module.init_distributed()
38 | if torch.cuda.is_available():
39 | torch.cuda.set_device(self._get_local_rank())
40 |
41 | @staticmethod
42 | def _require_torch_distributed_init():
43 | """Raise an error when `torch.distributed` has not been
44 | initialized yet.
45 | """
46 | assert torch.distributed.is_initialized(), \
47 | ('`torch.distributed` is not initialized; please call '
48 | '`DeepSpeedBackend.initialize` at the start of your script')
49 |
50 | def _get_world_size(self):
51 | self._require_torch_distributed_init()
52 | return torch.distributed.get_world_size()
53 |
54 | def _get_rank(self):
55 | self._require_torch_distributed_init()
56 | return torch.distributed.get_rank()
57 |
58 | def _get_local_rank(self):
59 | self._require_torch_distributed_init()
60 | return int(os.environ['LOCAL_RANK'])
61 |
62 | def _local_barrier(self):
63 | self._require_torch_distributed_init()
64 | torch.distributed.barrier()
65 |
66 | def _check_args(self, args, optimizer, lr_scheduler, kwargs):
67 | """Return an appropriate optimizer and learning rate scheduler
68 | after checking the values passed to `distribute`.
69 | """
70 | self._check_argvs(args, optimizer, lr_scheduler, kwargs)
71 | (optimizer, lr_scheduler) = self._check_config(
72 | args, optimizer, lr_scheduler, kwargs)
73 | return (optimizer, lr_scheduler)
74 |
75 | def _check_argvs(self, args, optimizer, lr_scheduler, kwargs):
76 | """Apply several sanity checks to the given command
77 | line arguments.
78 | """
79 | has_json_config = (hasattr(args, 'deepspeed_config')
80 | and args.deepspeed_config is not None)
81 | has_dict_config = 'config_params' in kwargs
82 | if (
83 | # No config given
84 | (not has_json_config and not has_dict_config)
85 | # JSON config file does not exist
86 | or (not has_dict_config
87 | and not os.path.isfile(args.deepspeed_config))
88 | ):
89 | # Let DeepSpeed handle these argument errors.
90 | return
91 |
92 | if not args.deepspeed:
93 | print(
94 | 'WARNING: DeepSpeed backend was selected; setting '
95 | '`args.deepspeed = True`'
96 | )
97 | args.deepspeed = True
98 |
99 | if has_json_config and has_dict_config:
100 | print(
101 | 'WARNING: DeepSpeed config was given as both JSON file and '
102 | 'Python dictionary. Python dictionary takes precedence.'
103 | )
104 |
105 | def _check_config(self, args, optimizer, lr_scheduler, kwargs):
106 | """Return an appropriate optimizer and learning rate scheduler
107 | for the DeepSpeed configuration.
108 | """
109 | if 'config_params' in kwargs:
110 | config = kwargs['config_params']
111 | else:
112 | with open(args.deepspeed_config, 'r') as json_config_file:
113 | config = json.load(json_config_file)
114 |
115 | if 'optimizer' in config and optimizer is not None:
116 | print(
117 | 'WARNING: Optimizer encountered in both DeepSpeed config and '
118 | 'keyword arguments. Optimizer in DeepSpeed config '
119 | 'takes precedence.'
120 | )
121 | optimizer = None
122 |
123 | if 'scheduler' in config and lr_scheduler is not None:
124 | print(
125 | 'WARNING: Learning rate scheduler encountered in both '
126 | 'DeepSpeed config and keyword arguments. Learning rate '
127 | 'scheduler in DeepSpeed config takes precedence.'
128 | )
129 | # For the LR scheduler, the JSON config already has
130 | # precedence. We do this for forward compatibility.
131 | lr_scheduler = None
132 |
133 | return (optimizer, lr_scheduler)
134 |
135 | def _distribute(
136 | self,
137 | args=None,
138 | model=None,
139 | optimizer=None,
140 | model_parameters=None,
141 | training_data=None,
142 | lr_scheduler=None,
143 | **kwargs,
144 | ):
145 | """Return a distributed model engine, optimizer, dataloader, and
146 | learning rate scheduler. These are obtained by wrapping the
147 | given values with the backend.
148 |
149 | For the other or other possible arguments,
150 | see `deepspeed.initialize`.
151 | """
152 | (optimizer, lr_scheduler) = self._check_args(
153 | args, optimizer, lr_scheduler, kwargs)
154 |
155 | return self.backend_module.initialize(
156 | args=args,
157 | model=model,
158 | optimizer=optimizer,
159 | model_parameters=model_parameters,
160 | training_data=training_data,
161 | lr_scheduler=lr_scheduler,
162 | **kwargs,
163 | )
164 |
165 | def _average_all(self, tensor):
166 | self._require_torch_distributed_init()
167 | # We copy because modification happens in-place
168 | averaged = tensor.detach().clone()
169 | # We use `all_reduce` because it is better supported than `reduce`
170 | torch.distributed.all_reduce(averaged, torch.distributed.ReduceOp.SUM)
171 | return averaged / self.get_world_size()
172 |
--------------------------------------------------------------------------------
/dalle_pytorch/distributed_backends/distributed_backend.py:
--------------------------------------------------------------------------------
1 | """
2 | An abstract backend for distributed deep learning.
3 |
4 | Provides several standard utility methods under a common API.
5 | Please check the documentation of the class `DistributedBackend` for
6 | details to implement a new backend.
7 | """
8 |
9 | from importlib import import_module
10 |
11 |
12 | class DistributedBackend:
13 | """An abstract backend class for distributed deep learning.
14 |
15 | Provides several standard utility methods under a common API.
16 | Variables that must be overridden:
17 | - BACKEND_MODULE_NAME
18 | - BACKEND_NAME
19 | Methods that must be overridden:
20 | - wrap_arg_parser
21 | - _initialize
22 | - _get_world_size
23 | - _get_rank
24 | - _get_local_rank
25 | - _local_barrier
26 | - _distribute
27 | - _average_all
28 | """
29 |
30 | BACKEND_MODULE_NAME = None
31 | """Name of the module to import for the backend."""
32 | BACKEND_NAME = None
33 | """Name of the backend for printing."""
34 |
35 | ROOT_RANK = 0
36 |
37 | backend_module = None
38 | """The module to access the backend."""
39 | is_initialized = False
40 | """Whether the backend is initialized."""
41 |
42 | def __init__(self):
43 | if self.BACKEND_MODULE_NAME is None:
44 | raise NotImplementedError('BACKEND_MODULE_NAME is not set')
45 | if self.BACKEND_NAME is None:
46 | raise NotImplementedError('BACKEND_NAME is not set')
47 |
48 | def has_backend(self):
49 | """Return whether the backend module is now imported."""
50 | try:
51 | self.backend_module = import_module(self.BACKEND_MODULE_NAME)
52 | except ModuleNotFoundError:
53 | return False
54 | return True
55 |
56 | def check_batch_size(self, batch_size):
57 | """Check whether the batch size makes sense for distribution."""
58 | assert batch_size >= self.get_world_size(), \
59 | (f"batch size can't be smaller than number of processes "
60 | f'({batch_size} < {self.get_world_size()})')
61 |
62 | def wrap_arg_parser(self, parser):
63 | """Add arguments to support optional distributed backend usage."""
64 | raise NotImplementedError
65 |
66 | def initialize(self):
67 | """Initialize the distributed backend."""
68 | self._initialize()
69 | self.is_initialized = True
70 |
71 | def _initialize(self):
72 | """Initialize the distributed backend."""
73 | raise NotImplementedError
74 |
75 | def require_init(self):
76 | """Raise an error when the backend has not been initialized yet."""
77 | assert self.is_initialized, \
78 | (f'{BACKEND_NAME} backend has not been initialized; please call '
79 | f'`distributed_utils.initialize` at the start of your script to '
80 | f'allow optional distributed usage')
81 |
82 | def get_world_size(self):
83 | """Return the amount of distributed processes."""
84 | self.require_init()
85 | return self._get_world_size()
86 |
87 | def _get_world_size(self):
88 | """Return the amount of distributed processes."""
89 | raise NotImplementedError
90 |
91 | def get_rank(self):
92 | """Return the global rank of the calling worker process."""
93 | self.require_init()
94 | return self._get_rank()
95 |
96 | def _get_rank(self):
97 | """Return the global rank of the calling worker process."""
98 | raise NotImplementedError
99 |
100 | def get_local_rank(self):
101 | """Return the local rank of the calling worker process.
102 | The local rank is the rank based on a single node's processes.
103 | """
104 | self.require_init()
105 | return self._get_local_rank()
106 |
107 | def _get_local_rank(self):
108 | """Return the local rank of the calling worker process.
109 | The local rank is the rank based on a single node's processes.
110 | """
111 | raise NotImplementedError
112 |
113 | def is_root_worker(self):
114 | """Return whether the calling worker has the root rank."""
115 | return self.get_rank() == self.ROOT_RANK
116 |
117 | def is_local_root_worker(self):
118 | """Return whether the calling worker has the root rank on this node."""
119 | return self.get_local_rank() == self.ROOT_RANK
120 |
121 | def local_barrier(self):
122 | """Wait until all processes on this node have called this function."""
123 | self.require_init()
124 | self._local_barrier()
125 |
126 | def _local_barrier(self):
127 | """Wait until all processes on this node have called this function."""
128 | raise NotImplementedError
129 |
130 | def distribute(
131 | self,
132 | args=None,
133 | model=None,
134 | optimizer=None,
135 | model_parameters=None,
136 | training_data=None,
137 | lr_scheduler=None,
138 | **kwargs,
139 | ):
140 | """Return a distributed model engine, optimizer, dataloader, and
141 | learning rate scheduler. These are obtained by wrapping the
142 | given values with the backend.
143 | """
144 | self.require_init()
145 | return self._distribute(
146 | args,
147 | model,
148 | optimizer,
149 | model_parameters,
150 | training_data,
151 | lr_scheduler,
152 | **kwargs,
153 | )
154 |
155 | def _distribute(
156 | self,
157 | args=None,
158 | model=None,
159 | optimizer=None,
160 | model_parameters=None,
161 | training_data=None,
162 | lr_scheduler=None,
163 | **kwargs,
164 | ):
165 | """Return a distributed model engine, optimizer, dataloader, and
166 | learning rate scheduler. These are obtained by wrapping the
167 | given values with the backend.
168 | """
169 | raise NotImplementedError
170 |
171 | def average_all(self, tensor):
172 | """Return the average of `tensor` over all workers."""
173 | self.require_init()
174 | return self._average_all(tensor)
175 |
176 | def _average_all(self, tensor):
177 | """Return the average of `tensor` over all workers."""
178 | raise NotImplementedError
179 |
--------------------------------------------------------------------------------
/dalle_pytorch/distributed_backends/dummy_backend.py:
--------------------------------------------------------------------------------
1 | from .distributed_backend import DistributedBackend
2 |
3 |
4 | class DummyBackend(DistributedBackend):
5 | """Acts like a distributed backend.
6 |
7 | Used as a stand-in replacement to obtain a non-distributed program.
8 | """
9 |
10 | # We define this so we can use `super().__init__` but want this to
11 | # throw an error upon import.
12 | BACKEND_MODULE_NAME = 'NO MODULE'
13 | BACKEND_NAME = 'Dummy'
14 |
15 | def has_backend(self):
16 | return True
17 |
18 | def wrap_arg_parser(self, parser):
19 | return parser
20 |
21 | def _initialize(self):
22 | pass
23 |
24 | def _get_world_size(self):
25 | return 1
26 |
27 | def _get_rank(self):
28 | return self.ROOT_RANK
29 |
30 | def _get_local_rank(self):
31 | return self.ROOT_RANK
32 |
33 | def _local_barrier(self):
34 | pass
35 |
36 | def _distribute(
37 | self,
38 | _args=None,
39 | model=None,
40 | optimizer=None,
41 | _model_parameters=None,
42 | training_data=None,
43 | lr_scheduler=None,
44 | **_kwargs,
45 | ):
46 | """Return the model, optimizer, dataloader, and learning rate scheduler
47 | as is.
48 | """
49 | return (model, optimizer, training_data, lr_scheduler)
50 |
51 | def _average_all(self, tensor):
52 | return tensor
53 |
--------------------------------------------------------------------------------
/dalle_pytorch/distributed_backends/horovod_backend.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .distributed_backend import DistributedBackend
4 |
5 |
6 | class HorovodBackend(DistributedBackend):
7 | """Distributed backend using Horovod."""
8 |
9 | BACKEND_MODULE_NAME = 'horovod.torch'
10 | BACKEND_NAME = 'Horovod'
11 |
12 | def wrap_arg_parser(self, parser):
13 | return parser
14 |
15 | def check_batch_size(self, batch_size):
16 | # Horovod uses the local batch size to determine the effective
17 | # batch size.
18 | pass
19 |
20 | def _initialize(self):
21 | self.backend_module.init()
22 | if torch.cuda.is_available():
23 | torch.cuda.set_device(self._get_local_rank())
24 |
25 | def _get_world_size(self):
26 | return self.backend_module.size()
27 |
28 | def _get_rank(self):
29 | return self.backend_module.rank()
30 |
31 | def _get_local_rank(self):
32 | return self.backend_module.local_rank()
33 |
34 | def _local_barrier(self):
35 | # Actually a global barrier but works for our purposes.
36 | self.backend_module.join()
37 |
38 | def _distribute(
39 | self,
40 | _args=None,
41 | model=None,
42 | optimizer=None,
43 | _model_parameters=None,
44 | training_data=None,
45 | lr_scheduler=None,
46 | **_kwargs,
47 | ):
48 | optimizer = self.backend_module.DistributedOptimizer(optimizer)
49 | self.backend_module.broadcast_parameters(
50 | model.state_dict(), root_rank=self.ROOT_RANK)
51 | self.backend_module.broadcast_optimizer_state(
52 | optimizer, root_rank=self.ROOT_RANK)
53 | return (model, optimizer, training_data, lr_scheduler)
54 |
55 | def _average_all(self, tensor):
56 | # Reduce op is average by default
57 | averaged = self.backend_module.allreduce(tensor)
58 | return averaged
59 |
--------------------------------------------------------------------------------
/dalle_pytorch/distributed_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility functions for optional distributed execution.
3 |
4 | To use,
5 | 1. set the `BACKENDS` to the ones you want to make available,
6 | 2. in the script, wrap the argument parser with `wrap_arg_parser`,
7 | 3. in the script, set and use the backend by calling
8 | `set_backend_from_args`.
9 |
10 | You can check whether a backend is in use with the `using_backend`
11 | function.
12 | """
13 |
14 | from dalle_pytorch.distributed_backends import \
15 | DeepSpeedBackend, \
16 | DummyBackend, \
17 | HorovodBackend
18 |
19 | _DEFAULT_BACKEND = DummyBackend()
20 | """Which backend to use by default. Assumed to be _not_ distributed."""
21 |
22 | BACKENDS = [
23 | _DEFAULT_BACKEND,
24 | DeepSpeedBackend(),
25 | HorovodBackend(),
26 | ]
27 |
28 | is_distributed = None
29 | """Whether we are distributed."""
30 | backend = None
31 | """Backend in usage."""
32 |
33 |
34 | def wrap_arg_parser(parser):
35 | """Add arguments to support optional distributed backend usage."""
36 | parser.add_argument(
37 | '--distributed_backend',
38 | '--distr_backend',
39 | type=str,
40 | default=None,
41 | help='which distributed backend to use. Do not distribute by default',
42 | )
43 | for distr_backend in BACKENDS:
44 | parser = distr_backend.wrap_arg_parser(parser)
45 | return parser
46 |
47 |
48 | def set_backend_from_args(args):
49 | """Set and return the backend based on the given `args`."""
50 | global is_distributed, backend
51 |
52 | # Handle this specially for backwards compatibility.
53 | if args.deepspeed:
54 | args.distributed_backend = DeepSpeedBackend.BACKEND_NAME
55 |
56 | if not args.distributed_backend:
57 | is_distributed = False
58 | backend = _DEFAULT_BACKEND
59 | return backend
60 |
61 | backend_name = args.distributed_backend.lower()
62 | for distr_backend in BACKENDS:
63 | if distr_backend.BACKEND_NAME.lower() == backend_name:
64 | backend = distr_backend
65 | if not backend.has_backend():
66 | raise ModuleNotFoundError(
67 | f'{backend.BACKEND_NAME} backend selected but '
68 | 'module not available'
69 | )
70 |
71 | print(f'Using {backend.BACKEND_NAME} for distributed execution')
72 | is_distributed = True
73 | return backend
74 |
75 | raise ValueError(
76 | 'unknown backend; please check `distributed_utils.BACKENDS`')
77 |
78 |
79 | def require_set_backend():
80 | """Raise an `AssertionError` when the backend has not been set."""
81 | assert backend is not None, (
82 | 'distributed backend is not set. Please call '
83 | '`distributed_utils.set_backend_from_args` at the start of your script'
84 | )
85 |
86 |
87 | def using_backend(test_backend):
88 | """Return whether the backend is set to `test_backend`.
89 |
90 | `test_backend` may be a string of the name of the backend or
91 | its class.
92 | """
93 | require_set_backend()
94 | if isinstance(test_backend, str):
95 | return backend.BACKEND_NAME == test_backend
96 | return isinstance(backend, test_backend)
97 |
--------------------------------------------------------------------------------
/dalle_pytorch/loader.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from random import randint, choice
3 |
4 | import PIL
5 |
6 | from torch.utils.data import Dataset
7 | from torchvision import transforms as T
8 |
9 |
10 | class TextImageDataset(Dataset):
11 | def __init__(self,
12 | folder,
13 | text_len=256,
14 | image_size=128,
15 | truncate_captions=False,
16 | resize_ratio=0.75,
17 | tokenizer=None,
18 | shuffle=False
19 | ):
20 | """
21 | @param folder: Folder containing images and text files matched by their paths' respective "stem"
22 | @param truncate_captions: Rather than throw an exception, captions which are too long will be truncated.
23 | """
24 | super().__init__()
25 | self.shuffle = shuffle
26 | path = Path(folder)
27 |
28 | text_files = [*path.glob('**/*.txt')]
29 | image_files = [
30 | *path.glob('**/*.png'), *path.glob('**/*.jpg'),
31 | *path.glob('**/*.jpeg'), *path.glob('**/*.bmp')
32 | ]
33 |
34 | text_files = {text_file.stem: text_file for text_file in text_files}
35 | image_files = {image_file.stem: image_file for image_file in image_files}
36 |
37 | keys = (image_files.keys() & text_files.keys())
38 |
39 | self.keys = list(keys)
40 | self.text_files = {k: v for k, v in text_files.items() if k in keys}
41 | self.image_files = {k: v for k, v in image_files.items() if k in keys}
42 | self.text_len = text_len
43 | self.truncate_captions = truncate_captions
44 | self.resize_ratio = resize_ratio
45 | self.tokenizer = tokenizer
46 | self.image_transform = T.Compose([
47 | T.Lambda(lambda img: img.convert('RGB')
48 | if img.mode != 'RGB' else img),
49 | T.RandomResizedCrop(image_size,
50 | scale=(self.resize_ratio, 1.),
51 | ratio=(1., 1.)),
52 | T.ToTensor()
53 | ])
54 |
55 | def __len__(self):
56 | return len(self.keys)
57 |
58 | def random_sample(self):
59 | return self.__getitem__(randint(0, self.__len__() - 1))
60 |
61 | def sequential_sample(self, ind):
62 | if ind >= self.__len__() - 1:
63 | return self.__getitem__(0)
64 | return self.__getitem__(ind + 1)
65 |
66 | def skip_sample(self, ind):
67 | if self.shuffle:
68 | return self.random_sample()
69 | return self.sequential_sample(ind=ind)
70 |
71 | def __getitem__(self, ind):
72 | key = self.keys[ind]
73 |
74 | text_file = self.text_files[key]
75 | image_file = self.image_files[key]
76 |
77 | descriptions = text_file.read_text().split('\n')
78 | descriptions = list(filter(lambda t: len(t) > 0, descriptions))
79 | try:
80 | description = choice(descriptions)
81 | except IndexError as zero_captions_in_file_ex:
82 | print(f"An exception occurred trying to load file {text_file}.")
83 | print(f"Skipping index {ind}")
84 | return self.skip_sample(ind)
85 |
86 | tokenized_text = self.tokenizer.tokenize(
87 | description,
88 | self.text_len,
89 | truncate_text=self.truncate_captions
90 | ).squeeze(0)
91 | try:
92 | image_tensor = self.image_transform(PIL.Image.open(image_file))
93 | except (PIL.UnidentifiedImageError, OSError) as corrupt_image_exceptions:
94 | print(f"An exception occurred trying to load file {image_file}.")
95 | print(f"Skipping index {ind}")
96 | return self.skip_sample(ind)
97 |
98 | # Success
99 | return tokenized_text, image_tensor
100 |
--------------------------------------------------------------------------------
/dalle_pytorch/reversible.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from operator import itemgetter
4 | from torch.autograd.function import Function
5 | from torch.utils.checkpoint import get_device_states, set_device_states
6 |
7 | # for routing arguments into the functions of the reversible layer
8 | def route_args(router, args, depth):
9 | routed_args = [(dict(), dict()) for _ in range(depth)]
10 | matched_keys = [key for key in args.keys() if key in router]
11 |
12 | for key in matched_keys:
13 | val = args[key]
14 | for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
15 | new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
16 | routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
17 | return routed_args
18 |
19 | # following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html
20 | class Deterministic(nn.Module):
21 | def __init__(self, net):
22 | super().__init__()
23 | self.net = net
24 | self.cpu_state = None
25 | self.cuda_in_fwd = None
26 | self.gpu_devices = None
27 | self.gpu_states = None
28 |
29 | def record_rng(self, *args):
30 | self.cpu_state = torch.get_rng_state()
31 | if torch.cuda._initialized:
32 | self.cuda_in_fwd = True
33 | self.gpu_devices, self.gpu_states = get_device_states(*args)
34 |
35 | def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
36 | if record_rng:
37 | self.record_rng(*args)
38 |
39 | if not set_rng:
40 | return self.net(*args, **kwargs)
41 |
42 | rng_devices = []
43 | if self.cuda_in_fwd:
44 | rng_devices = self.gpu_devices
45 |
46 | with torch.random.fork_rng(devices=rng_devices, enabled=True):
47 | torch.set_rng_state(self.cpu_state)
48 | if self.cuda_in_fwd:
49 | set_device_states(self.gpu_devices, self.gpu_states)
50 | return self.net(*args, **kwargs)
51 |
52 | # heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py
53 | # once multi-GPU is confirmed working, refactor and send PR back to source
54 | class ReversibleBlock(nn.Module):
55 | def __init__(self, f, g):
56 | super().__init__()
57 | self.f = Deterministic(f)
58 | self.g = Deterministic(g)
59 |
60 | def forward(self, x, f_args = {}, g_args = {}):
61 | x1, x2 = torch.chunk(x, 2, dim=2)
62 | y1, y2 = None, None
63 |
64 | with torch.no_grad():
65 | y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
66 | y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
67 |
68 | return torch.cat([y1, y2], dim=2)
69 |
70 | def backward_pass(self, y, dy, f_args = {}, g_args = {}):
71 | y1, y2 = torch.chunk(y, 2, dim=2)
72 | del y
73 |
74 | dy1, dy2 = torch.chunk(dy, 2, dim=2)
75 | del dy
76 |
77 | with torch.enable_grad():
78 | y1.requires_grad = True
79 | gy1 = self.g(y1, set_rng=True, **g_args)
80 | torch.autograd.backward(gy1, dy2)
81 |
82 | with torch.no_grad():
83 | x2 = y2 - gy1
84 | del y2, gy1
85 |
86 | dx1 = dy1 + y1.grad
87 | del dy1
88 | y1.grad = None
89 |
90 | with torch.enable_grad():
91 | x2.requires_grad = True
92 | fx2 = self.f(x2, set_rng=True, **f_args)
93 | torch.autograd.backward(fx2, dx1, retain_graph=True)
94 |
95 | with torch.no_grad():
96 | x1 = y1 - fx2
97 | del y1, fx2
98 |
99 | dx2 = dy2 + x2.grad
100 | del dy2
101 | x2.grad = None
102 |
103 | x = torch.cat([x1, x2.detach()], dim=2)
104 | dx = torch.cat([dx1, dx2], dim=2)
105 |
106 | return x, dx
107 |
108 | class _ReversibleFunction(Function):
109 | @staticmethod
110 | def forward(ctx, x, blocks, args):
111 | ctx.args = args
112 | for block, kwarg in zip(blocks, args):
113 | x = block(x, **kwarg)
114 | ctx.y = x.detach()
115 | ctx.blocks = blocks
116 | return x
117 |
118 | @staticmethod
119 | def backward(ctx, dy):
120 | y = ctx.y
121 | args = ctx.args
122 | for block, kwargs in zip(ctx.blocks[::-1], args[::-1]):
123 | y, dy = block.backward_pass(y, dy, **kwargs)
124 | return dy, None, None
125 |
126 | class SequentialSequence(nn.Module):
127 | def __init__(self, layers, args_route = {}, layer_dropout = 0.):
128 | super().__init__()
129 | assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
130 | self.layers = layers
131 | self.args_route = args_route
132 | self.layer_dropout = layer_dropout
133 |
134 | def forward(self, x, **kwargs):
135 | args = route_args(self.args_route, kwargs, len(self.layers))
136 | layers_and_args = list(zip(self.layers, args))
137 |
138 | for (f, g), (f_args, g_args) in layers_and_args:
139 | x = x + f(x, **f_args)
140 | x = x + g(x, **g_args)
141 | return x
142 |
143 | class ReversibleSequence(nn.Module):
144 | def __init__(self, blocks, args_route = {}):
145 | super().__init__()
146 | self.args_route = args_route
147 | self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks])
148 |
149 | def forward(self, x, **kwargs):
150 | x = torch.cat([x, x], dim=-1)
151 |
152 | blocks = self.blocks
153 | args = route_args(self.args_route, kwargs, len(blocks))
154 | args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args))
155 |
156 | out = _ReversibleFunction.apply(x, blocks, args)
157 | return torch.stack(out.chunk(2, dim=-1)).mean(dim=0)
158 |
--------------------------------------------------------------------------------
/dalle_pytorch/tokenizer.py:
--------------------------------------------------------------------------------
1 | # take from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
2 | # to give users a quick easy start to training DALL-E without doing BPE
3 |
4 | import torch
5 |
6 | import youtokentome as yttm
7 | from tokenizers import Tokenizer
8 | from tokenizers.processors import ByteLevel
9 | from transformers import BertTokenizer
10 |
11 | import html
12 | import os
13 | from functools import lru_cache
14 | from pathlib import Path
15 | import ftfy
16 | import regex as re
17 |
18 | # OpenAI simple tokenizer
19 |
20 | @lru_cache()
21 | def default_bpe():
22 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/bpe_simple_vocab_16e6.txt")
23 |
24 | @lru_cache()
25 | def bytes_to_unicode():
26 | bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2 ** 8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2 ** 8 + n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 | def get_pairs(word):
38 | pairs = set()
39 | prev_char = word[0]
40 | for char in word[1:]:
41 | pairs.add((prev_char, char))
42 | prev_char = char
43 | return pairs
44 |
45 | def basic_clean(text):
46 | text = ftfy.fix_text(text)
47 | text = html.unescape(html.unescape(text))
48 | return text.strip()
49 |
50 | def whitespace_clean(text):
51 | text = re.sub(r'\s+', ' ', text)
52 | text = text.strip()
53 | return text
54 |
55 | class SimpleTokenizer(object):
56 | def __init__(self, bpe_path = default_bpe()):
57 | self.byte_encoder = bytes_to_unicode()
58 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
59 | merges = Path(bpe_path).read_text(encoding='utf8').split('\n')
60 | merges = merges[1:49152 - 256 - 2 + 1]
61 | merges = [tuple(merge.split()) for merge in merges]
62 | vocab = list(bytes_to_unicode().values())
63 | vocab = vocab + [v + '' for v in vocab]
64 | for merge in merges:
65 | vocab.append(''.join(merge))
66 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
67 |
68 | self.vocab_size = 49408
69 |
70 | self.encoder = dict(zip(vocab, range(len(vocab))))
71 | self.decoder = {v: k for k, v in self.encoder.items()}
72 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
73 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
74 | self.pat = re.compile(
75 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
76 | re.IGNORECASE)
77 |
78 | def bpe(self, token):
79 | if token in self.cache:
80 | return self.cache[token]
81 | word = tuple(token[:-1]) + (token[-1] + '',)
82 | pairs = get_pairs(word)
83 |
84 | if not pairs:
85 | return token + ''
86 |
87 | while True:
88 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
89 | if bigram not in self.bpe_ranks:
90 | break
91 | first, second = bigram
92 | new_word = []
93 | i = 0
94 | while i < len(word):
95 | try:
96 | j = word.index(first, i)
97 | new_word.extend(word[i:j])
98 | i = j
99 | except:
100 | new_word.extend(word[i:])
101 | break
102 |
103 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
104 | new_word.append(first + second)
105 | i += 2
106 | else:
107 | new_word.append(word[i])
108 | i += 1
109 | new_word = tuple(new_word)
110 | word = new_word
111 | if len(word) == 1:
112 | break
113 | else:
114 | pairs = get_pairs(word)
115 | word = ' '.join(word)
116 | self.cache[token] = word
117 | return word
118 |
119 | def encode(self, text):
120 | bpe_tokens = []
121 | text = whitespace_clean(basic_clean(text)).lower()
122 | for token in re.findall(self.pat, text):
123 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
124 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
125 | return bpe_tokens
126 |
127 | def decode(self, tokens, remove_start_end = True, pad_tokens = {}):
128 | if torch.is_tensor(tokens):
129 | tokens = tokens.tolist()
130 |
131 | if remove_start_end:
132 | tokens = [token for token in tokens if token not in (49406, 40407, 0)]
133 | text = ''.join([self.decoder[token] for token in tokens if token not in pad_tokens])
134 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
135 | return text
136 |
137 | def tokenize(self, texts, context_length = 256, truncate_text = False):
138 | if isinstance(texts, str):
139 | texts = [texts]
140 |
141 | all_tokens = [self.encode(text) for text in texts]
142 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
143 |
144 | for i, tokens in enumerate(all_tokens):
145 | if len(tokens) > context_length:
146 | if truncate_text:
147 | tokens = tokens[:context_length]
148 | else:
149 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
150 | result[i, :len(tokens)] = torch.tensor(tokens)
151 |
152 | return result
153 |
154 | tokenizer = SimpleTokenizer()
155 |
156 | # huggingface tokenizer
157 |
158 | class HugTokenizer:
159 | def __init__(self, bpe_path = None):
160 | bpe_path = Path(bpe_path)
161 | assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'
162 | tokenizer = Tokenizer.from_file(str(bpe_path))
163 | tokenizer.post_processor = ByteLevel(trim_offsets = True)
164 | self.tokenizer = tokenizer
165 | self.vocab_size = tokenizer.get_vocab_size()
166 |
167 | def decode(self, tokens, pad_tokens = {}):
168 | if torch.is_tensor(tokens):
169 | tokens = tokens.tolist()
170 | ignore_ids = pad_tokens.union({0})
171 | tokens = [token for token in tokens if token not in ignore_ids]
172 | return self.tokenizer.decode(tokens, skip_special_tokens = True)
173 |
174 | def encode(self, text):
175 | return self.tokenizer.encode(text).ids
176 |
177 | def tokenize(self, texts, context_length = 256, truncate_text = False):
178 | if isinstance(texts, str):
179 | texts = [texts]
180 |
181 | all_tokens = [self.encode(text) for text in texts]
182 |
183 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
184 | for i, tokens in enumerate(all_tokens):
185 | if len(tokens) > context_length:
186 | if truncate_text:
187 | tokens = tokens[:context_length]
188 | else:
189 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
190 | result[i, :len(tokens)] = torch.tensor(tokens)
191 |
192 | return result
193 |
194 | # chinese tokenizer
195 |
196 | class ChineseTokenizer:
197 | def __init__(self):
198 | tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
199 | self.tokenizer = tokenizer
200 | self.vocab_size = tokenizer.vocab_size
201 |
202 | def decode(self, tokens, pad_tokens = {}):
203 | if torch.is_tensor(tokens):
204 | tokens = tokens.tolist()
205 |
206 | ignore_ids = pad_tokens.union({0})
207 | tokens = [token for token in tokens if token not in ignore_ids]
208 | return self.tokenizer.decode(tokens)
209 |
210 | def encode(self, text):
211 | return torch.tensor(self.tokenizer.encode(text, add_special_tokens = False))
212 |
213 | def tokenize(self, texts, context_length = 256, truncate_text = False):
214 | if isinstance(texts, str):
215 | texts = [texts]
216 |
217 | all_tokens = [self.encode(text) for text in texts]
218 |
219 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
220 | for i, tokens in enumerate(all_tokens):
221 | if len(tokens) > context_length:
222 | if truncate_text:
223 | tokens = tokens[:context_length]
224 | else:
225 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
226 | result[i, :len(tokens)] = torch.tensor(tokens)
227 |
228 | return result
229 |
230 | # yttm tokenizer
231 |
232 | class YttmTokenizer:
233 | def __init__(self, bpe_path = None):
234 | bpe_path = Path(bpe_path)
235 | assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'
236 |
237 | tokenizer = yttm.BPE(model = str(bpe_path))
238 | self.tokenizer = tokenizer
239 | self.vocab_size = tokenizer.vocab_size()
240 |
241 | def decode(self, tokens, pad_tokens = {}):
242 | if torch.is_tensor(tokens):
243 | tokens = tokens.tolist()
244 |
245 | return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0}))
246 |
247 | def encode(self, texts):
248 | encoded = self.tokenizer.encode(texts, output_type = yttm.OutputType.ID)
249 | return list(map(torch.tensor, encoded))
250 |
251 | def tokenize(self, texts, context_length = 256, truncate_text = False):
252 | if isinstance(texts, str):
253 | texts = [texts]
254 |
255 | all_tokens = self.encode(texts)
256 |
257 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
258 | for i, tokens in enumerate(all_tokens):
259 | if len(tokens) > context_length:
260 | if truncate_text:
261 | tokens = tokens[:context_length]
262 | else:
263 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
264 | result[i, :len(tokens)] = torch.tensor(tokens)
265 |
266 | return result
267 |
--------------------------------------------------------------------------------
/dalle_pytorch/transformer.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from itertools import islice, cycle
3 |
4 | import torch
5 | from torch import nn, einsum
6 | import torch.nn.functional as F
7 | from einops import rearrange
8 |
9 | from dalle_pytorch.reversible import ReversibleSequence, SequentialSequence
10 | from dalle_pytorch.attention import Attention, SparseAttention, SparseConvCausalAttention, SparseAxialCausalAttention
11 |
12 | from rotary_embedding_torch import RotaryEmbedding, broadcat
13 | from g_mlp_pytorch import gMLPBlock
14 |
15 | # helpers
16 |
17 | def exists(val):
18 | return val is not None
19 |
20 | def default(val, d):
21 | return val if exists(val) else d
22 |
23 | def cast_tuple(val, depth = 1):
24 | if isinstance(val, list):
25 | val = tuple(val)
26 | return val if isinstance(val, tuple) else (val,) * depth
27 |
28 | # classes
29 |
30 | class DivideMax(nn.Module):
31 | def __init__(self, dim):
32 | super().__init__()
33 | self.dim = dim
34 |
35 | def forward(self, x):
36 | maxes = x.amax(dim = self.dim, keepdim = True).detach()
37 | return x / maxes
38 |
39 | # https://arxiv.org/abs/2103.17239
40 | class LayerScale(nn.Module):
41 | def __init__(self, dim, depth, fn):
42 | super().__init__()
43 | if depth <= 18:
44 | init_eps = 0.1
45 | elif depth > 18 and depth <= 24:
46 | init_eps = 1e-5
47 | else:
48 | init_eps = 1e-6
49 |
50 | scale = torch.zeros(1, 1, dim).fill_(init_eps)
51 | self.scale = nn.Parameter(scale)
52 | self.fn = fn
53 | def forward(self, x, **kwargs):
54 | return self.fn(x, **kwargs) * self.scale
55 |
56 | # layer norm
57 |
58 | class PreNorm(nn.Module):
59 | def __init__(self, dim, fn, sandwich = False):
60 | super().__init__()
61 | self.norm = nn.LayerNorm(dim)
62 | self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity()
63 | self.fn = fn
64 |
65 | def forward(self, x, **kwargs):
66 | x = self.norm(x)
67 | x = self.fn(x, **kwargs)
68 | return self.norm_out(x)
69 |
70 | # feed forward
71 |
72 | class GEGLU(nn.Module):
73 | def forward(self, x):
74 | x, gates = x.chunk(2, dim = -1)
75 | return x * F.gelu(gates)
76 |
77 | class FeedForward(nn.Module):
78 | def __init__(self, dim, dropout = 0., mult = 4.):
79 | super().__init__()
80 | self.net = nn.Sequential(
81 | nn.Linear(dim, dim * mult * 2),
82 | GEGLU(),
83 | nn.Dropout(dropout),
84 | nn.Linear(dim * mult, dim)
85 | )
86 |
87 | def forward(self, x):
88 | return self.net(x)
89 |
90 | # token shift classes
91 |
92 | class PreShiftToken(nn.Module):
93 | def __init__(self, fn, image_size, seq_len):
94 | super().__init__()
95 | self.fn = fn
96 | self.image_size = image_size
97 | self.seq_len = seq_len
98 |
99 | def forward(self, x, **kwargs):
100 | n = x.shape[1]
101 | seq_len, image_size = self.seq_len, self.image_size
102 | img_seq_len = image_size ** 2
103 | text_len = seq_len - img_seq_len + 1
104 | padding = seq_len - n + 1
105 |
106 | # get text and image tokens
107 |
108 | x_text, x_img = x[:, :text_len], x[:, text_len:]
109 | x_img = F.pad(x_img, (0, 0, 0, padding))
110 | x_img = rearrange(x_img, 'b (h w) d -> b h w d', h = image_size)
111 |
112 | # shift 1 from the left for text tokens
113 |
114 | x_text_shift, x_text_pass = x_text.chunk(2, dim = -1)
115 | x_text_shift = F.pad(x_text_shift, (0, 0, 1, -1))
116 | x_text = torch.cat((x_text_shift, x_text_pass), dim = -1)
117 |
118 | # shift from top, left for image tokens
119 |
120 | x_img_shift_top, x_img_shift_left, *x_img_pass = x_img.chunk(4, dim = -1)
121 | x_img_shift_left = F.pad(x_img_shift_left, (0, 0, 1, -1))
122 | x_img_shift_top = F.pad(x_img_shift_top, (0, 0, 0, 0, 1, -1))
123 | x_img = torch.cat((x_img_shift_top, x_img_shift_left, *x_img_pass), dim = -1)
124 |
125 | # merge text and image sequence back together
126 |
127 | x_img = rearrange(x_img, 'b h w d -> b (h w) d')
128 | x = torch.cat((x_text, x_img[:, :-padding]), dim = 1)
129 | return self.fn(x, **kwargs)
130 |
131 | # main transformer class
132 |
133 | class Transformer(nn.Module):
134 | def __init__(
135 | self,
136 | *,
137 | dim,
138 | depth,
139 | seq_len,
140 | reversible = False,
141 | causal = True,
142 | heads = 8,
143 | dim_head = 64,
144 | ff_mult = 4,
145 | attn_dropout = 0.,
146 | ff_dropout = 0.,
147 | attn_types = None,
148 | image_fmap_size = None,
149 | sparse_attn = False,
150 | stable = False,
151 | sandwich_norm = False,
152 | shift_tokens = False,
153 | rotary_emb = True
154 | ):
155 | super().__init__()
156 | layers = nn.ModuleList([])
157 | sparse_layer = cast_tuple(sparse_attn, depth)
158 |
159 | attn_types = default(attn_types, ('full',))
160 | attn_types = cast_tuple(attn_types)
161 | attn_type_layer = islice(cycle(attn_types), depth)
162 |
163 | for ind, sparse_attn, attn_type in zip(range(depth), sparse_layer, attn_type_layer):
164 | if attn_type == 'full':
165 | attn_class = partial(Attention, stable = stable)
166 | elif attn_type == 'sparse':
167 | attn_class = SparseAttention
168 | elif attn_type == 'axial_row':
169 | attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_fmap_size, stable = stable)
170 | elif attn_type == 'axial_col':
171 | attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size, stable = stable)
172 | elif attn_type == 'conv_like':
173 | attn_class = partial(SparseConvCausalAttention, seq_len = seq_len, image_size = image_fmap_size, stable = stable)
174 | elif attn_type == 'mlp':
175 | attn_class = partial(gMLPBlock, seq_len = seq_len)
176 | else:
177 | raise ValueError(f'attention type "{attn_type}" is not valid')
178 |
179 | if attn_type != 'mlp':
180 | attn = attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)
181 | else:
182 | attn = attn_class(dim = dim, causal = causal, dim_ff = dim * 4)
183 |
184 | ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout)
185 |
186 | if shift_tokens:
187 | attn, ff = map(lambda t: PreShiftToken(t, image_size = image_fmap_size, seq_len = seq_len), (attn, ff))
188 |
189 | layers.append(nn.ModuleList([
190 | LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich = sandwich_norm)),
191 | LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich = sandwich_norm))
192 | ]))
193 |
194 | execute_type = ReversibleSequence if reversible else SequentialSequence
195 | route_attn = ((True, False),) * depth
196 | attn_route_map = {'mask': route_attn, 'rotary_pos_emb': route_attn}
197 |
198 | self.layers = execute_type(layers, args_route = attn_route_map)
199 |
200 | # generate positional embeddings for rotary
201 |
202 | pos_emb = None
203 | if rotary_emb:
204 | assert 'mlp' not in attn_types, 'you cannot use gMLPs if rotary embedding is turned on'
205 |
206 | rot_dim = dim_head // 3
207 | img_seq_len = (image_fmap_size ** 2)
208 | text_len = seq_len - img_seq_len + 1
209 |
210 | text_pos_emb = RotaryEmbedding(dim = rot_dim)
211 | img_axial_pos_emb = RotaryEmbedding(dim = rot_dim, freqs_for = 'pixel')
212 |
213 | text_freqs = text_pos_emb(torch.arange(text_len))
214 | img_to_text_freqs = text_pos_emb(torch.full((img_seq_len,), 8192)) # image is given a position far away from text
215 | text_freqs = torch.cat((text_freqs, img_to_text_freqs), dim = 0)
216 |
217 | img_freqs_axial = img_axial_pos_emb(torch.linspace(-1, 1, steps = image_fmap_size))
218 | img_freqs = broadcat((rearrange(img_freqs_axial, 'i d -> i () d'), rearrange(img_freqs_axial, 'j d -> () j d')), dim = -1)
219 | img_freqs = rearrange(img_freqs, 'h w d -> (h w) d')
220 |
221 | text_axial_freqs = img_axial_pos_emb(torch.full((text_len,), -10.)) # text is given a position of -10 apart from the image axial positions, which is from range [-1, 1]
222 | text_axial_freqs = torch.cat((text_axial_freqs, text_axial_freqs), dim = -1)
223 | img_freqs = torch.cat((text_axial_freqs, img_freqs), dim = 0)
224 |
225 | pos_emb = torch.cat((text_freqs, img_freqs), dim = -1)
226 | pos_emb = rearrange(pos_emb[:-1], 'n d -> () () n d')
227 |
228 | self.register_buffer('pos_emb', pos_emb)
229 |
230 | def forward(self, x, **kwargs):
231 | return self.layers(x, rotary_pos_emb = self.pos_emb, **kwargs)
232 |
--------------------------------------------------------------------------------
/dalle_pytorch/vae.py:
--------------------------------------------------------------------------------
1 | import io
2 | import sys
3 | import os
4 | import requests
5 | import PIL
6 | import warnings
7 | import hashlib
8 | import urllib
9 | import yaml
10 | from pathlib import Path
11 | from tqdm import tqdm
12 | from math import sqrt, log
13 | from omegaconf import OmegaConf
14 | from taming.models.vqgan import VQModel, GumbelVQ
15 | import importlib
16 |
17 | import torch
18 | from torch import nn
19 | import torch.nn.functional as F
20 |
21 | from einops import rearrange
22 |
23 | from dalle_pytorch import distributed_utils
24 |
25 | # constants
26 |
27 | CACHE_PATH = os.path.expanduser("~/.cache/dalle")
28 |
29 | OPENAI_VAE_ENCODER_PATH = 'https://cdn.openai.com/dall-e/encoder.pkl'
30 | OPENAI_VAE_DECODER_PATH = 'https://cdn.openai.com/dall-e/decoder.pkl'
31 |
32 | VQGAN_VAE_PATH = 'https://heibox.uni-heidelberg.de/f/140747ba53464f49b476/?dl=1'
33 | VQGAN_VAE_CONFIG_PATH = 'https://heibox.uni-heidelberg.de/f/6ecf2af6c658432c8298/?dl=1'
34 |
35 | # helpers methods
36 |
37 | def exists(val):
38 | return val is not None
39 |
40 | def default(val, d):
41 | return val if exists(val) else d
42 |
43 | def load_model(path):
44 | with open(path, 'rb') as f:
45 | return torch.load(f, map_location = torch.device('cpu'))
46 |
47 | def map_pixels(x, eps = 0.1):
48 | return (1 - 2 * eps) * x + eps
49 |
50 | def unmap_pixels(x, eps = 0.1):
51 | return torch.clamp((x - eps) / (1 - 2 * eps), 0, 1)
52 |
53 | def download(url, filename = None, root = CACHE_PATH):
54 | if (
55 | not distributed_utils.is_distributed
56 | or distributed_utils.backend.is_local_root_worker()
57 | ):
58 | os.makedirs(root, exist_ok = True)
59 | filename = default(filename, os.path.basename(url))
60 |
61 | download_target = os.path.join(root, filename)
62 | download_target_tmp = os.path.join(root, f'tmp.{filename}')
63 |
64 | if os.path.exists(download_target) and not os.path.isfile(download_target):
65 | raise RuntimeError(f"{download_target} exists and is not a regular file")
66 |
67 | if (
68 | distributed_utils.is_distributed
69 | and not distributed_utils.backend.is_local_root_worker()
70 | and not os.path.isfile(download_target)
71 | ):
72 | # If the file doesn't exist yet, wait until it's downloaded by the root worker.
73 | distributed_utils.backend.local_barrier()
74 |
75 | if os.path.isfile(download_target):
76 | return download_target
77 |
78 | with urllib.request.urlopen(url) as source, open(download_target_tmp, "wb") as output:
79 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop:
80 | while True:
81 | buffer = source.read(8192)
82 | if not buffer:
83 | break
84 |
85 | output.write(buffer)
86 | loop.update(len(buffer))
87 |
88 | os.rename(download_target_tmp, download_target)
89 | if (
90 | distributed_utils.is_distributed
91 | and distributed_utils.backend.is_local_root_worker()
92 | ):
93 | distributed_utils.backend.local_barrier()
94 | return download_target
95 |
96 | def make_contiguous(module):
97 | with torch.no_grad():
98 | for param in module.parameters():
99 | param.set_(param.contiguous())
100 |
101 | # pretrained Discrete VAE from OpenAI
102 |
103 | class OpenAIDiscreteVAE(nn.Module):
104 | def __init__(self):
105 | super().__init__()
106 |
107 | self.enc = load_model(download(OPENAI_VAE_ENCODER_PATH))
108 | self.dec = load_model(download(OPENAI_VAE_DECODER_PATH))
109 | make_contiguous(self)
110 |
111 | self.num_layers = 3
112 | self.image_size = 256
113 | self.num_tokens = 8192
114 |
115 | @torch.no_grad()
116 | def get_codebook_indices(self, img):
117 | img = map_pixels(img)
118 | z_logits = self.enc.blocks(img)
119 | z = torch.argmax(z_logits, dim = 1)
120 | return rearrange(z, 'b h w -> b (h w)')
121 |
122 | def decode(self, img_seq):
123 | b, n = img_seq.shape
124 | img_seq = rearrange(img_seq, 'b (h w) -> b h w', h = int(sqrt(n)))
125 |
126 | z = F.one_hot(img_seq, num_classes = self.num_tokens)
127 | z = rearrange(z, 'b h w c -> b c h w').float()
128 | x_stats = self.dec(z).float()
129 | x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3]))
130 | return x_rec
131 |
132 | def forward(self, img):
133 | raise NotImplemented
134 |
135 | # VQGAN from Taming Transformers paper
136 | # https://arxiv.org/abs/2012.09841
137 |
138 | def get_obj_from_str(string, reload=False):
139 | module, cls = string.rsplit(".", 1)
140 | if reload:
141 | module_imp = importlib.import_module(module)
142 | importlib.reload(module_imp)
143 | return getattr(importlib.import_module(module, package=None), cls)
144 |
145 | def instantiate_from_config(config):
146 | if not "target" in config:
147 | raise KeyError("Expected key `target` to instantiate.")
148 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
149 |
150 | class VQGanVAE(nn.Module):
151 | def __init__(self, vqgan_model_path=None, vqgan_config_path=None):
152 | super().__init__()
153 |
154 | if vqgan_model_path is None:
155 | model_filename = 'vqgan.1024.model.ckpt'
156 | config_filename = 'vqgan.1024.config.yml'
157 | download(VQGAN_VAE_CONFIG_PATH, config_filename)
158 | download(VQGAN_VAE_PATH, model_filename)
159 | config_path = str(Path(CACHE_PATH) / config_filename)
160 | model_path = str(Path(CACHE_PATH) / model_filename)
161 | else:
162 | model_path = vqgan_model_path
163 | config_path = vqgan_config_path
164 |
165 | config = OmegaConf.load(config_path)
166 |
167 | model = instantiate_from_config(config["model"])
168 |
169 | state = torch.load(model_path, map_location = 'cpu')['state_dict']
170 | model.load_state_dict(state, strict = False)
171 |
172 | print(f"Loaded VQGAN from {model_path} and {config_path}")
173 |
174 | self.model = model
175 |
176 | # f as used in https://github.com/CompVis/taming-transformers#overview-of-pretrained-models
177 | f = config.model.params.ddconfig.resolution / config.model.params.ddconfig.attn_resolutions[0]
178 | self.num_layers = int(log(f)/log(2))
179 | self.image_size = 256
180 | self.num_tokens = config.model.params.n_embed
181 | self.is_gumbel = isinstance(self.model, GumbelVQ)
182 |
183 | self._register_external_parameters()
184 |
185 | def _register_external_parameters(self):
186 | """Register external parameters for DeepSpeed partitioning."""
187 | if (
188 | not distributed_utils.is_distributed
189 | or not distributed_utils.using_backend(
190 | distributed_utils.DeepSpeedBackend)
191 | ):
192 | return
193 |
194 | deepspeed = distributed_utils.backend.backend_module
195 | deepspeed.zero.register_external_parameter(
196 | self, self.model.quantize.embed.weight if self.is_gumbel else self.model.quantize.embedding.weight)
197 |
198 | @torch.no_grad()
199 | def get_codebook_indices(self, img):
200 | b = img.shape[0]
201 | img = (2 * img) - 1
202 | _, _, [_, _, indices] = self.model.encode(img)
203 | if self.is_gumbel:
204 | return rearrange(indices, 'b h w -> b (h w)', b=b)
205 | return rearrange(indices, '(b n) -> b n', b = b)
206 |
207 | def decode(self, img_seq):
208 | b, n = img_seq.shape
209 | one_hot_indices = F.one_hot(img_seq, num_classes = self.num_tokens).float()
210 | z = one_hot_indices @ self.model.quantize.embed.weight if self.is_gumbel \
211 | else (one_hot_indices @ self.model.quantize.embedding.weight)
212 |
213 | z = rearrange(z, 'b (h w) c -> b c h w', h = int(sqrt(n)))
214 | img = self.model.decode(z)
215 |
216 | img = (img.clamp(-1., 1.) + 1) * 0.5
217 | return img
218 |
219 | def forward(self, img):
220 | raise NotImplemented
221 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 |
2 | ARG IMG_TAG=1.8.1-cuda10.2-cudnn7-devel
3 | ARG IMG_REPO=pytorch
4 |
5 | FROM pytorch/$IMG_REPO:$IMG_TAG
6 |
7 | RUN apt-get -y update && apt-get -y install git gcc llvm-9-dev cmake libaio-dev vim wget
8 |
9 | RUN git clone https://github.com/microsoft/DeepSpeed.git /tmp/DeepSpeed
10 | RUN cd /tmp/DeepSpeed && DS_BUILD_OPS=1 ./install.sh -r
11 | RUN pip install git+https://github.com/lucidrains/DALLE-pytorch.git
12 |
13 | WORKDIR dalle
14 |
--------------------------------------------------------------------------------
/generate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | from tqdm import tqdm
4 |
5 | # torch
6 |
7 | import torch
8 |
9 | from einops import repeat
10 |
11 | # vision imports
12 |
13 | from PIL import Image
14 | from torchvision.utils import make_grid, save_image
15 | from torchvision.transforms import functional as TF
16 |
17 | # dalle related classes and utils
18 |
19 | from dalle_pytorch import DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE, DALLE
20 | from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, YttmTokenizer, ChineseTokenizer
21 |
22 | import numpy as np
23 |
24 | # argument parsing
25 |
26 | parser = argparse.ArgumentParser()
27 |
28 | parser.add_argument('--dalle_path', type = str, required = True,
29 | help='path to your trained DALL-E')
30 |
31 | parser.add_argument('--vqgan_model_path', type=str, default = None,
32 | help='path to your trained VQGAN weights. This should be a .ckpt file. (only valid when taming option is enabled)')
33 |
34 | parser.add_argument('--vqgan_config_path', type=str, default = None,
35 | help='path to your trained VQGAN config. This should be a .yaml file. (only valid when taming option is enabled)')
36 |
37 | parser.add_argument('--text', type = str, required = True,
38 | help='your text prompt')
39 |
40 | parser.add_argument('--num_images', type = int, default = 12, required = False,
41 | help='number of images')
42 |
43 | parser.add_argument('--batch_size', type = int, default = 4, required = False,
44 | help='batch size')
45 |
46 | parser.add_argument('--top_k', type = float, default = None, required = False,
47 | help='top k filter threshold')
48 |
49 | parser.add_argument('--top_p', type = float, default = None, required = False,
50 | help='top p filter threshold')
51 |
52 | parser.add_argument('--temperature', type = float, default = 1.0, required = False,
53 | help='sampling temperature')
54 |
55 | parser.add_argument('--outputs_dir', type = str, default = './outputs', required = False,
56 | help='output directory')
57 |
58 | parser.add_argument('--bpe_path', type = str,
59 | help='path to your huggingface BPE json file')
60 |
61 | parser.add_argument('--clip_sort', dest='clip_sort', action = 'store_true')
62 |
63 | parser.add_argument('--hug', dest='hug', action = 'store_true')
64 |
65 | parser.add_argument('--chinese', dest='chinese', action = 'store_true')
66 |
67 | parser.add_argument('--taming', dest='taming', action='store_true')
68 |
69 | parser.add_argument('--output_npy', dest='output_npy', action='store_true')
70 |
71 | parser.add_argument('--gentxt', dest='gentxt', action='store_true')
72 |
73 | args = parser.parse_args()
74 |
75 | if args.clip_sort:
76 | # load OpenAI clip
77 | import clip
78 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
79 | clip_model, clip_preprocess = clip.load('ViT-B/16', jit=False)
80 | clip_model.eval().requires_grad_(False).to(device)
81 |
82 | # helper fns
83 |
84 | def exists(val):
85 | return val is not None
86 |
87 | # tokenizer
88 |
89 | if exists(args.bpe_path):
90 | klass = HugTokenizer if args.hug else YttmTokenizer
91 | tokenizer = klass(args.bpe_path)
92 | elif args.chinese:
93 | tokenizer = ChineseTokenizer()
94 |
95 | # load DALL-E
96 |
97 | dalle_path = Path(args.dalle_path)
98 |
99 | assert dalle_path.exists(), 'trained DALL-E must exist'
100 |
101 | load_obj = torch.load(str(dalle_path))
102 | dalle_params, vae_params, weights = load_obj.pop('hparams'), load_obj.pop('vae_params'), load_obj.pop('weights')
103 |
104 | dalle_params.pop('vae', None) # cleanup later
105 |
106 | if args.taming:
107 | vae = VQGanVAE(args.vqgan_model_path, args.vqgan_config_path)
108 | elif vae_params is not None:
109 | vae = DiscreteVAE(**vae_params)
110 | else:
111 | vae = OpenAIDiscreteVAE()
112 |
113 | dalle = DALLE(vae = vae, **dalle_params).cuda()
114 |
115 | dalle.load_state_dict(weights)
116 |
117 | # generate images
118 |
119 | image_size = vae.image_size
120 |
121 | texts = args.text.split('|')
122 |
123 | for j, text in tqdm(enumerate(texts)):
124 |
125 | text = text.lower()
126 |
127 | if args.gentxt:
128 | text_tokens, gen_texts = dalle.generate_texts(tokenizer, text=text, filter_thres = args.top_k)
129 | text = gen_texts[0]
130 | else:
131 | text_tokens = tokenizer.tokenize([text], dalle.text_seq_len).cuda()
132 |
133 | text_tokens = repeat(text_tokens, '() n -> b n', b = args.num_images)
134 |
135 | outputs = []
136 | tokens = []
137 |
138 | for text_chunk in tqdm(text_tokens.split(args.batch_size), desc = f'generating images for - {text}'):
139 | if args.top_k is not None:
140 | output, tok = dalle.generate_images(text_chunk, temperature=args.temperature, top_k_thresh = args.top_k, return_tokens = args.output_npy)
141 | elif args.top_p is not None:
142 | output, tok = dalle.generate_images(text_chunk, temperature=args.temperature, top_p_thresh = args.top_p, return_tokens = args.output_npy)
143 | else:
144 | output, tok = dalle.generate_images(text_chunk, temperature=1.0, top_p_thresh = 0.9, return_tokens = args.output_npy)
145 |
146 | outputs.append(output)
147 |
148 | if tok is not None:
149 | tokens.append(tok)
150 |
151 | outputs = torch.cat(outputs)
152 |
153 | if len(tokens) > 0:
154 | tokens = torch.cat(tokens).cpu().detach().numpy()
155 |
156 | # save all images
157 | file_name = text
158 | outputs_dir = Path(args.outputs_dir) / file_name.replace(' ', '_')[:(100)]
159 | outputs_dir.mkdir(parents = True, exist_ok = True)
160 |
161 | if not args.clip_sort:
162 | for i, image in tqdm(enumerate(outputs), desc = 'saving images'):
163 | save_image(image, outputs_dir / f'{i}.jpg', normalize=True)
164 | with open(outputs_dir / 'caption.txt', 'w') as f:
165 | f.write(file_name)
166 | if args.output_npy:
167 | with open(outputs_dir / f'{i}.npy', 'wb') as f:
168 | np.save(f, tokens[i])
169 | else:
170 | images_sorted = []
171 | for i, image in enumerate(outputs):
172 | image = image.clamp(-1,1)
173 | pimg = TF.to_pil_image(image)
174 | image = image.add(1).div(2)
175 |
176 |
177 | text_features = clip_model.encode_text(clip.tokenize(text).to(device))
178 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
179 |
180 | image_features = clip_model.encode_image(clip_preprocess(pimg).unsqueeze(0).to(device))
181 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
182 |
183 | similarity = torch.nn.functional.cosine_similarity(image_features, text_features, dim=-1)
184 |
185 | tup = (image, similarity.item(), tokens[i] if args.output_npy else None)
186 | images_sorted.append(tup)
187 |
188 | images_sorted.sort(key=lambda x:x[1], reverse=True)
189 |
190 | for i, image in enumerate(images_sorted):
191 | save_image(image[0], outputs_dir / f'{i}-{image[1]}.png', normalize=True)
192 | if args.output_npy:
193 | with open(outputs_dir / f'{i}.npy', 'wb') as f:
194 | np.save(f, image[2])
195 |
196 | print(f'created {args.num_images} images at "{str(outputs_dir)}"')
197 |
--------------------------------------------------------------------------------
/images/avocado-0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/avocado-0.png
--------------------------------------------------------------------------------
/images/avocado-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/avocado-1.png
--------------------------------------------------------------------------------
/images/avocado-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/avocado-2.png
--------------------------------------------------------------------------------
/images/avocado-3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/avocado-3.png
--------------------------------------------------------------------------------
/images/avocado-4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/avocado-4.png
--------------------------------------------------------------------------------
/images/avocado-5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/avocado-5.png
--------------------------------------------------------------------------------
/images/avocado-6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/avocado-6.png
--------------------------------------------------------------------------------
/images/avocado-7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/avocado-7.png
--------------------------------------------------------------------------------
/images/avocado-before.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/avocado-before.png
--------------------------------------------------------------------------------
/images/avocado.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/avocado.png
--------------------------------------------------------------------------------
/images/banner.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/banner.jpg
--------------------------------------------------------------------------------
/images/birds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/birds.png
--------------------------------------------------------------------------------
/images/clothing.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/clothing.png
--------------------------------------------------------------------------------
/images/cloud-0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cloud-0.png
--------------------------------------------------------------------------------
/images/cloud-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cloud-1.png
--------------------------------------------------------------------------------
/images/cloud-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cloud-2.png
--------------------------------------------------------------------------------
/images/cloud-3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cloud-3.png
--------------------------------------------------------------------------------
/images/cloud-4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cloud-4.png
--------------------------------------------------------------------------------
/images/cloud-5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cloud-5.png
--------------------------------------------------------------------------------
/images/cloud-6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cloud-6.png
--------------------------------------------------------------------------------
/images/cloud-7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cloud-7.png
--------------------------------------------------------------------------------
/images/cube-cloud-before.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cube-cloud-before.jpg
--------------------------------------------------------------------------------
/images/cube-cloud.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cube-cloud.png
--------------------------------------------------------------------------------
/images/cube-porcupine-before.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cube-porcupine-before.jpg
--------------------------------------------------------------------------------
/images/cube-porcupine.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cube-porcupine.png
--------------------------------------------------------------------------------
/images/cube-water-before.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cube-water-before.jpg
--------------------------------------------------------------------------------
/images/cube-water.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/cube-water.png
--------------------------------------------------------------------------------
/images/girl-0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/girl-0.png
--------------------------------------------------------------------------------
/images/girl-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/girl-1.png
--------------------------------------------------------------------------------
/images/girl-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/girl-2.png
--------------------------------------------------------------------------------
/images/girl-3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/girl-3.png
--------------------------------------------------------------------------------
/images/girl-4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/girl-4.png
--------------------------------------------------------------------------------
/images/girl-5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/girl-5.png
--------------------------------------------------------------------------------
/images/girl-6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/girl-6.png
--------------------------------------------------------------------------------
/images/girl-7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/girl-7.png
--------------------------------------------------------------------------------
/images/girl-glasses-before.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/girl-glasses-before.png
--------------------------------------------------------------------------------
/images/girl-glasses.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/girl-glasses.png
--------------------------------------------------------------------------------
/images/landscape.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/landscape.png
--------------------------------------------------------------------------------
/images/layouts-1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/layouts-1.jpg
--------------------------------------------------------------------------------
/images/layouts-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/layouts-2.jpg
--------------------------------------------------------------------------------
/images/researcher-mad-before.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/researcher-mad-before.png
--------------------------------------------------------------------------------
/images/researcher-mad.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/researcher-mad.png
--------------------------------------------------------------------------------
/images/wb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack000/DALLE-pytorch/7b6f2e1f9ea7efa0a8f2e85e47489b88e5919c93/images/wb.png
--------------------------------------------------------------------------------
/install_apex.sh:
--------------------------------------------------------------------------------
1 | git clone https://github.com/NVIDIA/apex.git /tmp/apex
2 | cd /tmp/apex && pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
3 |
--------------------------------------------------------------------------------
/install_deepspeed.sh:
--------------------------------------------------------------------------------
1 | sudo apt-get -y install llvm-9-dev cmake
2 | git clone https://github.com/microsoft/DeepSpeed.git /tmp/Deepspeed
3 | cd /tmp/Deepspeed && DS_BUILD_SPARSE_ATTN=1 ./install.sh -s
4 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'dalle-pytorch',
5 | packages = find_packages(),
6 | include_package_data = True,
7 | version = '1.1.4',
8 | license='MIT',
9 | description = 'DALL-E - Pytorch',
10 | author = 'Phil Wang',
11 | author_email = 'lucidrains@gmail.com',
12 | url = 'https://github.com/lucidrains/dalle-pytorch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'attention mechanism',
16 | 'transformers',
17 | 'text-to-image'
18 | ],
19 | install_requires=[
20 | 'axial_positional_embedding',
21 | 'DALL-E',
22 | 'einops>=0.3.2',
23 | 'ftfy',
24 | 'g-mlp-pytorch',
25 | 'pillow',
26 | 'regex',
27 | 'rotary-embedding-torch',
28 | 'taming-transformers-rom1504',
29 | 'tokenizers',
30 | 'torch>=1.6',
31 | 'torchvision',
32 | 'transformers',
33 | 'tqdm',
34 | 'youtokentome',
35 | 'WebDataset'
36 | ],
37 | classifiers=[
38 | 'Development Status :: 4 - Beta',
39 | 'Intended Audience :: Developers',
40 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
41 | 'License :: OSI Approved :: MIT License',
42 | 'Programming Language :: Python :: 3.6',
43 | ],
44 | )
45 |
--------------------------------------------------------------------------------
/train_dalle.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | import time
4 | from glob import glob
5 | import os
6 | import shutil
7 |
8 | import torch
9 | import wandb # Quit early if user doesn't have wandb installed.
10 | from torch.nn.utils import clip_grad_norm_
11 | from torch.optim import Adam
12 | from torch.optim.lr_scheduler import ReduceLROnPlateau
13 | from torch.utils.data import DataLoader
14 |
15 | from dalle_pytorch import OpenAIDiscreteVAE, VQGanVAE, DiscreteVAE, DALLE
16 | from dalle_pytorch import distributed_utils
17 | from dalle_pytorch.loader import TextImageDataset
18 | from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, ChineseTokenizer, YttmTokenizer
19 |
20 | # libraries needed for webdataset support
21 | import webdataset as wds
22 | from torchvision import transforms as T
23 | from PIL import Image
24 | from io import BytesIO
25 |
26 |
27 | # argument parsing
28 |
29 | parser = argparse.ArgumentParser()
30 |
31 | group = parser.add_mutually_exclusive_group(required=False)
32 |
33 | group.add_argument('--vae_path', type=str,
34 | help='path to your trained discrete VAE')
35 |
36 | group.add_argument('--dalle_path', type=str,
37 | help='path to your partially trained DALL-E')
38 |
39 | parser.add_argument('--vqgan_model_path', type=str, default = None,
40 | help='path to your trained VQGAN weights. This should be a .ckpt file. (only valid when taming option is enabled)')
41 |
42 | parser.add_argument('--vqgan_config_path', type=str, default = None,
43 | help='path to your trained VQGAN config. This should be a .yaml file. (only valid when taming option is enabled)')
44 |
45 | parser.add_argument('--image_text_folder', type=str, required=True,
46 | help='path to your folder of images and text for learning the DALL-E')
47 |
48 | parser.add_argument(
49 | '--wds',
50 | type = str,
51 | default='',
52 | help = 'Comma separated list of WebDataset (1) image and (2) text column names. Must contain 2 values, e.g. img,cap.'
53 | )
54 |
55 | parser.add_argument('--truncate_captions', dest='truncate_captions', action='store_true',
56 | help='Captions passed in which exceed the max token length will be truncated if this is set.')
57 |
58 | parser.add_argument('--random_resize_crop_lower_ratio', dest='resize_ratio', type=float, default=0.75,
59 | help='Random resized crop lower ratio')
60 |
61 | parser.add_argument('--chinese', dest='chinese', action='store_true')
62 |
63 | parser.add_argument('--taming', dest='taming', action='store_true')
64 |
65 | parser.add_argument('--hug', dest='hug', action='store_true')
66 |
67 | parser.add_argument('--bpe_path', type=str,
68 | help='path to your BPE json file')
69 |
70 | parser.add_argument('--dalle_output_file_name', type=str, default = "dalle",
71 | help='output_file_name')
72 |
73 | parser.add_argument('--fp16', action='store_true',
74 | help='(experimental) - Enable DeepSpeed 16 bit precision. Reduces VRAM.')
75 |
76 |
77 | parser.add_argument('--amp', action='store_true',
78 | help='Apex "O1" automatic mixed precision. More stable than 16 bit precision. Can\'t be used in conjunction with deepspeed zero stages 1-3.')
79 |
80 | parser.add_argument('--wandb_name', default='dalle_train_transformer',
81 | help='Name W&B will use when saving results.\ne.g. `--wandb_name "coco2017-full-sparse"`')
82 |
83 | parser.add_argument('--wandb_entity', default=None,
84 | help='(optional) Name of W&B team/entity to log to.')
85 |
86 | parser.add_argument('--stable_softmax', dest='stable_softmax', action='store_true',
87 | help='Prevent values from becoming too large during softmax. Helps with stability in fp16 and Mixture of Quantization training.')
88 |
89 | parser = distributed_utils.wrap_arg_parser(parser)
90 |
91 | train_group = parser.add_argument_group('Training settings')
92 |
93 | train_group.add_argument('--flops_profiler', dest = 'flops_profiler', action='store_true', help = 'Exits after printing detailed flops/runtime analysis of forward/backward')
94 |
95 | train_group.add_argument('--epochs', default = 20, type = int, help = 'Number of epochs')
96 |
97 | train_group.add_argument('--save_every_n_steps', default = 1000, type = int, help = 'Save a checkpoint every n steps')
98 |
99 | train_group.add_argument('--keep_n_checkpoints', default = None, type = int, help = '(Careful) Deletes old deepspeed checkpoints if there are more than n')
100 |
101 | train_group.add_argument('--batch_size', default = 4, type = int, help = 'Batch size')
102 |
103 | train_group.add_argument('--ga_steps', default = 1, type = int, help = 'Number of steps to accumulate gradients across per each iteration. DeepSpeed only.')
104 |
105 | train_group.add_argument('--learning_rate', default = 3e-4, type = float, help = 'Learning rate')
106 |
107 | train_group.add_argument('--clip_grad_norm', default = 0.5, type = float, help = 'Clip gradient norm')
108 |
109 | train_group.add_argument('--lr_decay', dest = 'lr_decay', action = 'store_true')
110 |
111 | model_group = parser.add_argument_group('Model settings')
112 |
113 | model_group.add_argument('--dim', default = 512, type = int, help = 'Model dimension')
114 |
115 | model_group.add_argument('--text_seq_len', default = 256, type = int, help = 'Text sequence length')
116 |
117 | model_group.add_argument('--depth', default = 2, type = int, help = 'Model depth')
118 |
119 | model_group.add_argument('--heads', default = 8, type = int, help = 'Model number of heads')
120 |
121 | model_group.add_argument('--dim_head', default = 64, type = int, help = 'Model head dimension')
122 |
123 | train_group.add_argument('--ff_dropout', default = 0.0, type = float, help = 'Feed forward dropout.')
124 |
125 | train_group.add_argument('--attn_dropout', default = 0.0, type = float, help = 'Feed forward dropout.')
126 |
127 | model_group.add_argument('--reversible', dest = 'reversible', action='store_true')
128 |
129 | model_group.add_argument('--loss_img_weight', default = 7, type = int, help = 'Image loss weight')
130 |
131 | model_group.add_argument('--attn_types', default = 'full', type = str, help = 'comma separated list of attention types. attention type can be: full or sparse or axial_row or axial_col or conv_like.')
132 |
133 | model_group.add_argument('--shift_tokens', help = 'Use the shift tokens feature', action = 'store_true')
134 |
135 | model_group.add_argument('--rotary_emb', help = 'Use rotary embeddings', action = 'store_true')
136 |
137 | args = parser.parse_args()
138 |
139 | # helpers
140 |
141 | def exists(val):
142 | return val is not None
143 |
144 | def get_trainable_params(model):
145 | return [params for params in model.parameters() if params.requires_grad]
146 |
147 | def cp_path_to_dir(cp_path, tag):
148 | """Convert a checkpoint path to a directory with `tag` inserted.
149 | If `cp_path` is already a directory, return it unchanged.
150 | """
151 | if not isinstance(cp_path, Path):
152 | cp_path = Path(cp_path)
153 | if cp_path.is_dir():
154 | return cp_path
155 | path_sans_extension = cp_path.parent / cp_path.stem
156 | cp_dir = Path(f'{path_sans_extension}-{tag}-cp')
157 | return cp_dir
158 |
159 | # constants
160 | WEBDATASET_IMAGE_TEXT_COLUMNS = tuple(args.wds.split(','))
161 | ENABLE_WEBDATASET = True if len(WEBDATASET_IMAGE_TEXT_COLUMNS) == 2 else False
162 |
163 | DALLE_OUTPUT_FILE_NAME = args.dalle_output_file_name + ".pt"
164 |
165 | VAE_PATH = args.vae_path
166 | VQGAN_MODEL_PATH = args.vqgan_model_path
167 | VQGAN_CONFIG_PATH = args.vqgan_config_path
168 | DALLE_PATH = args.dalle_path
169 | RESUME = exists(DALLE_PATH)
170 |
171 | EPOCHS = args.epochs
172 | BATCH_SIZE = args.batch_size
173 |
174 | LEARNING_RATE = args.learning_rate
175 | GRAD_CLIP_NORM = args.clip_grad_norm
176 | LR_DECAY = args.lr_decay
177 | SAVE_EVERY_N_STEPS = args.save_every_n_steps
178 | KEEP_N_CHECKPOINTS = args.keep_n_checkpoints
179 |
180 | MODEL_DIM = args.dim
181 | TEXT_SEQ_LEN = args.text_seq_len
182 | DEPTH = args.depth
183 | HEADS = args.heads
184 | DIM_HEAD = args.dim_head
185 | REVERSIBLE = args.reversible
186 | LOSS_IMG_WEIGHT = args.loss_img_weight
187 | FF_DROPOUT = args.ff_dropout
188 | ATTN_DROPOUT = args.attn_dropout
189 | STABLE = args.stable_softmax
190 | SHIFT_TOKENS = args.shift_tokens
191 | ROTARY_EMB = args.rotary_emb
192 |
193 | ATTN_TYPES = tuple(args.attn_types.split(','))
194 |
195 | DEEPSPEED_CP_AUX_FILENAME = 'auxiliary.pt'
196 |
197 | if not ENABLE_WEBDATASET:
198 | # quit early if you used the wrong folder name
199 | assert Path(args.image_text_folder).exists(), f'The path {args.image_text_folder} was not found.'
200 | else:
201 | # quit early if no tar files were found
202 | if Path(args.image_text_folder).is_dir():
203 | DATASET = [str(p) for p in Path(args.image_text_folder).glob("**/*") if ".tar" in str(p).lower()] # .name
204 | assert len(DATASET) > 0, 'The directory ({}) does not contain any WebDataset/.tar files.'.format(args.image_text_folder)
205 | print('Found {} WebDataset .tar(.gz) file(s) under given path {}!'.format(len(DATASET), args.image_text_folder))
206 | elif ('http://' in args.image_text_folder.lower()) | ('https://' in args.image_text_folder.lower()):
207 | DATASET = f"pipe:curl -L -s {args.image_text_folder} || true"
208 | print('Found {} http(s) link under given path!'.format(len(DATASET), args.image_text_folder))
209 | elif 'gs://' in args.image_text_folder.lower():
210 | DATASET = f"pipe:gsutil cat {args.image_text_folder} || true"
211 | print('Found {} GCS link under given path!'.format(len(DATASET), args.image_text_folder))
212 | elif '.tar' in args.image_text_folder:
213 | DATASET = args.image_text_folder
214 | print('Found WebDataset .tar(.gz) file under given path {}!'.format(args.image_text_folder))
215 | else:
216 | raise Exception('No folder, no .tar(.gz) and no url pointing to tar files provided under {}.'.format(args.image_text_folder))
217 |
218 | # initialize distributed backend
219 |
220 | distr_backend = distributed_utils.set_backend_from_args(args)
221 | distr_backend.initialize()
222 |
223 | using_deepspeed = \
224 | distributed_utils.using_backend(distributed_utils.DeepSpeedBackend)
225 |
226 | # tokenizer
227 |
228 | if exists(args.bpe_path):
229 | klass = HugTokenizer if args.hug else YttmTokenizer
230 | tokenizer = klass(args.bpe_path)
231 | elif args.chinese:
232 | tokenizer = ChineseTokenizer()
233 |
234 | # reconstitute vae
235 | if RESUME:
236 | dalle_path = Path(DALLE_PATH)
237 | if using_deepspeed:
238 | cp_dir = cp_path_to_dir(dalle_path, 'ds')
239 | assert cp_dir.is_dir(), \
240 | f'DeepSpeed checkpoint directory {cp_dir} not found'
241 | dalle_path = cp_dir / DEEPSPEED_CP_AUX_FILENAME
242 | else:
243 | assert dalle_path.exists(), 'DALL-E model file does not exist'
244 | loaded_obj = torch.load(str(dalle_path), map_location='cpu')
245 |
246 | dalle_params, vae_params, weights = loaded_obj['hparams'], loaded_obj['vae_params'], loaded_obj['weights']
247 | opt_state = loaded_obj.get('opt_state')
248 | scheduler_state = loaded_obj.get('scheduler_state')
249 |
250 | if vae_params is not None:
251 | vae = DiscreteVAE(**vae_params)
252 | else:
253 | if args.taming:
254 | vae = VQGanVAE(VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH)
255 | else:
256 | vae = OpenAIDiscreteVAE()
257 |
258 | dalle_params = dict(
259 | **dalle_params
260 | )
261 | IMAGE_SIZE = vae.image_size
262 | resume_epoch = loaded_obj.get('epoch', 0)
263 | else:
264 | if exists(VAE_PATH):
265 | vae_path = Path(VAE_PATH)
266 | assert vae_path.exists(), 'VAE model file does not exist'
267 | assert not vae_path.is_dir(), \
268 | ('Cannot load VAE model from directory; please use a '
269 | 'standard *.pt checkpoint. '
270 | 'Currently, merging a DeepSpeed-partitioned VAE into a DALLE '
271 | 'model is not supported.')
272 |
273 | loaded_obj = torch.load(str(vae_path))
274 |
275 | vae_params, weights = loaded_obj['hparams'], loaded_obj['weights']
276 |
277 | vae = DiscreteVAE(**vae_params)
278 | vae.load_state_dict(weights)
279 | else:
280 | if distr_backend.is_root_worker():
281 | print('using pretrained VAE for encoding images to tokens')
282 | vae_params = None
283 |
284 | if args.taming:
285 | vae = VQGanVAE(VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH)
286 | else:
287 | vae = OpenAIDiscreteVAE()
288 |
289 | IMAGE_SIZE = vae.image_size
290 |
291 | dalle_params = dict(
292 | num_text_tokens=tokenizer.vocab_size,
293 | text_seq_len=TEXT_SEQ_LEN,
294 | dim=MODEL_DIM,
295 | depth=DEPTH,
296 | heads=HEADS,
297 | dim_head=DIM_HEAD,
298 | reversible=REVERSIBLE,
299 | loss_img_weight=LOSS_IMG_WEIGHT,
300 | attn_types=ATTN_TYPES,
301 | ff_dropout=FF_DROPOUT,
302 | attn_dropout=ATTN_DROPOUT,
303 | stable=STABLE,
304 | shift_tokens=SHIFT_TOKENS,
305 | rotary_emb=ROTARY_EMB,
306 | )
307 | resume_epoch = 0
308 |
309 | # configure OpenAI VAE for float16s
310 |
311 | if isinstance(vae, OpenAIDiscreteVAE) and args.fp16:
312 | vae.enc.blocks.output.conv.use_float16 = True
313 |
314 |
315 | # helpers
316 |
317 | def group_weight(model):
318 | group_decay, group_no_decay = [], []
319 | for params in model.named_parameters():
320 | if 'transformer' in params[0]:
321 | if 'bias' in params[0] or 'norm' in params[0]:
322 | group_no_decay.append(params[1])
323 | continue
324 | group_decay.append(params[1])
325 |
326 | assert len(list(model.parameters())) == len(group_decay) + len(group_no_decay)
327 | groups = [dict(params=group_decay), dict(params=group_no_decay, weight_decay=.0)]
328 | return groups
329 |
330 |
331 | # create dataset and dataloader
332 |
333 | is_shuffle = not distributed_utils.using_backend(distributed_utils.HorovodBackend)
334 |
335 | imagepreproc = T.Compose([
336 | T.Lambda(lambda img: img.convert('RGB')
337 | if img.mode != 'RGB' else img),
338 | T.RandomResizedCrop(IMAGE_SIZE,
339 | scale=(args.resize_ratio, 1.),
340 | ratio=(1., 1.)),
341 | T.ToTensor(),
342 | ])
343 |
344 | def imagetransform(b):
345 | return Image.open(BytesIO(b))
346 |
347 | def tokenize(s):
348 | return tokenizer.tokenize(
349 | s.decode('utf-8'),
350 | TEXT_SEQ_LEN,
351 | truncate_text=args.truncate_captions).squeeze(0)
352 |
353 | if ENABLE_WEBDATASET:
354 | DATASET_SIZE = int(1e9) # You need to set a nominal length for the Dataset in order to avoid warnings from DataLoader
355 |
356 | myimg, mycap = WEBDATASET_IMAGE_TEXT_COLUMNS
357 | image_text_mapping = {
358 | myimg: imagetransform,
359 | mycap: tokenize
360 | }
361 | image_mapping = {
362 | myimg: imagepreproc
363 | }
364 |
365 | def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available.
366 | if mycap not in item:
367 | return False
368 | if myimg not in item:
369 | return False
370 | return True
371 |
372 | w_dataset = wds.WebDataset(DATASET, handler=wds.warn_and_continue)
373 | filtered_dataset = w_dataset.select(filter_dataset)
374 | ds = filtered_dataset.map_dict(**image_text_mapping).map_dict(**image_mapping).to_tuple(mycap, myimg).batched(BATCH_SIZE, partial=True)
375 | else:
376 | ds = TextImageDataset(
377 | args.image_text_folder,
378 | text_len=TEXT_SEQ_LEN,
379 | image_size=IMAGE_SIZE,
380 | resize_ratio=args.resize_ratio,
381 | truncate_captions=args.truncate_captions,
382 | tokenizer=tokenizer,
383 | shuffle=is_shuffle,
384 | )
385 | assert len(ds) > 0, 'dataset is empty'
386 |
387 | if distr_backend.is_root_worker():
388 | if not ENABLE_WEBDATASET:
389 | print(f'{len(ds)} image-text pairs found for training')
390 |
391 | if not is_shuffle:
392 | data_sampler = torch.utils.data.distributed.DistributedSampler(
393 | ds,
394 | num_replicas=distr_backend.get_world_size(),
395 | rank=distr_backend.get_rank()
396 | )
397 | else:
398 | data_sampler = None
399 |
400 | if ENABLE_WEBDATASET:
401 | # WebLoader for WebDataset and DeepSpeed compatibility
402 | dl = wds.WebLoader(ds, batch_size=None, shuffle=False) # optionally add num_workers=2 (n) argument
403 | number_of_batches = DATASET_SIZE // (BATCH_SIZE * distr_backend.get_world_size())
404 | dl = dl.repeat(2).slice(number_of_batches)
405 | dl.length = number_of_batches
406 | else:
407 | # Regular DataLoader for image-text-folder datasets
408 | dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=is_shuffle, drop_last=True, sampler=data_sampler)
409 |
410 |
411 | # initialize DALL-E
412 |
413 | dalle = DALLE(vae=vae, **dalle_params)
414 | if not using_deepspeed:
415 | if args.fp16:
416 | dalle = dalle.half()
417 | dalle = dalle.cuda()
418 |
419 | if RESUME and not using_deepspeed:
420 | dalle.load_state_dict(weights)
421 |
422 | # optimizer
423 |
424 | opt = Adam(get_trainable_params(dalle), lr=LEARNING_RATE)
425 | if RESUME and opt_state:
426 | opt.load_state_dict(opt_state)
427 |
428 | if LR_DECAY:
429 | scheduler = ReduceLROnPlateau(
430 | opt,
431 | mode="min",
432 | factor=0.5,
433 | patience=10,
434 | cooldown=10,
435 | min_lr=1e-6,
436 | verbose=True,
437 | )
438 | if RESUME and scheduler_state:
439 | scheduler.load_state_dict(scheduler_state)
440 | else:
441 | scheduler = None
442 |
443 | if distr_backend.is_root_worker():
444 | # experiment tracker
445 |
446 | model_config = dict(
447 | depth=DEPTH,
448 | heads=HEADS,
449 | dim_head=DIM_HEAD
450 | )
451 |
452 | run = wandb.init(
453 | project=args.wandb_name,
454 | entity=args.wandb_entity,
455 | resume=False,
456 | config=model_config,
457 | )
458 |
459 | # distribute
460 |
461 | distr_backend.check_batch_size(BATCH_SIZE)
462 | deepspeed_config = {
463 | 'train_batch_size': BATCH_SIZE,
464 | 'gradient_accumulation_steps': args.ga_steps,
465 | 'gradient_clipping': GRAD_CLIP_NORM,
466 | 'fp16': {
467 | 'enabled': args.fp16,
468 | },
469 | 'amp': {
470 | 'enabled': args.amp,
471 | 'opt_level': 'O1',
472 | },
473 | "flops_profiler": {
474 | "enabled": args.flops_profiler,
475 | "profile_step": 200,
476 | "module_depth": -1,
477 | "top_modules": 1,
478 | "detailed": True,
479 | "output_file": None # TODO Can't get this to work.
480 | },
481 | }
482 |
483 | if deepspeed_config.get('zero_optimization', {}).get('stage', 0) >= 2:
484 | print(f"Checkpoints made with DeepSpeed ZeRO Stages 2 and 3 will be stored in deepspeed checkpoint folder")
485 | print(f"As such, they will require DeepSpeed as a dependency in order to resume from or generate with.")
486 | print("See the deespeed conversion script for details on how to convert your ZeRO stage 2/3 checkpoint to a single file.")
487 | print("If using a single GPU, consider running with apex automatic mixed precision instead for a similar speedup to ZeRO.")
488 | time.sleep(2)
489 |
490 | (distr_dalle, distr_opt, distr_dl, distr_scheduler) = distr_backend.distribute(
491 | args=args,
492 | model=dalle,
493 | optimizer=opt,
494 | model_parameters=get_trainable_params(dalle),
495 | training_data=(
496 | (None if ENABLE_WEBDATASET else ds)
497 | if using_deepspeed
498 | else dl
499 | ),
500 | # Do not pass the LR scheduler to DeepSpeed so we can manually
501 | # advance it.
502 | lr_scheduler=scheduler if LR_DECAY and not using_deepspeed else None,
503 | config_params=deepspeed_config,
504 | )
505 | # Prefer scheduler in `deepspeed_config`.
506 | if LR_DECAY and distr_scheduler is None:
507 | distr_scheduler = scheduler
508 | avoid_model_calls = using_deepspeed and args.fp16
509 |
510 | if RESUME and using_deepspeed:
511 | distr_dalle.load_checkpoint(str(cp_dir))
512 |
513 |
514 | def save_model(path, epoch=0):
515 | save_obj = {
516 | 'hparams': dalle_params,
517 | 'vae_params': vae_params,
518 | 'epoch': epoch,
519 | }
520 | if using_deepspeed:
521 | cp_dir = cp_path_to_dir(path, 'ds')
522 |
523 | if KEEP_N_CHECKPOINTS is not None and distr_backend.is_root_worker():
524 | checkpoints = sorted(glob(str(cp_dir / "global*")), key=os.path.getmtime, reverse=True)
525 | for checkpoint in checkpoints[KEEP_N_CHECKPOINTS:]:
526 | shutil.rmtree(checkpoint)
527 |
528 | distr_dalle.save_checkpoint(cp_dir, client_state=save_obj)
529 |
530 | if not distr_backend.is_root_worker():
531 | return
532 |
533 | # Save auxiliary values so we can reuse the standard routine
534 | # for loading.
535 | save_obj = {
536 | **save_obj,
537 | # Save a nonsense value that directs the user to
538 | # further help.
539 | 'weights': (
540 | 'To get a working standard checkpoint, '
541 | 'look into consolidating DeepSpeed checkpoints.'
542 | ),
543 | }
544 | torch.save(save_obj, str(cp_dir / DEEPSPEED_CP_AUX_FILENAME))
545 | if deepspeed_config.get('zero_optimization', {}).get('stage', 0) >= 2: # see https://github.com/lucidrains/DALLE-pytorch/wiki/DeepSpeed-Checkpoints
546 | return
547 |
548 | if not distr_backend.is_root_worker():
549 | return
550 |
551 | save_obj = {
552 | **save_obj,
553 | 'weights': dalle.state_dict(),
554 | 'opt_state': opt.state_dict(),
555 | }
556 | save_obj['scheduler_state'] = (scheduler.state_dict() if scheduler else None)
557 | torch.save(save_obj, path)
558 |
559 | # training
560 |
561 | # Saves a checkpoint before training begins to fail early when mis-configured.
562 | # See https://github.com/lucidrains/DALLE-pytorch/wiki/DeepSpeed-Checkpoints
563 | save_model(DALLE_OUTPUT_FILE_NAME, epoch=resume_epoch)
564 | for epoch in range(resume_epoch, EPOCHS):
565 | if data_sampler:
566 | data_sampler.set_epoch(epoch)
567 | for i, (text, images) in enumerate((dl if ENABLE_WEBDATASET else distr_dl)):
568 | if i % 10 == 0 and distr_backend.is_root_worker():
569 | t = time.time()
570 | if args.fp16:
571 | images = images.half()
572 | text, images = map(lambda t: t.cuda(), (text, images))
573 |
574 | loss = distr_dalle(text, images, return_loss=True)
575 |
576 | if using_deepspeed:
577 | distr_dalle.backward(loss)
578 | distr_dalle.step()
579 | # Gradients are automatically zeroed after the step
580 | else:
581 | loss.backward()
582 | clip_grad_norm_(distr_dalle.parameters(), GRAD_CLIP_NORM)
583 | distr_opt.step()
584 | distr_opt.zero_grad()
585 |
586 | # Collective loss, averaged
587 | avg_loss = distr_backend.average_all(loss)
588 |
589 | log = {}
590 |
591 | if i % 10 == 0 and distr_backend.is_root_worker():
592 | print(epoch, i, f'loss - {avg_loss.item()}')
593 |
594 | log = {
595 | **log,
596 | 'epoch': epoch,
597 | 'iter': i,
598 | 'loss': avg_loss.item()
599 | }
600 |
601 | if i % SAVE_EVERY_N_STEPS == 0:
602 | save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch)
603 |
604 | if i % 100 == 0:
605 | if distr_backend.is_root_worker():
606 | sample_text = text[:1]
607 | token_list = sample_text.masked_select(sample_text != 0).tolist()
608 | decoded_text = tokenizer.decode(token_list)
609 |
610 | if not avoid_model_calls:
611 | # CUDA index errors when we don't guard this
612 | image = dalle.generate_images(text[:1], top_p_thres=0.9) # topp sampling at 0.9
613 |
614 |
615 | log = {
616 | **log,
617 | }
618 | if not avoid_model_calls:
619 | log['image'] = wandb.Image(image, caption=decoded_text)
620 |
621 | if i % 10 == 9 and distr_backend.is_root_worker():
622 | sample_per_sec = BATCH_SIZE * 10 / (time.time() - t)
623 | log["sample_per_sec"] = sample_per_sec
624 | print(epoch, i, f'sample_per_sec - {sample_per_sec}')
625 |
626 | if i == 201 and args.flops_profiler:
627 | raise StopIteration("Profiler has finished running. Stopping training early.")
628 |
629 | if distr_backend.is_root_worker():
630 | wandb.log(log)
631 |
632 | if LR_DECAY:
633 | distr_scheduler.step(avg_loss)
634 |
635 | save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch)
636 |
637 | if distr_backend.is_root_worker():
638 | # save trained model to wandb as an artifact every epoch's end
639 |
640 | model_artifact = wandb.Artifact('trained-dalle', type='model', metadata=dict(model_config))
641 | model_artifact.add_file(DALLE_OUTPUT_FILE_NAME)
642 | run.log_artifact(model_artifact)
643 |
644 | save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch)
645 | if distr_backend.is_root_worker():
646 | wandb.save(DALLE_OUTPUT_FILE_NAME)
647 | model_artifact = wandb.Artifact('trained-dalle', type='model', metadata=dict(model_config))
648 | model_artifact.add_file(DALLE_OUTPUT_FILE_NAME)
649 | run.log_artifact(model_artifact)
650 |
651 | wandb.finish()
652 |
--------------------------------------------------------------------------------
/train_vae.py:
--------------------------------------------------------------------------------
1 | import math
2 | from math import sqrt
3 | import argparse
4 | from pathlib import Path
5 |
6 | # torch
7 |
8 | import torch
9 | from torch.optim import Adam
10 | from torch.optim.lr_scheduler import ExponentialLR
11 |
12 | # vision imports
13 |
14 | from torchvision import transforms as T
15 | from torch.utils.data import DataLoader
16 | from torchvision.datasets import ImageFolder
17 | from torchvision.utils import make_grid, save_image
18 |
19 | # dalle classes and utils
20 |
21 | from dalle_pytorch import distributed_utils
22 | from dalle_pytorch import DiscreteVAE
23 |
24 | # argument parsing
25 |
26 | parser = argparse.ArgumentParser()
27 |
28 | parser.add_argument('--image_folder', type = str, required = True,
29 | help='path to your folder of images for learning the discrete VAE and its codebook')
30 |
31 | parser.add_argument('--image_size', type = int, required = False, default = 128,
32 | help='image size')
33 |
34 | parser = distributed_utils.wrap_arg_parser(parser)
35 |
36 |
37 | train_group = parser.add_argument_group('Training settings')
38 |
39 | train_group.add_argument('--epochs', type = int, default = 20, help = 'number of epochs')
40 |
41 | train_group.add_argument('--batch_size', type = int, default = 8, help = 'batch size')
42 |
43 | train_group.add_argument('--learning_rate', type = float, default = 1e-3, help = 'learning rate')
44 |
45 | train_group.add_argument('--lr_decay_rate', type = float, default = 0.98, help = 'learning rate decay')
46 |
47 | train_group.add_argument('--starting_temp', type = float, default = 1., help = 'starting temperature')
48 |
49 | train_group.add_argument('--temp_min', type = float, default = 0.5, help = 'minimum temperature to anneal to')
50 |
51 | train_group.add_argument('--anneal_rate', type = float, default = 1e-6, help = 'temperature annealing rate')
52 |
53 | train_group.add_argument('--num_images_save', type = int, default = 4, help = 'number of images to save')
54 |
55 | model_group = parser.add_argument_group('Model settings')
56 |
57 | model_group.add_argument('--num_tokens', type = int, default = 8192, help = 'number of image tokens')
58 |
59 | model_group.add_argument('--num_layers', type = int, default = 3, help = 'number of layers (should be 3 or above)')
60 |
61 | model_group.add_argument('--num_resnet_blocks', type = int, default = 2, help = 'number of residual net blocks')
62 |
63 | model_group.add_argument('--smooth_l1_loss', dest = 'smooth_l1_loss', action = 'store_true')
64 |
65 | model_group.add_argument('--emb_dim', type = int, default = 512, help = 'embedding dimension')
66 |
67 | model_group.add_argument('--hidden_dim', type = int, default = 256, help = 'hidden dimension')
68 |
69 | model_group.add_argument('--kl_loss_weight', type = float, default = 0., help = 'KL loss weight')
70 |
71 | args = parser.parse_args()
72 |
73 | # constants
74 |
75 | IMAGE_SIZE = args.image_size
76 | IMAGE_PATH = args.image_folder
77 |
78 | EPOCHS = args.epochs
79 | BATCH_SIZE = args.batch_size
80 | LEARNING_RATE = args.learning_rate
81 | LR_DECAY_RATE = args.lr_decay_rate
82 |
83 | NUM_TOKENS = args.num_tokens
84 | NUM_LAYERS = args.num_layers
85 | NUM_RESNET_BLOCKS = args.num_resnet_blocks
86 | SMOOTH_L1_LOSS = args.smooth_l1_loss
87 | EMB_DIM = args.emb_dim
88 | HIDDEN_DIM = args.hidden_dim
89 | KL_LOSS_WEIGHT = args.kl_loss_weight
90 |
91 | STARTING_TEMP = args.starting_temp
92 | TEMP_MIN = args.temp_min
93 | ANNEAL_RATE = args.anneal_rate
94 |
95 | NUM_IMAGES_SAVE = args.num_images_save
96 |
97 | # initialize distributed backend
98 |
99 | distr_backend = distributed_utils.set_backend_from_args(args)
100 | distr_backend.initialize()
101 |
102 | using_deepspeed = \
103 | distributed_utils.using_backend(distributed_utils.DeepSpeedBackend)
104 |
105 | # data
106 |
107 | ds = ImageFolder(
108 | IMAGE_PATH,
109 | T.Compose([
110 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
111 | T.Resize(IMAGE_SIZE),
112 | T.CenterCrop(IMAGE_SIZE),
113 | T.ToTensor()
114 | ])
115 | )
116 |
117 | if distributed_utils.using_backend(distributed_utils.HorovodBackend):
118 | data_sampler = torch.utils.data.distributed.DistributedSampler(
119 | ds, num_replicas=distr_backend.get_world_size(),
120 | rank=distr_backend.get_rank())
121 | else:
122 | data_sampler = None
123 |
124 | dl = DataLoader(ds, BATCH_SIZE, shuffle = not data_sampler, sampler=data_sampler)
125 |
126 | vae_params = dict(
127 | image_size = IMAGE_SIZE,
128 | num_layers = NUM_LAYERS,
129 | num_tokens = NUM_TOKENS,
130 | codebook_dim = EMB_DIM,
131 | hidden_dim = HIDDEN_DIM,
132 | num_resnet_blocks = NUM_RESNET_BLOCKS
133 | )
134 |
135 | vae = DiscreteVAE(
136 | **vae_params,
137 | smooth_l1_loss = SMOOTH_L1_LOSS,
138 | kl_div_loss_weight = KL_LOSS_WEIGHT
139 | )
140 | if not using_deepspeed:
141 | vae = vae.cuda()
142 |
143 |
144 | assert len(ds) > 0, 'folder does not contain any images'
145 | if distr_backend.is_root_worker():
146 | print(f'{len(ds)} images found for training')
147 |
148 | # optimizer
149 |
150 | opt = Adam(vae.parameters(), lr = LEARNING_RATE)
151 | sched = ExponentialLR(optimizer = opt, gamma = LR_DECAY_RATE)
152 |
153 |
154 | if distr_backend.is_root_worker():
155 | # weights & biases experiment tracking
156 |
157 | import wandb
158 |
159 | model_config = dict(
160 | num_tokens = NUM_TOKENS,
161 | smooth_l1_loss = SMOOTH_L1_LOSS,
162 | num_resnet_blocks = NUM_RESNET_BLOCKS,
163 | kl_loss_weight = KL_LOSS_WEIGHT
164 | )
165 |
166 | run = wandb.init(
167 | project = 'dalle_train_vae',
168 | job_type = 'train_model',
169 | config = model_config
170 | )
171 |
172 | # distribute
173 |
174 | distr_backend.check_batch_size(BATCH_SIZE)
175 | deepspeed_config = {'train_batch_size': BATCH_SIZE}
176 |
177 | (distr_vae, distr_opt, distr_dl, distr_sched) = distr_backend.distribute(
178 | args=args,
179 | model=vae,
180 | optimizer=opt,
181 | model_parameters=vae.parameters(),
182 | training_data=ds if using_deepspeed else dl,
183 | lr_scheduler=sched if not using_deepspeed else None,
184 | config_params=deepspeed_config,
185 | )
186 |
187 | using_deepspeed_sched = False
188 | # Prefer scheduler in `deepspeed_config`.
189 | if distr_sched is None:
190 | distr_sched = sched
191 | elif using_deepspeed:
192 | # We are using a DeepSpeed LR scheduler and want to let DeepSpeed
193 | # handle its scheduling.
194 | using_deepspeed_sched = True
195 |
196 | def save_model(path):
197 | save_obj = {
198 | 'hparams': vae_params,
199 | }
200 | if using_deepspeed:
201 | cp_path = Path(path)
202 | path_sans_extension = cp_path.parent / cp_path.stem
203 | cp_dir = str(path_sans_extension) + '-ds-cp'
204 |
205 | distr_vae.save_checkpoint(cp_dir, client_state=save_obj)
206 | # We do not return so we do get a "normal" checkpoint to refer to.
207 |
208 | if not distr_backend.is_root_worker():
209 | return
210 |
211 | save_obj = {
212 | **save_obj,
213 | 'weights': vae.state_dict()
214 | }
215 |
216 | torch.save(save_obj, path)
217 |
218 | # starting temperature
219 |
220 | global_step = 0
221 | temp = STARTING_TEMP
222 |
223 | for epoch in range(EPOCHS):
224 | for i, (images, _) in enumerate(distr_dl):
225 | images = images.cuda()
226 |
227 | loss, recons = distr_vae(
228 | images,
229 | return_loss = True,
230 | return_recons = True,
231 | temp = temp
232 | )
233 |
234 | if using_deepspeed:
235 | # Gradients are automatically zeroed after the step
236 | distr_vae.backward(loss)
237 | distr_vae.step()
238 | else:
239 | distr_opt.zero_grad()
240 | loss.backward()
241 | distr_opt.step()
242 |
243 | logs = {}
244 |
245 | if i % 100 == 0:
246 | if distr_backend.is_root_worker():
247 | k = NUM_IMAGES_SAVE
248 |
249 | with torch.no_grad():
250 | codes = vae.get_codebook_indices(images[:k])
251 | hard_recons = vae.decode(codes)
252 |
253 | images, recons = map(lambda t: t[:k], (images, recons))
254 | images, recons, hard_recons, codes = map(lambda t: t.detach().cpu(), (images, recons, hard_recons, codes))
255 | images, recons, hard_recons = map(lambda t: make_grid(t.float(), nrow = int(sqrt(k)), normalize = True, range = (-1, 1)), (images, recons, hard_recons))
256 |
257 | logs = {
258 | **logs,
259 | 'sample images': wandb.Image(images, caption = 'original images'),
260 | 'reconstructions': wandb.Image(recons, caption = 'reconstructions'),
261 | 'hard reconstructions': wandb.Image(hard_recons, caption = 'hard reconstructions'),
262 | 'codebook_indices': wandb.Histogram(codes),
263 | 'temperature': temp
264 | }
265 |
266 | wandb.save('./vae.pt')
267 | save_model(f'./vae.pt')
268 |
269 | # temperature anneal
270 |
271 | temp = max(temp * math.exp(-ANNEAL_RATE * global_step), TEMP_MIN)
272 |
273 | # lr decay
274 |
275 | # Do not advance schedulers from `deepspeed_config`.
276 | if not using_deepspeed_sched:
277 | distr_sched.step()
278 |
279 | # Collective loss, averaged
280 | avg_loss = distr_backend.average_all(loss)
281 |
282 | if distr_backend.is_root_worker():
283 | if i % 10 == 0:
284 | lr = distr_sched.get_last_lr()[0]
285 | print(epoch, i, f'lr - {lr:6f} loss - {avg_loss.item()}')
286 |
287 | logs = {
288 | **logs,
289 | 'epoch': epoch,
290 | 'iter': i,
291 | 'loss': avg_loss.item(),
292 | 'lr': lr
293 | }
294 |
295 | wandb.log(logs)
296 | global_step += 1
297 |
298 | if distr_backend.is_root_worker():
299 | # save trained model to wandb as an artifact every epoch's end
300 |
301 | model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config))
302 | model_artifact.add_file('vae.pt')
303 | run.log_artifact(model_artifact)
304 |
305 | if distr_backend.is_root_worker():
306 | # save final vae and cleanup
307 |
308 | save_model('./vae-final.pt')
309 | wandb.save('./vae-final.pt')
310 |
311 | model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config))
312 | model_artifact.add_file('vae-final.pt')
313 | run.log_artifact(model_artifact)
314 |
315 | wandb.finish()
316 |
--------------------------------------------------------------------------------