├── dalle_pytorch
├── __init__.py
├── reversible.py
├── transformer.py
└── dalle_pytorch.py
├── images
└── landscape.png
├── install_deepspeed.sh
├── setup.py
├── .github
└── workflows
│ └── python-publish.yml
├── LICENSE
├── Vocabulary.py
├── mixVAEcuda.py
├── .gitignore
├── genDALLE.py
├── trainVAE.py
├── README.md
└── trainDALLE.py
/dalle_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from dalle_pytorch.dalle_pytorch import DALLE, CLIP, DiscreteVAE
2 |
--------------------------------------------------------------------------------
/images/landscape.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/htoyryla/DALLE-pytorch/HEAD/images/landscape.png
--------------------------------------------------------------------------------
/install_deepspeed.sh:
--------------------------------------------------------------------------------
1 | sudo apt-get -y install llvm-9-dev cmake
2 | git clone https://github.com/microsoft/DeepSpeed.git /tmp/Deepspeed
3 | cd /tmp/Deepspeed && DS_BUILD_SPARSE_ATTN=1 ./install.sh -s
4 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'dalle-pytorch',
5 | packages = find_packages(),
6 | version = '0.0.36',
7 | license='MIT',
8 | description = 'DALL-E - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | url = 'https://github.com/lucidrains/dalle-pytorch',
12 | keywords = [
13 | 'artificial intelligence',
14 | 'attention mechanism',
15 | 'transformers',
16 | 'text-to-image'
17 | ],
18 | install_requires=[
19 | 'axial_positional_embedding',
20 | 'einops>=0.3',
21 | 'torch>=1.6'
22 | ],
23 | classifiers=[
24 | 'Development Status :: 4 - Beta',
25 | 'Intended Audience :: Developers',
26 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
27 | 'License :: OSI Approved :: MIT License',
28 | 'Programming Language :: Python :: 3.6',
29 | ],
30 | )
31 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflows will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Build and publish
26 | env:
27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29 | run: |
30 | python setup.py sdist bdist_wheel
31 | twine upload dist/*
32 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Vocabulary.py:
--------------------------------------------------------------------------------
1 | # adapted from https://www.kdnuggets.com/2019/11/create-vocabulary-nlp-tasks-python.html
2 |
3 | class Vocabulary:
4 | #PAD_token = 0 # Used for padding short sentences
5 | #SOS_token = 1 # Start-of-sentence token
6 | #EOS_token = 2 # End-of-sentence token
7 |
8 | def __init__(self, name):
9 | self.name = name
10 | self.word2index = {}
11 | self.word2count = {}
12 | self.index2word = {0: "PAD", 1: "SOS", 2: "EOS"}
13 | self.num_words = 3
14 | self.num_sentences = 0
15 | self.longest_sentence = 0
16 |
17 | def add_word(self, word):
18 | if word not in self.word2index:
19 | # First entry of word into vocabulary
20 | self.word2index[word] = self.num_words
21 | self.word2count[word] = 1
22 | self.index2word[self.num_words] = word
23 | self.num_words += 1
24 | else:
25 | # Word exists; increase word count
26 | self.word2count[word] += 1
27 |
28 | def add_sentence(self, sentence):
29 | sentence_len = 0
30 | for word in sentence.split(' '):
31 | sentence_len += 1
32 | self.add_word(word)
33 | if sentence_len > self.longest_sentence:
34 | # This is the longest sentence
35 | self.longest_sentence = sentence_len
36 | # Count the number of sentences
37 | self.num_sentences += 1
38 |
39 | def to_word(self, index):
40 | return self.index2word[index]
41 |
42 | def to_index(self, word):
43 | return self.word2index[word]
--------------------------------------------------------------------------------
/mixVAEcuda.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, optim
3 | from torchvision import datasets, transforms
4 | from torchvision.utils import save_image
5 | from torch.utils.data import DataLoader
6 | from dalle_pytorch import DiscreteVAE
7 |
8 | imgSize = 256
9 | load_epoch = 280
10 |
11 | vae = DiscreteVAE(
12 | image_size = imgSize,
13 | num_layers = 3,
14 | channels = 3,
15 | num_tokens = 2048,
16 | codebook_dim = 1024,
17 | hidden_dim = 128
18 | )
19 |
20 | vae_dict = torch.load("./models/dvae-"+str(load_epoch)+".pth")
21 | vae.load_state_dict(vae_dict)
22 | vae.cuda()
23 |
24 | batchSize = 12
25 | n_epochs = 500
26 | log_interval = 20
27 | #images = torch.randn(4, 3, 256, 256)
28 |
29 | t = transforms.Compose([
30 | transforms.Resize(imgSize),
31 | transforms.RandomHorizontalFlip(),
32 | transforms.ToTensor(),
33 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) #(0.267, 0.233, 0.234))
34 | ])
35 |
36 | train_set = datasets.ImageFolder('./imagedata', transform=t, target_transform=None)
37 |
38 | train_loader = DataLoader(dataset=train_set, num_workers=1, batch_size=batchSize, shuffle=True)
39 |
40 |
41 |
42 | for batch_idx, (images, _) in enumerate(train_loader):
43 | images = images.cuda()
44 | codes = vae.get_codebook_indices(images)
45 | sample1 = vae.decode(codes)
46 | #save_image(sample.view(-1, 3, imgSize, imgSize),
47 | # 'results/recon_sample_' + str(batch_idx) + '.png', normalize=True)
48 | for i in range(0, 8):
49 | j = i + 1
50 | j = j % 8
51 | codes[i,512:] = codes[j,512:]
52 | sample2 = vae.decode(codes)
53 | grid = torch.cat([images[:8], sample1[:8], sample2[:8]])
54 | save_image(grid.view(-1, 3, imgSize, imgSize),
55 | 'mixed/mixed_epoch_' +str(load_epoch) + "_"+ str(batch_idx) + '.png', normalize=True)
56 | #break
57 |
58 |
59 |
60 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/genDALLE.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from dalle_pytorch import DiscreteVAE, DALLE
3 | from torchvision.io import read_image
4 | import torchvision.transforms as transforms
5 | import torch.optim as optim
6 | from torchvision.utils import save_image
7 | import time
8 | import sys
9 |
10 | # vae
11 |
12 | load_epoch = 390
13 | vaename = "vae-cdim256"
14 |
15 | # general
16 |
17 | imgSize = 256
18 | batchSize = 12
19 | n_epochs = 100
20 | log_interval = 10
21 | lr = 2e-5
22 |
23 | #dalle
24 |
25 | dalle_epoch = 220
26 | #loadfn = ""
27 | #start_epoch = 0
28 | name = "vae-cdim256"
29 | loadfn = "./models/dalle_"+name+"-"+str(dalle_epoch)+".pth"
30 |
31 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32 |
33 | tf = transforms.Compose([
34 | #transforms.Resize(imgSize),
35 | #transforms.RandomHorizontalFlip(),
36 | #transforms.ToTensor(),
37 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) #(0.267, 0.233, 0.234))
38 | ])
39 |
40 | vae = DiscreteVAE(
41 | image_size = 256,
42 | num_layers = 3,
43 | num_tokens = 2048,
44 | codebook_dim = 256,
45 | hidden_dim = 128,
46 | temperature = 0.9
47 | )
48 |
49 | # load pretrained vae
50 |
51 | vae_dict = torch.load("./models/"+vaename+"-"+str(load_epoch)+".pth")
52 | vae.load_state_dict(vae_dict)
53 | vae.to(device)
54 |
55 | dalle = DALLE(
56 | dim = 256, #512,
57 | vae = vae, # automatically infer (1) image sequence length and (2) number of image tokens
58 | num_text_tokens = 10000, # vocab size for text
59 | text_seq_len = 256, # text sequence length
60 | depth = 6, # should be 64
61 | heads = 8, # attention heads
62 | dim_head = 64, # attention head dimension
63 | attn_dropout = 0.1, # attention dropout
64 | ff_dropout = 0.1 # feedforward dropout
65 | )
66 |
67 |
68 | # load pretrained dalle if continuing training
69 |
70 | dalle_dict = torch.load(loadfn)
71 | dalle.load_state_dict(dalle_dict)
72 |
73 | dalle.to(device)
74 |
75 | # get image and text data
76 |
77 | lf = open("od-captionsonly.txt", "r") # file contains captions only, one caption per line
78 |
79 | # build vocabulary
80 |
81 | from Vocabulary import Vocabulary
82 |
83 | vocab = Vocabulary("captions")
84 |
85 | captions = []
86 | for lin in lf:
87 | captions.append(lin)
88 |
89 | for caption in captions:
90 | vocab.add_sentence(caption)
91 |
92 | def tokenizer(text): # create a tokenizer function
93 | return text.split(' ')
94 |
95 | inp_text = sys.argv[1]
96 | print(inp_text)
97 | tokens = tokenizer(inp_text)
98 | codes = []
99 | for t in tokens:
100 | codes.append(vocab.to_index(t))
101 |
102 | print(codes)
103 | c_tokens = [0]*256 # fill to match text_seq_len
104 | c_tokens[:len(codes)] = codes
105 |
106 | text = torch.LongTensor(codes).unsqueeze(0).to(device) # a minibatch of text (numerical tokens)
107 | mask = torch.ones_like(text).bool().to(device)
108 | oimgs = dalle.generate_images(text, mask = mask)
109 | ts = int(time.time())
110 | print(inp_text, ts)
111 | save_image(oimgs,
112 | 'results/gendalle'+name+'_epoch_' + str(dalle_epoch) + '-' +str(ts)+'.png', normalize=True)
113 |
114 |
--------------------------------------------------------------------------------
/trainVAE.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from torch import nn, optim
4 | from torchvision import datasets, transforms
5 | from torchvision.utils import save_image
6 | from torch.utils.data import DataLoader
7 | import torch.nn.functional as F
8 | from dalle_pytorch import DiscreteVAE
9 | from torch.nn.utils import clip_grad_norm_
10 |
11 | parser = argparse.ArgumentParser(description='train VAE for DALLE-pytorch')
12 | parser.add_argument('--batchSize', type=int, default=24, help='batch size for training (default: 24)')
13 | parser.add_argument('--dataPath', type=str, default="./imagedata", help='path to imageFolder (default: ./imagedata')
14 | parser.add_argument('--imageSize', type=int, default=256, help='image size for training (default: 256)')
15 | parser.add_argument('--n_epochs', type=int, default=500, help='number of epochs (default: 500)')
16 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate (default: 1e-4)')
17 | parser.add_argument('--tempsched', action='store_true', default=False, help='use temperature scheduling')
18 | parser.add_argument('--temperature', type=float, default=0.9, help='vae temperature (default: 0.9)')
19 | parser.add_argument('--name', type=str, default="vae", help='experiment name')
20 | parser.add_argument('--loadVAE', type=str, default="", help='name for pretrained VAE when continuing training')
21 | parser.add_argument('--start_epoch', type=int, default=0, help='start epoch numbering for continuing training (default: 0)')
22 | parser.add_argument('--clip', type=float, default=0, help='clip weights, 0 = no clipping (default: 0)')
23 | opt = parser.parse_args()
24 |
25 | imgSize = opt.imageSize #256
26 | batchSize = opt.batchSize #24
27 | n_epochs = opt.n_epochs #500
28 | log_interval = 10
29 | lr = opt.lr #1e-4
30 | temperature_scheduling = opt.tempsched #True
31 |
32 | name = opt.name #"v2vae256"
33 |
34 | # for continuing training
35 | # set loadfn: path to pretrained model
36 | # start_epoch: start epoch numbering from this
37 | loadfn = opt.loadVAE #""
38 | start_epoch = opt.start_epoch #0
39 |
40 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
41 |
42 | vae = DiscreteVAE(
43 | image_size = imgSize,
44 | num_layers = 3,
45 | channels = 3,
46 | num_tokens = 2048,
47 | codebook_dim = 256,
48 | hidden_dim = 128,
49 | temperature = opt.temperature
50 | )
51 |
52 | if loadfn != "":
53 | vae_dict = torch.load(loadfn)
54 | vae.load_state_dict(vae_dict)
55 |
56 |
57 | vae.to(device)
58 |
59 | t = transforms.Compose([
60 | transforms.Resize(imgSize),
61 | transforms.ToTensor(),
62 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) #(0.267, 0.233, 0.234))
63 | ])
64 |
65 | train_set = datasets.ImageFolder(opt.dataPath, transform=t, target_transform=None)
66 |
67 | train_loader = DataLoader(dataset=train_set, num_workers=1, batch_size=batchSize, shuffle=True)
68 |
69 | optimizer = optim.Adam(vae.parameters(), lr=lr)
70 |
71 | def clampWeights(m):
72 | if type(m) != nn.BatchNorm2d and type(m) != nn.Sequential:
73 | for p in m.parameters():
74 | p.data.clamp_(-opt.clip, opt.clip)
75 |
76 | if temperature_scheduling:
77 | vae.temperature = opt.temperature
78 | dk = 0.7 ** (1/len(train_loader))
79 | print('Scale Factor:', dk)
80 |
81 | for epoch in range(start_epoch, start_epoch + n_epochs):
82 |
83 | train_loss = 0
84 | for batch_idx, (images, _) in enumerate(train_loader):
85 | images = images.to(device)
86 | recons = vae(images)
87 | loss = F.smooth_l1_loss(images, recons) + F.mse_loss(images, recons)
88 |
89 |
90 | optimizer.zero_grad()
91 | loss.backward()
92 | train_loss += loss.item()
93 | optimizer.step()
94 |
95 | if opt.clip > 0:
96 | vae.apply(clampWeights)
97 |
98 | if batch_idx % log_interval == 0:
99 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.8f}'.format(
100 | epoch, batch_idx * len(images), len(train_loader.dataset),
101 | 100. * batch_idx / len(train_loader),
102 | loss.item() / len(images)))
103 |
104 | if temperature_scheduling:
105 | vae.temperature *= dk
106 | print("Current temperature: ", vae.temperature)
107 |
108 | k = 8
109 | with torch.no_grad():
110 | codes = vae.get_codebook_indices(images)
111 | imgx = vae.decode(codes)
112 | grid = torch.cat([images[:k], recons[:k], imgx[:k]])
113 | save_image(grid,
114 | 'results/'+name+'_epoch_' + str(epoch) + '.png', normalize=True)
115 |
116 | print('====> Epoch: {} Average loss: {:.8f}'.format(
117 | epoch, train_loss / len(train_loader.dataset)))
118 |
119 | torch.save(vae.state_dict(), "./models/"+name+"-"+str(epoch)+".pth")
120 |
121 |
122 |
123 |
124 |
--------------------------------------------------------------------------------
/dalle_pytorch/reversible.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from operator import itemgetter
4 | from torch.autograd.function import Function
5 | from torch.utils.checkpoint import get_device_states, set_device_states
6 |
7 | # for routing arguments into the functions of the reversible layer
8 | def route_args(router, args, depth):
9 | routed_args = [(dict(), dict()) for _ in range(depth)]
10 | matched_keys = [key for key in args.keys() if key in router]
11 |
12 | for key in matched_keys:
13 | val = args[key]
14 | for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
15 | new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
16 | routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
17 | return routed_args
18 |
19 | # following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html
20 | class Deterministic(nn.Module):
21 | def __init__(self, net):
22 | super().__init__()
23 | self.net = net
24 | self.cpu_state = None
25 | self.cuda_in_fwd = None
26 | self.gpu_devices = None
27 | self.gpu_states = None
28 |
29 | def record_rng(self, *args):
30 | self.cpu_state = torch.get_rng_state()
31 | if torch.cuda._initialized:
32 | self.cuda_in_fwd = True
33 | self.gpu_devices, self.gpu_states = get_device_states(*args)
34 |
35 | def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
36 | if record_rng:
37 | self.record_rng(*args)
38 |
39 | if not set_rng:
40 | return self.net(*args, **kwargs)
41 |
42 | rng_devices = []
43 | if self.cuda_in_fwd:
44 | rng_devices = self.gpu_devices
45 |
46 | with torch.random.fork_rng(devices=rng_devices, enabled=True):
47 | torch.set_rng_state(self.cpu_state)
48 | if self.cuda_in_fwd:
49 | set_device_states(self.gpu_devices, self.gpu_states)
50 | return self.net(*args, **kwargs)
51 |
52 | # heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py
53 | # once multi-GPU is confirmed working, refactor and send PR back to source
54 | class ReversibleBlock(nn.Module):
55 | def __init__(self, f, g):
56 | super().__init__()
57 | self.f = Deterministic(f)
58 | self.g = Deterministic(g)
59 |
60 | def forward(self, x, f_args = {}, g_args = {}):
61 | x1, x2 = torch.chunk(x, 2, dim=2)
62 | y1, y2 = None, None
63 |
64 | with torch.no_grad():
65 | y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
66 | y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
67 |
68 | return torch.cat([y1, y2], dim=2)
69 |
70 | def backward_pass(self, y, dy, f_args = {}, g_args = {}):
71 | y1, y2 = torch.chunk(y, 2, dim=2)
72 | del y
73 |
74 | dy1, dy2 = torch.chunk(dy, 2, dim=2)
75 | del dy
76 |
77 | with torch.enable_grad():
78 | y1.requires_grad = True
79 | gy1 = self.g(y1, set_rng=True, **g_args)
80 | torch.autograd.backward(gy1, dy2)
81 |
82 | with torch.no_grad():
83 | x2 = y2 - gy1
84 | del y2, gy1
85 |
86 | dx1 = dy1 + y1.grad
87 | del dy1
88 | y1.grad = None
89 |
90 | with torch.enable_grad():
91 | x2.requires_grad = True
92 | fx2 = self.f(x2, set_rng=True, **f_args)
93 | torch.autograd.backward(fx2, dx1, retain_graph=True)
94 |
95 | with torch.no_grad():
96 | x1 = y1 - fx2
97 | del y1, fx2
98 |
99 | dx2 = dy2 + x2.grad
100 | del dy2
101 | x2.grad = None
102 |
103 | x = torch.cat([x1, x2.detach()], dim=2)
104 | dx = torch.cat([dx1, dx2], dim=2)
105 |
106 | return x, dx
107 |
108 | class _ReversibleFunction(Function):
109 | @staticmethod
110 | def forward(ctx, x, blocks, args):
111 | ctx.args = args
112 | for block, kwarg in zip(blocks, args):
113 | x = block(x, **kwarg)
114 | ctx.y = x.detach()
115 | ctx.blocks = blocks
116 | return x
117 |
118 | @staticmethod
119 | def backward(ctx, dy):
120 | y = ctx.y
121 | args = ctx.args
122 | for block, kwargs in zip(ctx.blocks[::-1], args[::-1]):
123 | y, dy = block.backward_pass(y, dy, **kwargs)
124 | return dy, None, None
125 |
126 | class SequentialSequence(nn.Module):
127 | def __init__(self, layers, args_route = {}, layer_dropout = 0.):
128 | super().__init__()
129 | assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
130 | self.layers = layers
131 | self.args_route = args_route
132 | self.layer_dropout = layer_dropout
133 |
134 | def forward(self, x, **kwargs):
135 | args = route_args(self.args_route, kwargs, len(self.layers))
136 | layers_and_args = list(zip(self.layers, args))
137 |
138 | for (f, g), (f_args, g_args) in layers_and_args:
139 | x = x + f(x, **f_args)
140 | x = x + g(x, **g_args)
141 | return x
142 |
143 | class ReversibleSequence(nn.Module):
144 | def __init__(self, blocks, args_route = {}):
145 | super().__init__()
146 | self.args_route = args_route
147 | self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks])
148 |
149 | def forward(self, x, **kwargs):
150 | x = torch.cat([x, x], dim=-1)
151 |
152 | blocks = self.blocks
153 | args = route_args(self.args_route, kwargs, len(blocks))
154 | args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args))
155 |
156 | out = _ReversibleFunction.apply(x, blocks, args)
157 | return torch.stack(out.chunk(2, dim=-1)).mean(dim=0)
158 |
--------------------------------------------------------------------------------
/dalle_pytorch/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from inspect import isfunction
3 | from torch import nn, einsum
4 | import torch.nn.functional as F
5 | from einops import rearrange, repeat
6 |
7 | from dalle_pytorch.reversible import ReversibleSequence, SequentialSequence
8 |
9 | # helpers
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 | def default(val, d):
15 | if exists(val):
16 | return val
17 | return d() if isfunction(d) else d
18 |
19 | def cast_tuple(val, depth):
20 | return val if isinstance(val, tuple) else (val,) * depth
21 |
22 | # classes
23 |
24 | class PreNorm(nn.Module):
25 | def __init__(self, dim, fn):
26 | super().__init__()
27 | self.norm = nn.LayerNorm(dim)
28 | self.fn = fn
29 |
30 | def forward(self, x, **kwargs):
31 | return self.fn(self.norm(x), **kwargs)
32 |
33 | class GEGLU(nn.Module):
34 | def forward(self, x):
35 | x, gates = x.chunk(2, dim = -1)
36 | return x * F.gelu(gates)
37 |
38 | class FeedForward(nn.Module):
39 | def __init__(self, dim, dropout = 0., mult = 4.):
40 | super().__init__()
41 | self.net = nn.Sequential(
42 | nn.Linear(dim, dim * mult * 2),
43 | GEGLU(),
44 | nn.Dropout(dropout),
45 | nn.Linear(dim * mult, dim)
46 | )
47 |
48 | def forward(self, x):
49 | return self.net(x)
50 |
51 | class Attention(nn.Module):
52 | def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0.):
53 | super().__init__()
54 | inner_dim = dim_head * heads
55 | self.heads = heads
56 | self.seq_len = seq_len
57 | self.scale = dim ** -0.5
58 | self.causal = causal
59 |
60 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
61 | self.to_out = nn.Sequential(
62 | nn.Linear(inner_dim, dim),
63 | nn.Dropout(dropout)
64 | )
65 |
66 | def forward(self, x, mask = None):
67 | b, n, _, h, device = *x.shape, self.heads, x.device
68 | qkv = self.to_qkv(x).chunk(3, dim = -1)
69 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
70 |
71 | dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
72 | mask_value = -torch.finfo(dots.dtype).max
73 |
74 | if exists(mask):
75 | mask = rearrange(mask, 'b i -> b () i ()') * rearrange(mask, 'b j -> b () () j')
76 | dots.masked_fill_(~mask, mask_value)
77 | del mask
78 |
79 | if self.causal:
80 | i, j = dots.shape[-2:]
81 | mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
82 | dots.masked_fill_(mask, mask_value)
83 |
84 | attn = dots.softmax(dim=-1)
85 |
86 | out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
87 | out = rearrange(out, 'b h n d -> b n (h d)')
88 | out = self.to_out(out)
89 | return out
90 |
91 | class SparseAttention(Attention):
92 | def __init__(self, *args, **kwargs):
93 | super().__init__(*args, **kwargs)
94 | from deepspeed.ops.sparse_attention import SparseSelfAttention, VariableSparsityConfig
95 | self.block_size = 16
96 |
97 | self.attn_fn = SparseSelfAttention(
98 | sparsity_config = VariableSparsityConfig(
99 | num_heads = self.heads,
100 | block = self.block_size,
101 | attention = 'unidirectional' if self.causal else 'bidirectional'
102 | ),
103 | max_seq_length = self.seq_len,
104 | attn_mask_mode = 'add'
105 | )
106 |
107 | def forward(self, x, mask = None):
108 | b, n, _, h, device = *x.shape, self.heads, x.device
109 | remainder = n % self.block_size
110 | mask = default(mask, lambda: torch.ones(b, n, device = device).bool())
111 |
112 | if remainder > 0:
113 | padding = self.block_size - remainder
114 | x = F.pad(x, (0, 0, 0, padding), value = 0)
115 | mask = F.pad(mask, (0, padding), value = False)
116 |
117 | qkv = self.to_qkv(x).chunk(3, dim = -1)
118 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
119 |
120 | key_pad_mask = None
121 | if exists(mask):
122 | key_pad_mask = ~mask
123 |
124 | attn_mask = None
125 | if self.causal:
126 | i, j = q.shape[-2], k.shape[-2]
127 | mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
128 | attn_mask = torch.zeros(i, j, device = device).to(q)
129 | mask_value = -(torch.finfo(q.dtype).max / 2)
130 | attn_mask.masked_fill_(mask, mask_value)
131 |
132 | out = self.attn_fn(q, k, v, attn_mask = attn_mask, key_padding_mask = key_pad_mask)
133 | out = rearrange(out, 'b h n d -> b n (h d)')
134 | out = self.to_out(out)
135 | return out[:, :n]
136 |
137 | class Transformer(nn.Module):
138 | def __init__(
139 | self,
140 | *,
141 | dim,
142 | depth,
143 | seq_len,
144 | reversible = False,
145 | causal = True,
146 | heads = 8,
147 | dim_head = 64,
148 | ff_mult = 4,
149 | attn_dropout = 0.,
150 | ff_dropout = 0.,
151 | sparse_attn = True
152 | ):
153 | super().__init__()
154 | layers = nn.ModuleList([])
155 | sparse_layer = cast_tuple(sparse_attn, depth)
156 |
157 | for _, sparse_attn in zip(range(depth), sparse_layer):
158 | attn_class = Attention if not sparse_attn else SparseAttention
159 |
160 | layers.append(nn.ModuleList([
161 | PreNorm(dim, attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)),
162 | PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = ff_dropout))
163 | ]))
164 |
165 | execute_type = ReversibleSequence if reversible else SequentialSequence
166 | route_attn = ((True, False),) * depth
167 | attn_route_map = {'mask': route_attn}
168 |
169 | self.layers = execute_type(layers, args_route = attn_route_map)
170 |
171 | def forward(self, x, **kwargs):
172 | return self.layers(x, **kwargs)
173 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## DALL-E in Pytorch
2 |
3 | Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch. It will also contain CLIP for ranking the generations.
4 |
5 | Sid, Ben, and Aran over at Eleuther AI are working on DALL-E for Mesh Tensorflow! Please lend them a hand if you would like to see DALL-E trained on TPUs.
6 |
7 | Yannic Kilcher's video
8 |
9 | ## Status
10 |
11 | Hannu has managed to train a small 6 layer DALL-E on a dataset of just 2000 landscape images! (2048 visual tokens)
12 |
13 |
14 |
15 | ## Install
16 |
17 | ```bash
18 | $ pip install dalle-pytorch
19 | ```
20 |
21 | ## Usage
22 |
23 | Train VAE
24 |
25 | ```python
26 | import torch
27 | from dalle_pytorch import DiscreteVAE
28 |
29 | vae = DiscreteVAE(
30 | image_size = 256,
31 | num_layers = 3, # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map)
32 | num_tokens = 1024, # number of visual tokens. iGPT had 512, so probably should have more
33 | codebook_dim = 512, # codebook dimension
34 | hidden_dim = 64, # hidden dimension
35 | temperature = 0.9 # gumbel softmax temperature, the lower this is, the more hard the discretization
36 | )
37 |
38 | images = torch.randn(4, 3, 256, 256)
39 |
40 | loss = vae(images, return_recon_loss = True)
41 | loss.backward()
42 |
43 | # train with a lot of data to learn a good codebook
44 | ```
45 |
46 | Train DALL-E with pretrained VAE from above
47 |
48 | ```python
49 | import torch
50 | from dalle_pytorch import DiscreteVAE, DALLE
51 |
52 | vae = DiscreteVAE(
53 | image_size = 256,
54 | num_layers = 3,
55 | num_tokens = 1024,
56 | codebook_dim = 512,
57 | hidden_dim = 64,
58 | temperature = 0.9
59 | )
60 |
61 | dalle = DALLE(
62 | dim = 512,
63 | vae = vae, # automatically infer (1) image sequence length and (2) number of image tokens
64 | num_text_tokens = 10000, # vocab size for text
65 | text_seq_len = 256, # text sequence length
66 | depth = 6, # should be 64
67 | heads = 8, # attention heads
68 | dim_head = 64, # attention head dimension
69 | attn_dropout = 0.1, # attention dropout
70 | ff_dropout = 0.1 # feedforward dropout
71 | )
72 |
73 | text = torch.randint(0, 10000, (4, 256))
74 | images = torch.randn(4, 3, 256, 256)
75 | mask = torch.ones_like(text).bool()
76 |
77 | loss = dalle(text, images, mask = mask, return_loss = True)
78 | loss.backward()
79 |
80 | # do the above for a long time with a lot of data ... then
81 |
82 | images = dalle.generate_images(text, mask = mask)
83 | images.shape # (2, 3, 256, 256)
84 | ```
85 |
86 | ## Ranking the generations
87 |
88 | Train CLIP
89 |
90 | ```python
91 | import torch
92 | from dalle_pytorch import CLIP
93 |
94 | clip = CLIP(
95 | dim_text = 512,
96 | dim_image = 512,
97 | dim_latent = 512,
98 | num_text_tokens = 10000,
99 | text_enc_depth = 6,
100 | text_seq_len = 256,
101 | text_heads = 8,
102 | num_visual_tokens = 512,
103 | visual_enc_depth = 6,
104 | visual_image_size = 256,
105 | visual_patch_size = 32,
106 | visual_heads = 8
107 | )
108 |
109 | text = torch.randint(0, 10000, (4, 256))
110 | images = torch.randn(4, 3, 256, 256)
111 | mask = torch.ones_like(text).bool()
112 |
113 | loss = clip(text, images, text_mask = mask, return_loss = True)
114 | loss.backward()
115 | ```
116 |
117 | To get the similarity scores from your trained Clipper, just do
118 |
119 | ```python
120 | images, scores = dalle.generate_images(text, mask = mask, clip = clip)
121 |
122 | scores.shape # (2,)
123 | images.shape # (2, 3, 256, 256)
124 |
125 | # do your topk here, in paper they sampled 512 and chose top 32
126 | ```
127 |
128 | Or you can just use the official CLIP model to rank the images from DALL-E
129 |
130 | ## Scaling depth
131 |
132 | In the blog post, they used 64 layers to achieve their results. I added reversible networks, from the Reformer paper, in order for users to attempt to scale depth at the cost of compute. Reversible networks allow you to scale to any depth at no memory cost, but a little over 2x compute cost (each layer is rerun on the backward pass).
133 |
134 | Simply set the `reversible` keyword to `True` for the `DALLE` class
135 |
136 | ```python
137 | dalle = DALLE(
138 | dim = 512,
139 | vae = vae,
140 | num_text_tokens = 10000,
141 | text_seq_len = 256,
142 | depth = 64,
143 | heads = 8,
144 | reversible = True # <-- reversible networks https://arxiv.org/abs/2001.04451
145 | )
146 | ```
147 |
148 | ## Sparse Attention
149 |
150 | You can also train with Microsoft Deepspeed's Sparse Attention, with any combination of dense and sparse attention that you'd like. However, you will have to endure the installation process.
151 |
152 | First, you need to install Deepspeed with Sparse Attention
153 |
154 | ```bash
155 | $ sh install_deepspeed.sh
156 | ```
157 |
158 | Next, you need to install the pip package `triton`
159 |
160 | ```bash
161 | $ pip install triton
162 | ```
163 |
164 | If both of the above succeeded, now you can train with Sparse Attention!
165 |
166 | ```python
167 | dalle = DALLE(
168 | dim = 512,
169 | vae = vae,
170 | num_text_tokens = 10000,
171 | text_seq_len = 256,
172 | depth = 64,
173 | heads = 8,
174 | sparse_attn = (True, False) * 32 # interleave sparse and dense attention for 64 layers
175 | )
176 | ```
177 |
178 | ## Citations
179 |
180 | ```bibtex
181 | @misc{unpublished2021dalle,
182 | title = {DALL·E: Creating Images from Text},
183 | author = {Aditya Ramesh, Mikhail Pavlov, Gabriel Goh, Scott Gray},
184 | year = {2021}
185 | }
186 | ```
187 |
188 | ```bibtex
189 | @misc{unpublished2021clip,
190 | title = {CLIP: Connecting Text and Images},
191 | author = {Alec Radford, Ilya Sutskever, Jong Wook Kim, Gretchen Krueger, Sandhini Agarwal},
192 | year = {2021}
193 | }
194 | ```
195 |
196 | ```bibtex
197 | @misc{kitaev2020reformer,
198 | title = {Reformer: The Efficient Transformer},
199 | author = {Nikita Kitaev and Łukasz Kaiser and Anselm Levskaya},
200 | year = {2020},
201 | eprint = {2001.04451},
202 | archivePrefix = {arXiv},
203 | primaryClass = {cs.LG}
204 | }
205 | ```
206 |
207 | *Those who do not want to imitate anything, produce nothing.* - Dali
208 |
--------------------------------------------------------------------------------
/trainDALLE.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from dalle_pytorch import DiscreteVAE, DALLE
4 | from torchvision.io import read_image
5 | import torchvision.transforms as transforms
6 | import torch.optim as optim
7 | from torchvision.utils import save_image
8 |
9 | parser = argparse.ArgumentParser(description='train VAE for DALLE-pytorch')
10 | parser.add_argument('--batchSize', type=int, default=24, help='batch size for training (default: 24)')
11 | parser.add_argument('--dataPath', type=str, default="./imagedata", help='path to imageFolder (default: ./imagedata')
12 | parser.add_argument('--imageSize', type=int, default=256, help='image size for training (default: 256)')
13 | parser.add_argument('--n_epochs', type=int, default=500, help='number of epochs (default: 500)')
14 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate (default: 1e-4)')
15 | #parser.add_argument('--tempsched', action='store_true', default=False, help='use temperature scheduling')
16 | #parser.add_argument('--temperature', type=float, default=0.9, help='vae temperature (default: 0.9)')
17 | parser.add_argument('--vaename', type=str, default="vae", help='experiment name')
18 | parser.add_argument('--vae_epoch', type=int, default=0, help='start epoch numbering for continuing training (default: 0)')
19 | parser.add_argument('--name', type=str, default="test", help='experiment name')
20 | parser.add_argument('--load_dalle', type=str, default="", help='name for pretrained VAE when continuing training')
21 | parser.add_argument('--start_epoch', type=int, default=0, help='start epoch numbering for continuing training (default: 0)')
22 | opt = parser.parse_args()
23 |
24 | # vae
25 |
26 | load_epoch = opt.vae_epoch #499
27 | vaename = opt.vaename #"v2vae256"
28 |
29 | # general
30 |
31 | imgSize = opt.imageSize #256
32 | batchSize = opt.batchSize #24
33 | n_epochs = opt.n_epochs #500
34 | log_interval = 10
35 | lr = opt.lr #1e-4
36 |
37 | #dalle
38 |
39 | # to continue training from a saved checkpoint, give checkpoint path as loadfn and start_epoch
40 |
41 | #loadfn = "./models/dalle_vae-cdim256-140.pth"
42 | #start_epoch = 140
43 | loadfn = opt.load_dalle
44 | start_epoch = opt.start_epoch
45 | name = opt.name #v2vae256
46 |
47 |
48 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
49 |
50 | tf = transforms.Compose([
51 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
52 | ])
53 |
54 | vae = DiscreteVAE(
55 | image_size = opt.imageSize,
56 | num_layers = 3,
57 | channels = 3,
58 | num_tokens = 2048,
59 | codebook_dim = 256,
60 | hidden_dim = 128,
61 | temperature = 0.9
62 | )
63 |
64 | # load pretrained vae
65 | print("loading VAE from ./models/"+vaename+"-"+str(load_epoch)+".pth")
66 | vae_dict = torch.load("./models/"+vaename+"-"+str(load_epoch)+".pth")
67 | vae.load_state_dict(vae_dict)
68 | vae.to(device)
69 |
70 | dalle = DALLE(
71 | dim = 256, #512,
72 | vae = vae, # automatically infer (1) image sequence length and (2) number of image tokens
73 | num_text_tokens = 10000, # vocab size for text
74 | text_seq_len = 256, # text sequence length
75 | depth = 6, # should be 64
76 | heads = 8, # attention heads
77 | dim_head = 64, # attention head dimension
78 | attn_dropout = 0.1, # attention dropout
79 | ff_dropout = 0.1 # feedforward dropout
80 | )
81 |
82 |
83 | # load pretrained dalle if continuing training
84 | if loadfn != "":
85 | dalle_dict = torch.load(loadfn)
86 | dalle.load_state_dict(dalle_dict)
87 |
88 | dalle.to(device)
89 |
90 | # get image and text data
91 |
92 | lf = open("od-captionsonly.txt", "r") # file contains captions only, one caption per line
93 |
94 | # build vocabulary
95 |
96 | from Vocabulary import Vocabulary
97 |
98 | vocab = Vocabulary("captions")
99 |
100 | captions = []
101 | for lin in lf:
102 | captions.append(lin)
103 |
104 | for caption in captions:
105 | vocab.add_sentence(caption)
106 |
107 | def tokenizer(text): # create a tokenizer function
108 | return text.split(' ')
109 |
110 |
111 | lf = open("od-captions.txt", "r") # files contains lines in the format image_path : captions
112 |
113 | data = []
114 |
115 | for lin in lf:
116 | (fn, txt) = lin.split(":")
117 | tokens = tokenizer(txt)
118 | codes = []
119 | for t in tokens:
120 | #print(t)
121 | if t=="":
122 | continue
123 | codes.append(vocab.to_index(t))
124 | #print(fn, codes)
125 | data.append((fn, codes))
126 |
127 |
128 |
129 | len_data = len(data)
130 | print(len_data)
131 | #datactr = 0
132 |
133 | # an iterator for fetching data during training
134 |
135 | class ImageCaptions:
136 |
137 | def __init__(self, data, batchsize=4):
138 | self.data = data
139 | self.len = len(data)
140 | self.index = 0
141 | self.end = False
142 | self.batchsize = batchsize
143 |
144 | def __iter__(self):
145 | return self
146 |
147 | def __next__(self):
148 | if self.end:
149 | self.index = 0
150 | raise StopIteration
151 | i_data = []
152 | c_data = []
153 | for i in range(0, self.batchsize):
154 | i_data.append(self.data[self.index][0])
155 | c_tokens = [0]*256 # fill to match text_seq_len
156 | c_tokens_ = self.data[self.index][1]
157 | c_tokens[:len(c_tokens_)] = c_tokens_
158 | c_data.append(c_tokens)
159 | self.index += 1
160 | if self.index == self.len:
161 | self.end = True
162 | break
163 | return i_data, c_data
164 |
165 |
166 | optimizer = optim.Adam(dalle.parameters(), lr=lr)
167 |
168 | for epoch in range(start_epoch, start_epoch+n_epochs):
169 | batch_idx = 0
170 | train_loss = 0
171 | dset = ImageCaptions(data, batchsize=batchSize) # initialize iterator
172 |
173 | for i,c in dset: # loop through dataset by minibatch
174 | text = torch.LongTensor(c) # a minibatch of text (numerical tokens)
175 | images = torch.zeros(len(i), 3, 256, 256) # placeholder for images
176 |
177 | text = text.to(device)
178 | #print(text)
179 |
180 | # fetch images into tensor based on paths given in minibatch
181 | ix = 0
182 | for imgfn in i: # iterate through image paths in minibatch
183 |
184 | # note: images are expected to be in ./imagefolder/0/
185 | img_t = read_image(opt.dataPath+"/0/"+imgfn).float()/255. # read image and scale into float 0..1
186 | img_t = tf(img_t) # normalize
187 | images[ix,:,:,:] = img_t
188 | ix += 1
189 |
190 | images = images.to(device)
191 |
192 | mask = torch.ones_like(text).bool().to(device)
193 |
194 | # train and optimize a single minibatch
195 | optimizer.zero_grad()
196 | loss = dalle(text, images, mask = mask, return_loss = True)
197 | train_loss += loss.item()
198 | loss.backward()
199 | optimizer.step()
200 |
201 | if batch_idx % log_interval == 0:
202 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
203 | epoch, batch_idx * len(i), len(data),
204 | 100. * batch_idx / int(round(len(data)/batchSize)),
205 | loss.item() / len(i)))
206 |
207 | batch_idx += 1
208 |
209 | print('====> Epoch: {} Average loss: {:.4f}'.format(
210 | epoch, train_loss / len(data)))
211 |
212 | torch.save(dalle.state_dict(), "./models/"+name+"_dalle_"+str(epoch)+".pth")
213 |
214 | # generate a test sample from the captions in the last minibatch
215 | oimgs = dalle.generate_images(text, mask = mask)
216 | save_image(oimgs,
217 | 'results/'+name+'_dalle_epoch_' + str(epoch) + '.png', normalize=True)
218 |
219 |
--------------------------------------------------------------------------------
/dalle_pytorch/dalle_pytorch.py:
--------------------------------------------------------------------------------
1 | from math import log2, sqrt
2 | import torch
3 | from torch import nn, einsum
4 | import torch.nn.functional as F
5 |
6 | from einops import rearrange
7 | from axial_positional_embedding import AxialPositionalEmbedding
8 | from dalle_pytorch.transformer import Transformer
9 |
10 | # helpers
11 |
12 | def exists(val):
13 | return val is not None
14 |
15 | def default(val, d):
16 | return val if exists(val) else d
17 |
18 | def always(val):
19 | def inner(*args, **kwargs):
20 | return val
21 | return inner
22 |
23 | def is_empty(t):
24 | return t.nelement() == 0
25 |
26 | def masked_mean(t, mask, dim = 1):
27 | t = t.masked_fill(~mask[:, :, None], 0.)
28 | return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]
29 |
30 | def eval_decorator(fn):
31 | def inner(model, *args, **kwargs):
32 | was_training = model.training
33 | model.eval()
34 | out = fn(model, *args, **kwargs)
35 | model.train(was_training)
36 | return out
37 | return inner
38 |
39 | # sampling helpers
40 |
41 | def top_k(logits, thres = 0.5):
42 | num_logits = logits.shape[-1]
43 | k = max(int((1 - thres) * num_logits), 1)
44 | val, ind = torch.topk(logits, k)
45 | probs = torch.full_like(logits, float('-inf'))
46 | probs.scatter_(1, ind, val)
47 | return probs
48 |
49 | # discrete vae class
50 |
51 | class ResBlock(nn.Module):
52 | def __init__(self, chan):
53 | super().__init__()
54 | self.net = nn.Sequential(
55 | nn.Conv2d(chan, chan, 3, padding = 1),
56 | nn.ReLU(),
57 | nn.Conv2d(chan, chan, 3, padding = 1),
58 | nn.ReLU(),
59 | nn.Conv2d(chan, chan, 1)
60 | )
61 |
62 | def forward(self, x):
63 | return self.net(x) + x
64 |
65 | class DiscreteVAE(nn.Module):
66 | def __init__(
67 | self,
68 | image_size = 256,
69 | num_tokens = 512,
70 | codebook_dim = 512,
71 | num_layers = 3,
72 | num_resnet_blocks = 0,
73 | hidden_dim = 64,
74 | channels = 3,
75 | temperature = 0.9
76 | ):
77 | super().__init__()
78 | assert log2(image_size).is_integer(), 'image size must be a power of 2'
79 | assert num_layers >= 1, 'number of layers must be greater than or equal to 1'
80 | has_resblocks = num_resnet_blocks > 0
81 |
82 | self.image_size = image_size
83 | self.num_tokens = num_tokens
84 | self.num_layers = num_layers
85 | self.temperature = temperature
86 | self.codebook = nn.Embedding(num_tokens, codebook_dim)
87 |
88 | hdim = hidden_dim
89 |
90 | enc_chans = [hidden_dim] * num_layers
91 | dec_chans = list(reversed(enc_chans))
92 |
93 | enc_chans = [channels, *enc_chans]
94 |
95 | dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
96 | dec_chans = [dec_init_chan, *dec_chans]
97 |
98 | enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))
99 |
100 | enc_layers = []
101 | dec_layers = []
102 |
103 | for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
104 | enc_layers.append(nn.Sequential(nn.Conv2d(enc_in, enc_out, 4, stride = 2, padding = 1), nn.ReLU()))
105 | dec_layers.append(nn.Sequential(nn.ConvTranspose2d(dec_in, dec_out, 4, stride = 2, padding = 1), nn.ReLU()))
106 |
107 | for _ in range(num_resnet_blocks):
108 | dec_layers.insert(0, ResBlock(dec_chans[1]))
109 | enc_layers.append(ResBlock(enc_chans[-1]))
110 |
111 | if num_resnet_blocks > 0:
112 | dec_layers.insert(0, nn.Conv2d(codebook_dim, dec_chans[1], 1))
113 |
114 | enc_layers.append(nn.Conv2d(enc_chans[-1], num_tokens, 1))
115 | dec_layers.append(nn.Conv2d(dec_chans[-1], channels, 1))
116 |
117 | self.encoder = nn.Sequential(*enc_layers)
118 | self.decoder = nn.Sequential(*dec_layers)
119 |
120 | @torch.no_grad()
121 | def get_codebook_indices(self, images):
122 | logits = self.forward(images, return_logits = True)
123 | codebook_indices = logits.argmax(dim = 1).flatten(1)
124 | return codebook_indices
125 |
126 | def decode(
127 | self,
128 | img_seq
129 | ):
130 | image_embeds = self.codebook(img_seq)
131 | b, n, d = image_embeds.shape
132 | h = w = int(sqrt(n))
133 |
134 | image_embeds = rearrange(image_embeds, 'b (h w) d -> b d h w', h = h, w = w)
135 | images = self.decoder(image_embeds)
136 | return images
137 |
138 | def forward(
139 | self,
140 | img,
141 | return_recon_loss = False,
142 | return_logits = False
143 | ):
144 | logits = self.encoder(img)
145 |
146 | if return_logits:
147 | return logits # return logits for getting hard image indices for DALL-E training
148 |
149 | soft_one_hot = F.gumbel_softmax(logits, tau = self.temperature, dim = 1)
150 | sampled = einsum('b n h w, n d -> b d h w', soft_one_hot, self.codebook.weight)
151 | out = self.decoder(sampled)
152 |
153 | if not return_recon_loss:
154 | return out
155 |
156 | loss = F.mse_loss(img, out)
157 | return loss
158 |
159 | # main classes
160 |
161 | class CLIP(nn.Module):
162 | def __init__(
163 | self,
164 | *,
165 | dim_text = 512,
166 | dim_image = 512,
167 | dim_latent = 512,
168 | num_text_tokens = 10000,
169 | text_enc_depth = 6,
170 | text_seq_len = 256,
171 | text_heads = 8,
172 | num_visual_tokens = 512,
173 | visual_enc_depth = 6,
174 | visual_heads = 8,
175 | visual_image_size = 256,
176 | visual_patch_size = 32,
177 | channels = 3
178 | ):
179 | super().__init__()
180 | self.text_emb = nn.Embedding(num_text_tokens, dim_text)
181 | self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
182 | self.text_transformer = Transformer(causal = False, seq_len = text_seq_len, dim = dim_text, depth = text_enc_depth, heads = text_heads)
183 | self.to_text_latent = nn.Linear(dim_text, dim_latent, bias = False)
184 |
185 | assert visual_image_size % visual_patch_size == 0, 'Image dimensions must be divisible by the patch size.'
186 | num_patches = (visual_image_size // visual_patch_size) ** 2
187 | patch_dim = channels * visual_patch_size ** 2
188 |
189 | self.visual_patch_size = visual_patch_size
190 | self.to_visual_embedding = nn.Linear(patch_dim, dim_image)
191 | self.visual_pos_emb = nn.Embedding(num_patches, dim_image)
192 | self.visual_transformer = Transformer(causal = False, seq_len = num_patches, dim = dim_image, depth = visual_enc_depth, heads = visual_heads)
193 | self.to_visual_latent = nn.Linear(dim_image, dim_latent, bias = False)
194 |
195 | self.temperature = nn.Parameter(torch.tensor(1.))
196 |
197 | def forward(
198 | self,
199 | text,
200 | image,
201 | text_mask = None,
202 | return_loss = False
203 | ):
204 | b, device, p = text.shape[0], text.device, self.visual_patch_size
205 |
206 | text_emb = self.text_emb(text)
207 | text_emb += self.text_pos_emb(torch.arange(text.shape[1], device = device))
208 |
209 | image_patches = rearrange(image, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
210 | image_emb = self.to_visual_embedding(image_patches)
211 | image_emb += self.visual_pos_emb(torch.arange(image_emb.shape[1], device = device))
212 |
213 | enc_text = self.text_transformer(text_emb, mask = text_mask)
214 | enc_image = self.visual_transformer(image_emb)
215 |
216 | if exists(text_mask):
217 | text_latents = masked_mean(enc_text, text_mask, dim = 1)
218 | else:
219 | text_latents = enc_text.mean(dim = 1)
220 |
221 | image_latents = enc_image.mean(dim = 1)
222 |
223 | text_latents = self.to_text_latent(text_latents)
224 | image_latents = self.to_visual_latent(image_latents)
225 |
226 | text_latents, image_latents = map(lambda t: F.normalize(t, p = 2, dim = -1), (text_latents, image_latents))
227 |
228 | temp = self.temperature.exp()
229 |
230 | if not return_loss:
231 | sim = einsum('n d, n d -> n', text_latents, image_latents) * temp
232 | return sim
233 |
234 | sim = einsum('i d, j d -> i j', text_latents, image_latents) * temp
235 | labels = torch.arange(b, device = device)
236 | loss = F.cross_entropy(sim, labels)
237 | return loss
238 |
239 | # main DALL-E class
240 |
241 | class DALLE(nn.Module):
242 | def __init__(
243 | self,
244 | *,
245 | dim,
246 | vae,
247 | num_text_tokens = 10000,
248 | text_seq_len = 256,
249 | depth,
250 | heads = 8,
251 | dim_head = 64,
252 | reversible = False,
253 | attn_dropout = 0.,
254 | ff_dropout = 0,
255 | sparse_attn = False
256 | ):
257 | super().__init__()
258 | assert isinstance(vae, DiscreteVAE), 'vae must be an instance of DiscreteVAE'
259 |
260 | image_size = vae.image_size
261 | num_image_tokens = vae.num_tokens
262 | image_seq_len = (vae.image_size // (2 ** vae.num_layers)) ** 2
263 |
264 | self.text_emb = nn.Embedding(num_text_tokens, dim)
265 | self.image_emb = nn.Embedding(num_image_tokens, dim)
266 |
267 | self.text_pos_emb = nn.Embedding(text_seq_len, dim)
268 | self.image_pos_emb = AxialPositionalEmbedding(dim, axial_shape = (image_size, image_size))
269 |
270 | self.num_text_tokens = num_text_tokens # for offsetting logits index and calculating cross entropy loss
271 | self.num_image_tokens = num_image_tokens
272 |
273 | self.text_seq_len = text_seq_len
274 | self.image_seq_len = image_seq_len
275 |
276 | seq_len = text_seq_len + image_seq_len
277 | total_tokens = num_text_tokens + num_image_tokens + 1 # extra for EOS
278 | self.total_tokens = total_tokens
279 |
280 | self.vae = vae
281 | if exists(self.vae):
282 | self.vae = vae
283 | self.image_emb = vae.codebook
284 |
285 | self.transformer = Transformer(
286 | dim = dim,
287 | causal = True,
288 | seq_len = seq_len,
289 | depth = depth,
290 | heads = heads,
291 | dim_head = dim_head,
292 | reversible = reversible,
293 | attn_dropout = attn_dropout,
294 | ff_dropout = ff_dropout,
295 | sparse_attn = sparse_attn
296 | )
297 |
298 | self.to_logits = nn.Sequential(
299 | nn.LayerNorm(dim),
300 | nn.Linear(dim, self.total_tokens),
301 | )
302 |
303 | seq_range = torch.arange(seq_len)
304 | logits_range = torch.arange(total_tokens)
305 |
306 | seq_range = rearrange(seq_range, 'n -> () n ()')
307 | logits_range = rearrange(logits_range, 'd -> () () d')
308 |
309 | logits_mask = (
310 | ((seq_range >= (text_seq_len - 1)) & (logits_range < num_text_tokens)) |
311 | ((seq_range < (text_seq_len - 1)) & (logits_range >= num_text_tokens)) |
312 | ((seq_range != (seq_len - 1)) & (logits_range >= (total_tokens - 1)))
313 | )
314 |
315 | self.register_buffer('logits_mask', logits_mask)
316 |
317 | @torch.no_grad()
318 | @eval_decorator
319 | def generate_images(
320 | self,
321 | text,
322 | *,
323 | clip = None,
324 | mask = None,
325 | filter_thres = 0.5,
326 | temperature = 1.
327 | ):
328 | vae, text_seq_len, image_seq_len, num_text_tokens = self.vae, self.text_seq_len, self.image_seq_len, self.num_text_tokens
329 | total_len = text_seq_len + image_seq_len
330 |
331 | out = text
332 | for cur_len in range(text.shape[1], total_len):
333 | is_image = cur_len >= text_seq_len
334 |
335 | text, image = out[:, :text_seq_len], out[:, text_seq_len:]
336 |
337 | logits = self(text, image, mask = mask)[:, -1, :]
338 |
339 | filtered_logits = top_k(logits, thres = filter_thres)
340 | probs = F.softmax(filtered_logits / temperature, dim = -1)
341 | sample = torch.multinomial(probs, 1)
342 |
343 | sample -= (num_text_tokens if is_image else 0) # offset sampled token if it is an image token, since logit space is composed of text and then image tokens
344 | out = torch.cat((out, sample), dim=-1)
345 |
346 | if out.shape[1] <= text_seq_len:
347 | mask = F.pad(mask, (0, 1), value = True)
348 |
349 | text_seq = out[:, :text_seq_len]
350 |
351 | img_seq = out[:, -image_seq_len:]
352 | images = vae.decode(img_seq)
353 |
354 | if exists(clip):
355 | scores = clip(text_seq, images, return_loss = False)
356 | return images, scores
357 |
358 | return images
359 |
360 | def forward(
361 | self,
362 | text,
363 | image = None,
364 | mask = None,
365 | return_loss = False
366 | ):
367 | device = text.device
368 | eos_token_id = self.total_tokens - 1
369 |
370 | tokens = self.text_emb(text)
371 | tokens += self.text_pos_emb(torch.arange(text.shape[1], device = device))
372 |
373 | seq_len = tokens.shape[1]
374 |
375 | if exists(image) and not is_empty(image):
376 | is_raw_image = len(image.shape) == 4
377 | if is_raw_image:
378 | image = self.vae.get_codebook_indices(image)
379 |
380 | image_len = image.shape[1]
381 | image_emb = self.image_emb(image)
382 | image_emb += self.image_pos_emb(image_emb)
383 |
384 | tokens = torch.cat((tokens, image_emb), dim = 1)
385 |
386 | seq_len += image_len
387 | if exists(mask):
388 | mask = F.pad(mask, (0, image_emb.shape[1]), value = True)
389 |
390 | out = self.transformer(tokens, mask = mask)
391 | logits = self.to_logits(out)
392 |
393 | # mask logits to make sure text predicts text (except last token), and image predicts image
394 | mask = self.logits_mask[:, :seq_len]
395 | max_neg_value = -torch.finfo(logits.dtype).max
396 | logits.masked_fill_(mask, max_neg_value)
397 |
398 | if not return_loss:
399 | return logits
400 |
401 | assert exists(image), 'when training, image must be supplied'
402 |
403 | offsetted_image = image + self.num_text_tokens
404 | labels = torch.cat((text, offsetted_image), dim = 1)
405 | labels = F.pad(labels, (0, 1), value = eos_token_id) # last token predicts EOS
406 | loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels[:, 1:])
407 | return loss
408 |
--------------------------------------------------------------------------------