├── dalle_pytorch ├── __init__.py ├── reversible.py ├── transformer.py └── dalle_pytorch.py ├── images └── landscape.png ├── install_deepspeed.sh ├── setup.py ├── .github └── workflows │ └── python-publish.yml ├── LICENSE ├── Vocabulary.py ├── mixVAEcuda.py ├── .gitignore ├── genDALLE.py ├── trainVAE.py ├── README.md └── trainDALLE.py /dalle_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from dalle_pytorch.dalle_pytorch import DALLE, CLIP, DiscreteVAE 2 | -------------------------------------------------------------------------------- /images/landscape.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htoyryla/DALLE-pytorch/HEAD/images/landscape.png -------------------------------------------------------------------------------- /install_deepspeed.sh: -------------------------------------------------------------------------------- 1 | sudo apt-get -y install llvm-9-dev cmake 2 | git clone https://github.com/microsoft/DeepSpeed.git /tmp/Deepspeed 3 | cd /tmp/Deepspeed && DS_BUILD_SPARSE_ATTN=1 ./install.sh -s 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'dalle-pytorch', 5 | packages = find_packages(), 6 | version = '0.0.36', 7 | license='MIT', 8 | description = 'DALL-E - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | url = 'https://github.com/lucidrains/dalle-pytorch', 12 | keywords = [ 13 | 'artificial intelligence', 14 | 'attention mechanism', 15 | 'transformers', 16 | 'text-to-image' 17 | ], 18 | install_requires=[ 19 | 'axial_positional_embedding', 20 | 'einops>=0.3', 21 | 'torch>=1.6' 22 | ], 23 | classifiers=[ 24 | 'Development Status :: 4 - Beta', 25 | 'Intended Audience :: Developers', 26 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 27 | 'License :: OSI Approved :: MIT License', 28 | 'Programming Language :: Python :: 3.6', 29 | ], 30 | ) 31 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Vocabulary.py: -------------------------------------------------------------------------------- 1 | # adapted from https://www.kdnuggets.com/2019/11/create-vocabulary-nlp-tasks-python.html 2 | 3 | class Vocabulary: 4 | #PAD_token = 0 # Used for padding short sentences 5 | #SOS_token = 1 # Start-of-sentence token 6 | #EOS_token = 2 # End-of-sentence token 7 | 8 | def __init__(self, name): 9 | self.name = name 10 | self.word2index = {} 11 | self.word2count = {} 12 | self.index2word = {0: "PAD", 1: "SOS", 2: "EOS"} 13 | self.num_words = 3 14 | self.num_sentences = 0 15 | self.longest_sentence = 0 16 | 17 | def add_word(self, word): 18 | if word not in self.word2index: 19 | # First entry of word into vocabulary 20 | self.word2index[word] = self.num_words 21 | self.word2count[word] = 1 22 | self.index2word[self.num_words] = word 23 | self.num_words += 1 24 | else: 25 | # Word exists; increase word count 26 | self.word2count[word] += 1 27 | 28 | def add_sentence(self, sentence): 29 | sentence_len = 0 30 | for word in sentence.split(' '): 31 | sentence_len += 1 32 | self.add_word(word) 33 | if sentence_len > self.longest_sentence: 34 | # This is the longest sentence 35 | self.longest_sentence = sentence_len 36 | # Count the number of sentences 37 | self.num_sentences += 1 38 | 39 | def to_word(self, index): 40 | return self.index2word[index] 41 | 42 | def to_index(self, word): 43 | return self.word2index[word] -------------------------------------------------------------------------------- /mixVAEcuda.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, optim 3 | from torchvision import datasets, transforms 4 | from torchvision.utils import save_image 5 | from torch.utils.data import DataLoader 6 | from dalle_pytorch import DiscreteVAE 7 | 8 | imgSize = 256 9 | load_epoch = 280 10 | 11 | vae = DiscreteVAE( 12 | image_size = imgSize, 13 | num_layers = 3, 14 | channels = 3, 15 | num_tokens = 2048, 16 | codebook_dim = 1024, 17 | hidden_dim = 128 18 | ) 19 | 20 | vae_dict = torch.load("./models/dvae-"+str(load_epoch)+".pth") 21 | vae.load_state_dict(vae_dict) 22 | vae.cuda() 23 | 24 | batchSize = 12 25 | n_epochs = 500 26 | log_interval = 20 27 | #images = torch.randn(4, 3, 256, 256) 28 | 29 | t = transforms.Compose([ 30 | transforms.Resize(imgSize), 31 | transforms.RandomHorizontalFlip(), 32 | transforms.ToTensor(), 33 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) #(0.267, 0.233, 0.234)) 34 | ]) 35 | 36 | train_set = datasets.ImageFolder('./imagedata', transform=t, target_transform=None) 37 | 38 | train_loader = DataLoader(dataset=train_set, num_workers=1, batch_size=batchSize, shuffle=True) 39 | 40 | 41 | 42 | for batch_idx, (images, _) in enumerate(train_loader): 43 | images = images.cuda() 44 | codes = vae.get_codebook_indices(images) 45 | sample1 = vae.decode(codes) 46 | #save_image(sample.view(-1, 3, imgSize, imgSize), 47 | # 'results/recon_sample_' + str(batch_idx) + '.png', normalize=True) 48 | for i in range(0, 8): 49 | j = i + 1 50 | j = j % 8 51 | codes[i,512:] = codes[j,512:] 52 | sample2 = vae.decode(codes) 53 | grid = torch.cat([images[:8], sample1[:8], sample2[:8]]) 54 | save_image(grid.view(-1, 3, imgSize, imgSize), 55 | 'mixed/mixed_epoch_' +str(load_epoch) + "_"+ str(batch_idx) + '.png', normalize=True) 56 | #break 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /genDALLE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dalle_pytorch import DiscreteVAE, DALLE 3 | from torchvision.io import read_image 4 | import torchvision.transforms as transforms 5 | import torch.optim as optim 6 | from torchvision.utils import save_image 7 | import time 8 | import sys 9 | 10 | # vae 11 | 12 | load_epoch = 390 13 | vaename = "vae-cdim256" 14 | 15 | # general 16 | 17 | imgSize = 256 18 | batchSize = 12 19 | n_epochs = 100 20 | log_interval = 10 21 | lr = 2e-5 22 | 23 | #dalle 24 | 25 | dalle_epoch = 220 26 | #loadfn = "" 27 | #start_epoch = 0 28 | name = "vae-cdim256" 29 | loadfn = "./models/dalle_"+name+"-"+str(dalle_epoch)+".pth" 30 | 31 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 32 | 33 | tf = transforms.Compose([ 34 | #transforms.Resize(imgSize), 35 | #transforms.RandomHorizontalFlip(), 36 | #transforms.ToTensor(), 37 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) #(0.267, 0.233, 0.234)) 38 | ]) 39 | 40 | vae = DiscreteVAE( 41 | image_size = 256, 42 | num_layers = 3, 43 | num_tokens = 2048, 44 | codebook_dim = 256, 45 | hidden_dim = 128, 46 | temperature = 0.9 47 | ) 48 | 49 | # load pretrained vae 50 | 51 | vae_dict = torch.load("./models/"+vaename+"-"+str(load_epoch)+".pth") 52 | vae.load_state_dict(vae_dict) 53 | vae.to(device) 54 | 55 | dalle = DALLE( 56 | dim = 256, #512, 57 | vae = vae, # automatically infer (1) image sequence length and (2) number of image tokens 58 | num_text_tokens = 10000, # vocab size for text 59 | text_seq_len = 256, # text sequence length 60 | depth = 6, # should be 64 61 | heads = 8, # attention heads 62 | dim_head = 64, # attention head dimension 63 | attn_dropout = 0.1, # attention dropout 64 | ff_dropout = 0.1 # feedforward dropout 65 | ) 66 | 67 | 68 | # load pretrained dalle if continuing training 69 | 70 | dalle_dict = torch.load(loadfn) 71 | dalle.load_state_dict(dalle_dict) 72 | 73 | dalle.to(device) 74 | 75 | # get image and text data 76 | 77 | lf = open("od-captionsonly.txt", "r") # file contains captions only, one caption per line 78 | 79 | # build vocabulary 80 | 81 | from Vocabulary import Vocabulary 82 | 83 | vocab = Vocabulary("captions") 84 | 85 | captions = [] 86 | for lin in lf: 87 | captions.append(lin) 88 | 89 | for caption in captions: 90 | vocab.add_sentence(caption) 91 | 92 | def tokenizer(text): # create a tokenizer function 93 | return text.split(' ') 94 | 95 | inp_text = sys.argv[1] 96 | print(inp_text) 97 | tokens = tokenizer(inp_text) 98 | codes = [] 99 | for t in tokens: 100 | codes.append(vocab.to_index(t)) 101 | 102 | print(codes) 103 | c_tokens = [0]*256 # fill to match text_seq_len 104 | c_tokens[:len(codes)] = codes 105 | 106 | text = torch.LongTensor(codes).unsqueeze(0).to(device) # a minibatch of text (numerical tokens) 107 | mask = torch.ones_like(text).bool().to(device) 108 | oimgs = dalle.generate_images(text, mask = mask) 109 | ts = int(time.time()) 110 | print(inp_text, ts) 111 | save_image(oimgs, 112 | 'results/gendalle'+name+'_epoch_' + str(dalle_epoch) + '-' +str(ts)+'.png', normalize=True) 113 | 114 | -------------------------------------------------------------------------------- /trainVAE.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from torch import nn, optim 4 | from torchvision import datasets, transforms 5 | from torchvision.utils import save_image 6 | from torch.utils.data import DataLoader 7 | import torch.nn.functional as F 8 | from dalle_pytorch import DiscreteVAE 9 | from torch.nn.utils import clip_grad_norm_ 10 | 11 | parser = argparse.ArgumentParser(description='train VAE for DALLE-pytorch') 12 | parser.add_argument('--batchSize', type=int, default=24, help='batch size for training (default: 24)') 13 | parser.add_argument('--dataPath', type=str, default="./imagedata", help='path to imageFolder (default: ./imagedata') 14 | parser.add_argument('--imageSize', type=int, default=256, help='image size for training (default: 256)') 15 | parser.add_argument('--n_epochs', type=int, default=500, help='number of epochs (default: 500)') 16 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate (default: 1e-4)') 17 | parser.add_argument('--tempsched', action='store_true', default=False, help='use temperature scheduling') 18 | parser.add_argument('--temperature', type=float, default=0.9, help='vae temperature (default: 0.9)') 19 | parser.add_argument('--name', type=str, default="vae", help='experiment name') 20 | parser.add_argument('--loadVAE', type=str, default="", help='name for pretrained VAE when continuing training') 21 | parser.add_argument('--start_epoch', type=int, default=0, help='start epoch numbering for continuing training (default: 0)') 22 | parser.add_argument('--clip', type=float, default=0, help='clip weights, 0 = no clipping (default: 0)') 23 | opt = parser.parse_args() 24 | 25 | imgSize = opt.imageSize #256 26 | batchSize = opt.batchSize #24 27 | n_epochs = opt.n_epochs #500 28 | log_interval = 10 29 | lr = opt.lr #1e-4 30 | temperature_scheduling = opt.tempsched #True 31 | 32 | name = opt.name #"v2vae256" 33 | 34 | # for continuing training 35 | # set loadfn: path to pretrained model 36 | # start_epoch: start epoch numbering from this 37 | loadfn = opt.loadVAE #"" 38 | start_epoch = opt.start_epoch #0 39 | 40 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 41 | 42 | vae = DiscreteVAE( 43 | image_size = imgSize, 44 | num_layers = 3, 45 | channels = 3, 46 | num_tokens = 2048, 47 | codebook_dim = 256, 48 | hidden_dim = 128, 49 | temperature = opt.temperature 50 | ) 51 | 52 | if loadfn != "": 53 | vae_dict = torch.load(loadfn) 54 | vae.load_state_dict(vae_dict) 55 | 56 | 57 | vae.to(device) 58 | 59 | t = transforms.Compose([ 60 | transforms.Resize(imgSize), 61 | transforms.ToTensor(), 62 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) #(0.267, 0.233, 0.234)) 63 | ]) 64 | 65 | train_set = datasets.ImageFolder(opt.dataPath, transform=t, target_transform=None) 66 | 67 | train_loader = DataLoader(dataset=train_set, num_workers=1, batch_size=batchSize, shuffle=True) 68 | 69 | optimizer = optim.Adam(vae.parameters(), lr=lr) 70 | 71 | def clampWeights(m): 72 | if type(m) != nn.BatchNorm2d and type(m) != nn.Sequential: 73 | for p in m.parameters(): 74 | p.data.clamp_(-opt.clip, opt.clip) 75 | 76 | if temperature_scheduling: 77 | vae.temperature = opt.temperature 78 | dk = 0.7 ** (1/len(train_loader)) 79 | print('Scale Factor:', dk) 80 | 81 | for epoch in range(start_epoch, start_epoch + n_epochs): 82 | 83 | train_loss = 0 84 | for batch_idx, (images, _) in enumerate(train_loader): 85 | images = images.to(device) 86 | recons = vae(images) 87 | loss = F.smooth_l1_loss(images, recons) + F.mse_loss(images, recons) 88 | 89 | 90 | optimizer.zero_grad() 91 | loss.backward() 92 | train_loss += loss.item() 93 | optimizer.step() 94 | 95 | if opt.clip > 0: 96 | vae.apply(clampWeights) 97 | 98 | if batch_idx % log_interval == 0: 99 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.8f}'.format( 100 | epoch, batch_idx * len(images), len(train_loader.dataset), 101 | 100. * batch_idx / len(train_loader), 102 | loss.item() / len(images))) 103 | 104 | if temperature_scheduling: 105 | vae.temperature *= dk 106 | print("Current temperature: ", vae.temperature) 107 | 108 | k = 8 109 | with torch.no_grad(): 110 | codes = vae.get_codebook_indices(images) 111 | imgx = vae.decode(codes) 112 | grid = torch.cat([images[:k], recons[:k], imgx[:k]]) 113 | save_image(grid, 114 | 'results/'+name+'_epoch_' + str(epoch) + '.png', normalize=True) 115 | 116 | print('====> Epoch: {} Average loss: {:.8f}'.format( 117 | epoch, train_loss / len(train_loader.dataset))) 118 | 119 | torch.save(vae.state_dict(), "./models/"+name+"-"+str(epoch)+".pth") 120 | 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /dalle_pytorch/reversible.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from operator import itemgetter 4 | from torch.autograd.function import Function 5 | from torch.utils.checkpoint import get_device_states, set_device_states 6 | 7 | # for routing arguments into the functions of the reversible layer 8 | def route_args(router, args, depth): 9 | routed_args = [(dict(), dict()) for _ in range(depth)] 10 | matched_keys = [key for key in args.keys() if key in router] 11 | 12 | for key in matched_keys: 13 | val = args[key] 14 | for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])): 15 | new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes) 16 | routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) 17 | return routed_args 18 | 19 | # following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html 20 | class Deterministic(nn.Module): 21 | def __init__(self, net): 22 | super().__init__() 23 | self.net = net 24 | self.cpu_state = None 25 | self.cuda_in_fwd = None 26 | self.gpu_devices = None 27 | self.gpu_states = None 28 | 29 | def record_rng(self, *args): 30 | self.cpu_state = torch.get_rng_state() 31 | if torch.cuda._initialized: 32 | self.cuda_in_fwd = True 33 | self.gpu_devices, self.gpu_states = get_device_states(*args) 34 | 35 | def forward(self, *args, record_rng = False, set_rng = False, **kwargs): 36 | if record_rng: 37 | self.record_rng(*args) 38 | 39 | if not set_rng: 40 | return self.net(*args, **kwargs) 41 | 42 | rng_devices = [] 43 | if self.cuda_in_fwd: 44 | rng_devices = self.gpu_devices 45 | 46 | with torch.random.fork_rng(devices=rng_devices, enabled=True): 47 | torch.set_rng_state(self.cpu_state) 48 | if self.cuda_in_fwd: 49 | set_device_states(self.gpu_devices, self.gpu_states) 50 | return self.net(*args, **kwargs) 51 | 52 | # heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 53 | # once multi-GPU is confirmed working, refactor and send PR back to source 54 | class ReversibleBlock(nn.Module): 55 | def __init__(self, f, g): 56 | super().__init__() 57 | self.f = Deterministic(f) 58 | self.g = Deterministic(g) 59 | 60 | def forward(self, x, f_args = {}, g_args = {}): 61 | x1, x2 = torch.chunk(x, 2, dim=2) 62 | y1, y2 = None, None 63 | 64 | with torch.no_grad(): 65 | y1 = x1 + self.f(x2, record_rng=self.training, **f_args) 66 | y2 = x2 + self.g(y1, record_rng=self.training, **g_args) 67 | 68 | return torch.cat([y1, y2], dim=2) 69 | 70 | def backward_pass(self, y, dy, f_args = {}, g_args = {}): 71 | y1, y2 = torch.chunk(y, 2, dim=2) 72 | del y 73 | 74 | dy1, dy2 = torch.chunk(dy, 2, dim=2) 75 | del dy 76 | 77 | with torch.enable_grad(): 78 | y1.requires_grad = True 79 | gy1 = self.g(y1, set_rng=True, **g_args) 80 | torch.autograd.backward(gy1, dy2) 81 | 82 | with torch.no_grad(): 83 | x2 = y2 - gy1 84 | del y2, gy1 85 | 86 | dx1 = dy1 + y1.grad 87 | del dy1 88 | y1.grad = None 89 | 90 | with torch.enable_grad(): 91 | x2.requires_grad = True 92 | fx2 = self.f(x2, set_rng=True, **f_args) 93 | torch.autograd.backward(fx2, dx1, retain_graph=True) 94 | 95 | with torch.no_grad(): 96 | x1 = y1 - fx2 97 | del y1, fx2 98 | 99 | dx2 = dy2 + x2.grad 100 | del dy2 101 | x2.grad = None 102 | 103 | x = torch.cat([x1, x2.detach()], dim=2) 104 | dx = torch.cat([dx1, dx2], dim=2) 105 | 106 | return x, dx 107 | 108 | class _ReversibleFunction(Function): 109 | @staticmethod 110 | def forward(ctx, x, blocks, args): 111 | ctx.args = args 112 | for block, kwarg in zip(blocks, args): 113 | x = block(x, **kwarg) 114 | ctx.y = x.detach() 115 | ctx.blocks = blocks 116 | return x 117 | 118 | @staticmethod 119 | def backward(ctx, dy): 120 | y = ctx.y 121 | args = ctx.args 122 | for block, kwargs in zip(ctx.blocks[::-1], args[::-1]): 123 | y, dy = block.backward_pass(y, dy, **kwargs) 124 | return dy, None, None 125 | 126 | class SequentialSequence(nn.Module): 127 | def __init__(self, layers, args_route = {}, layer_dropout = 0.): 128 | super().__init__() 129 | assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers' 130 | self.layers = layers 131 | self.args_route = args_route 132 | self.layer_dropout = layer_dropout 133 | 134 | def forward(self, x, **kwargs): 135 | args = route_args(self.args_route, kwargs, len(self.layers)) 136 | layers_and_args = list(zip(self.layers, args)) 137 | 138 | for (f, g), (f_args, g_args) in layers_and_args: 139 | x = x + f(x, **f_args) 140 | x = x + g(x, **g_args) 141 | return x 142 | 143 | class ReversibleSequence(nn.Module): 144 | def __init__(self, blocks, args_route = {}): 145 | super().__init__() 146 | self.args_route = args_route 147 | self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks]) 148 | 149 | def forward(self, x, **kwargs): 150 | x = torch.cat([x, x], dim=-1) 151 | 152 | blocks = self.blocks 153 | args = route_args(self.args_route, kwargs, len(blocks)) 154 | args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args)) 155 | 156 | out = _ReversibleFunction.apply(x, blocks, args) 157 | return torch.stack(out.chunk(2, dim=-1)).mean(dim=0) 158 | -------------------------------------------------------------------------------- /dalle_pytorch/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from inspect import isfunction 3 | from torch import nn, einsum 4 | import torch.nn.functional as F 5 | from einops import rearrange, repeat 6 | 7 | from dalle_pytorch.reversible import ReversibleSequence, SequentialSequence 8 | 9 | # helpers 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | def default(val, d): 15 | if exists(val): 16 | return val 17 | return d() if isfunction(d) else d 18 | 19 | def cast_tuple(val, depth): 20 | return val if isinstance(val, tuple) else (val,) * depth 21 | 22 | # classes 23 | 24 | class PreNorm(nn.Module): 25 | def __init__(self, dim, fn): 26 | super().__init__() 27 | self.norm = nn.LayerNorm(dim) 28 | self.fn = fn 29 | 30 | def forward(self, x, **kwargs): 31 | return self.fn(self.norm(x), **kwargs) 32 | 33 | class GEGLU(nn.Module): 34 | def forward(self, x): 35 | x, gates = x.chunk(2, dim = -1) 36 | return x * F.gelu(gates) 37 | 38 | class FeedForward(nn.Module): 39 | def __init__(self, dim, dropout = 0., mult = 4.): 40 | super().__init__() 41 | self.net = nn.Sequential( 42 | nn.Linear(dim, dim * mult * 2), 43 | GEGLU(), 44 | nn.Dropout(dropout), 45 | nn.Linear(dim * mult, dim) 46 | ) 47 | 48 | def forward(self, x): 49 | return self.net(x) 50 | 51 | class Attention(nn.Module): 52 | def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0.): 53 | super().__init__() 54 | inner_dim = dim_head * heads 55 | self.heads = heads 56 | self.seq_len = seq_len 57 | self.scale = dim ** -0.5 58 | self.causal = causal 59 | 60 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 61 | self.to_out = nn.Sequential( 62 | nn.Linear(inner_dim, dim), 63 | nn.Dropout(dropout) 64 | ) 65 | 66 | def forward(self, x, mask = None): 67 | b, n, _, h, device = *x.shape, self.heads, x.device 68 | qkv = self.to_qkv(x).chunk(3, dim = -1) 69 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 70 | 71 | dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 72 | mask_value = -torch.finfo(dots.dtype).max 73 | 74 | if exists(mask): 75 | mask = rearrange(mask, 'b i -> b () i ()') * rearrange(mask, 'b j -> b () () j') 76 | dots.masked_fill_(~mask, mask_value) 77 | del mask 78 | 79 | if self.causal: 80 | i, j = dots.shape[-2:] 81 | mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() 82 | dots.masked_fill_(mask, mask_value) 83 | 84 | attn = dots.softmax(dim=-1) 85 | 86 | out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) 87 | out = rearrange(out, 'b h n d -> b n (h d)') 88 | out = self.to_out(out) 89 | return out 90 | 91 | class SparseAttention(Attention): 92 | def __init__(self, *args, **kwargs): 93 | super().__init__(*args, **kwargs) 94 | from deepspeed.ops.sparse_attention import SparseSelfAttention, VariableSparsityConfig 95 | self.block_size = 16 96 | 97 | self.attn_fn = SparseSelfAttention( 98 | sparsity_config = VariableSparsityConfig( 99 | num_heads = self.heads, 100 | block = self.block_size, 101 | attention = 'unidirectional' if self.causal else 'bidirectional' 102 | ), 103 | max_seq_length = self.seq_len, 104 | attn_mask_mode = 'add' 105 | ) 106 | 107 | def forward(self, x, mask = None): 108 | b, n, _, h, device = *x.shape, self.heads, x.device 109 | remainder = n % self.block_size 110 | mask = default(mask, lambda: torch.ones(b, n, device = device).bool()) 111 | 112 | if remainder > 0: 113 | padding = self.block_size - remainder 114 | x = F.pad(x, (0, 0, 0, padding), value = 0) 115 | mask = F.pad(mask, (0, padding), value = False) 116 | 117 | qkv = self.to_qkv(x).chunk(3, dim = -1) 118 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 119 | 120 | key_pad_mask = None 121 | if exists(mask): 122 | key_pad_mask = ~mask 123 | 124 | attn_mask = None 125 | if self.causal: 126 | i, j = q.shape[-2], k.shape[-2] 127 | mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() 128 | attn_mask = torch.zeros(i, j, device = device).to(q) 129 | mask_value = -(torch.finfo(q.dtype).max / 2) 130 | attn_mask.masked_fill_(mask, mask_value) 131 | 132 | out = self.attn_fn(q, k, v, attn_mask = attn_mask, key_padding_mask = key_pad_mask) 133 | out = rearrange(out, 'b h n d -> b n (h d)') 134 | out = self.to_out(out) 135 | return out[:, :n] 136 | 137 | class Transformer(nn.Module): 138 | def __init__( 139 | self, 140 | *, 141 | dim, 142 | depth, 143 | seq_len, 144 | reversible = False, 145 | causal = True, 146 | heads = 8, 147 | dim_head = 64, 148 | ff_mult = 4, 149 | attn_dropout = 0., 150 | ff_dropout = 0., 151 | sparse_attn = True 152 | ): 153 | super().__init__() 154 | layers = nn.ModuleList([]) 155 | sparse_layer = cast_tuple(sparse_attn, depth) 156 | 157 | for _, sparse_attn in zip(range(depth), sparse_layer): 158 | attn_class = Attention if not sparse_attn else SparseAttention 159 | 160 | layers.append(nn.ModuleList([ 161 | PreNorm(dim, attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)), 162 | PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = ff_dropout)) 163 | ])) 164 | 165 | execute_type = ReversibleSequence if reversible else SequentialSequence 166 | route_attn = ((True, False),) * depth 167 | attn_route_map = {'mask': route_attn} 168 | 169 | self.layers = execute_type(layers, args_route = attn_route_map) 170 | 171 | def forward(self, x, **kwargs): 172 | return self.layers(x, **kwargs) 173 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## DALL-E in Pytorch 2 | 3 | Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch. It will also contain CLIP for ranking the generations. 4 | 5 | Sid, Ben, and Aran over at Eleuther AI are working on DALL-E for Mesh Tensorflow! Please lend them a hand if you would like to see DALL-E trained on TPUs. 6 | 7 | Yannic Kilcher's video 8 | 9 | ## Status 10 | 11 | Hannu has managed to train a small 6 layer DALL-E on a dataset of just 2000 landscape images! (2048 visual tokens) 12 | 13 | 14 | 15 | ## Install 16 | 17 | ```bash 18 | $ pip install dalle-pytorch 19 | ``` 20 | 21 | ## Usage 22 | 23 | Train VAE 24 | 25 | ```python 26 | import torch 27 | from dalle_pytorch import DiscreteVAE 28 | 29 | vae = DiscreteVAE( 30 | image_size = 256, 31 | num_layers = 3, # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map) 32 | num_tokens = 1024, # number of visual tokens. iGPT had 512, so probably should have more 33 | codebook_dim = 512, # codebook dimension 34 | hidden_dim = 64, # hidden dimension 35 | temperature = 0.9 # gumbel softmax temperature, the lower this is, the more hard the discretization 36 | ) 37 | 38 | images = torch.randn(4, 3, 256, 256) 39 | 40 | loss = vae(images, return_recon_loss = True) 41 | loss.backward() 42 | 43 | # train with a lot of data to learn a good codebook 44 | ``` 45 | 46 | Train DALL-E with pretrained VAE from above 47 | 48 | ```python 49 | import torch 50 | from dalle_pytorch import DiscreteVAE, DALLE 51 | 52 | vae = DiscreteVAE( 53 | image_size = 256, 54 | num_layers = 3, 55 | num_tokens = 1024, 56 | codebook_dim = 512, 57 | hidden_dim = 64, 58 | temperature = 0.9 59 | ) 60 | 61 | dalle = DALLE( 62 | dim = 512, 63 | vae = vae, # automatically infer (1) image sequence length and (2) number of image tokens 64 | num_text_tokens = 10000, # vocab size for text 65 | text_seq_len = 256, # text sequence length 66 | depth = 6, # should be 64 67 | heads = 8, # attention heads 68 | dim_head = 64, # attention head dimension 69 | attn_dropout = 0.1, # attention dropout 70 | ff_dropout = 0.1 # feedforward dropout 71 | ) 72 | 73 | text = torch.randint(0, 10000, (4, 256)) 74 | images = torch.randn(4, 3, 256, 256) 75 | mask = torch.ones_like(text).bool() 76 | 77 | loss = dalle(text, images, mask = mask, return_loss = True) 78 | loss.backward() 79 | 80 | # do the above for a long time with a lot of data ... then 81 | 82 | images = dalle.generate_images(text, mask = mask) 83 | images.shape # (2, 3, 256, 256) 84 | ``` 85 | 86 | ## Ranking the generations 87 | 88 | Train CLIP 89 | 90 | ```python 91 | import torch 92 | from dalle_pytorch import CLIP 93 | 94 | clip = CLIP( 95 | dim_text = 512, 96 | dim_image = 512, 97 | dim_latent = 512, 98 | num_text_tokens = 10000, 99 | text_enc_depth = 6, 100 | text_seq_len = 256, 101 | text_heads = 8, 102 | num_visual_tokens = 512, 103 | visual_enc_depth = 6, 104 | visual_image_size = 256, 105 | visual_patch_size = 32, 106 | visual_heads = 8 107 | ) 108 | 109 | text = torch.randint(0, 10000, (4, 256)) 110 | images = torch.randn(4, 3, 256, 256) 111 | mask = torch.ones_like(text).bool() 112 | 113 | loss = clip(text, images, text_mask = mask, return_loss = True) 114 | loss.backward() 115 | ``` 116 | 117 | To get the similarity scores from your trained Clipper, just do 118 | 119 | ```python 120 | images, scores = dalle.generate_images(text, mask = mask, clip = clip) 121 | 122 | scores.shape # (2,) 123 | images.shape # (2, 3, 256, 256) 124 | 125 | # do your topk here, in paper they sampled 512 and chose top 32 126 | ``` 127 | 128 | Or you can just use the official CLIP model to rank the images from DALL-E 129 | 130 | ## Scaling depth 131 | 132 | In the blog post, they used 64 layers to achieve their results. I added reversible networks, from the Reformer paper, in order for users to attempt to scale depth at the cost of compute. Reversible networks allow you to scale to any depth at no memory cost, but a little over 2x compute cost (each layer is rerun on the backward pass). 133 | 134 | Simply set the `reversible` keyword to `True` for the `DALLE` class 135 | 136 | ```python 137 | dalle = DALLE( 138 | dim = 512, 139 | vae = vae, 140 | num_text_tokens = 10000, 141 | text_seq_len = 256, 142 | depth = 64, 143 | heads = 8, 144 | reversible = True # <-- reversible networks https://arxiv.org/abs/2001.04451 145 | ) 146 | ``` 147 | 148 | ## Sparse Attention 149 | 150 | You can also train with Microsoft Deepspeed's Sparse Attention, with any combination of dense and sparse attention that you'd like. However, you will have to endure the installation process. 151 | 152 | First, you need to install Deepspeed with Sparse Attention 153 | 154 | ```bash 155 | $ sh install_deepspeed.sh 156 | ``` 157 | 158 | Next, you need to install the pip package `triton` 159 | 160 | ```bash 161 | $ pip install triton 162 | ``` 163 | 164 | If both of the above succeeded, now you can train with Sparse Attention! 165 | 166 | ```python 167 | dalle = DALLE( 168 | dim = 512, 169 | vae = vae, 170 | num_text_tokens = 10000, 171 | text_seq_len = 256, 172 | depth = 64, 173 | heads = 8, 174 | sparse_attn = (True, False) * 32 # interleave sparse and dense attention for 64 layers 175 | ) 176 | ``` 177 | 178 | ## Citations 179 | 180 | ```bibtex 181 | @misc{unpublished2021dalle, 182 | title = {DALL·E: Creating Images from Text}, 183 | author = {Aditya Ramesh, Mikhail Pavlov, Gabriel Goh, Scott Gray}, 184 | year = {2021} 185 | } 186 | ``` 187 | 188 | ```bibtex 189 | @misc{unpublished2021clip, 190 | title = {CLIP: Connecting Text and Images}, 191 | author = {Alec Radford, Ilya Sutskever, Jong Wook Kim, Gretchen Krueger, Sandhini Agarwal}, 192 | year = {2021} 193 | } 194 | ``` 195 | 196 | ```bibtex 197 | @misc{kitaev2020reformer, 198 | title = {Reformer: The Efficient Transformer}, 199 | author = {Nikita Kitaev and Łukasz Kaiser and Anselm Levskaya}, 200 | year = {2020}, 201 | eprint = {2001.04451}, 202 | archivePrefix = {arXiv}, 203 | primaryClass = {cs.LG} 204 | } 205 | ``` 206 | 207 | *Those who do not want to imitate anything, produce nothing.* - Dali 208 | -------------------------------------------------------------------------------- /trainDALLE.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from dalle_pytorch import DiscreteVAE, DALLE 4 | from torchvision.io import read_image 5 | import torchvision.transforms as transforms 6 | import torch.optim as optim 7 | from torchvision.utils import save_image 8 | 9 | parser = argparse.ArgumentParser(description='train VAE for DALLE-pytorch') 10 | parser.add_argument('--batchSize', type=int, default=24, help='batch size for training (default: 24)') 11 | parser.add_argument('--dataPath', type=str, default="./imagedata", help='path to imageFolder (default: ./imagedata') 12 | parser.add_argument('--imageSize', type=int, default=256, help='image size for training (default: 256)') 13 | parser.add_argument('--n_epochs', type=int, default=500, help='number of epochs (default: 500)') 14 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate (default: 1e-4)') 15 | #parser.add_argument('--tempsched', action='store_true', default=False, help='use temperature scheduling') 16 | #parser.add_argument('--temperature', type=float, default=0.9, help='vae temperature (default: 0.9)') 17 | parser.add_argument('--vaename', type=str, default="vae", help='experiment name') 18 | parser.add_argument('--vae_epoch', type=int, default=0, help='start epoch numbering for continuing training (default: 0)') 19 | parser.add_argument('--name', type=str, default="test", help='experiment name') 20 | parser.add_argument('--load_dalle', type=str, default="", help='name for pretrained VAE when continuing training') 21 | parser.add_argument('--start_epoch', type=int, default=0, help='start epoch numbering for continuing training (default: 0)') 22 | opt = parser.parse_args() 23 | 24 | # vae 25 | 26 | load_epoch = opt.vae_epoch #499 27 | vaename = opt.vaename #"v2vae256" 28 | 29 | # general 30 | 31 | imgSize = opt.imageSize #256 32 | batchSize = opt.batchSize #24 33 | n_epochs = opt.n_epochs #500 34 | log_interval = 10 35 | lr = opt.lr #1e-4 36 | 37 | #dalle 38 | 39 | # to continue training from a saved checkpoint, give checkpoint path as loadfn and start_epoch 40 | 41 | #loadfn = "./models/dalle_vae-cdim256-140.pth" 42 | #start_epoch = 140 43 | loadfn = opt.load_dalle 44 | start_epoch = opt.start_epoch 45 | name = opt.name #v2vae256 46 | 47 | 48 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 49 | 50 | tf = transforms.Compose([ 51 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 52 | ]) 53 | 54 | vae = DiscreteVAE( 55 | image_size = opt.imageSize, 56 | num_layers = 3, 57 | channels = 3, 58 | num_tokens = 2048, 59 | codebook_dim = 256, 60 | hidden_dim = 128, 61 | temperature = 0.9 62 | ) 63 | 64 | # load pretrained vae 65 | print("loading VAE from ./models/"+vaename+"-"+str(load_epoch)+".pth") 66 | vae_dict = torch.load("./models/"+vaename+"-"+str(load_epoch)+".pth") 67 | vae.load_state_dict(vae_dict) 68 | vae.to(device) 69 | 70 | dalle = DALLE( 71 | dim = 256, #512, 72 | vae = vae, # automatically infer (1) image sequence length and (2) number of image tokens 73 | num_text_tokens = 10000, # vocab size for text 74 | text_seq_len = 256, # text sequence length 75 | depth = 6, # should be 64 76 | heads = 8, # attention heads 77 | dim_head = 64, # attention head dimension 78 | attn_dropout = 0.1, # attention dropout 79 | ff_dropout = 0.1 # feedforward dropout 80 | ) 81 | 82 | 83 | # load pretrained dalle if continuing training 84 | if loadfn != "": 85 | dalle_dict = torch.load(loadfn) 86 | dalle.load_state_dict(dalle_dict) 87 | 88 | dalle.to(device) 89 | 90 | # get image and text data 91 | 92 | lf = open("od-captionsonly.txt", "r") # file contains captions only, one caption per line 93 | 94 | # build vocabulary 95 | 96 | from Vocabulary import Vocabulary 97 | 98 | vocab = Vocabulary("captions") 99 | 100 | captions = [] 101 | for lin in lf: 102 | captions.append(lin) 103 | 104 | for caption in captions: 105 | vocab.add_sentence(caption) 106 | 107 | def tokenizer(text): # create a tokenizer function 108 | return text.split(' ') 109 | 110 | 111 | lf = open("od-captions.txt", "r") # files contains lines in the format image_path : captions 112 | 113 | data = [] 114 | 115 | for lin in lf: 116 | (fn, txt) = lin.split(":") 117 | tokens = tokenizer(txt) 118 | codes = [] 119 | for t in tokens: 120 | #print(t) 121 | if t=="": 122 | continue 123 | codes.append(vocab.to_index(t)) 124 | #print(fn, codes) 125 | data.append((fn, codes)) 126 | 127 | 128 | 129 | len_data = len(data) 130 | print(len_data) 131 | #datactr = 0 132 | 133 | # an iterator for fetching data during training 134 | 135 | class ImageCaptions: 136 | 137 | def __init__(self, data, batchsize=4): 138 | self.data = data 139 | self.len = len(data) 140 | self.index = 0 141 | self.end = False 142 | self.batchsize = batchsize 143 | 144 | def __iter__(self): 145 | return self 146 | 147 | def __next__(self): 148 | if self.end: 149 | self.index = 0 150 | raise StopIteration 151 | i_data = [] 152 | c_data = [] 153 | for i in range(0, self.batchsize): 154 | i_data.append(self.data[self.index][0]) 155 | c_tokens = [0]*256 # fill to match text_seq_len 156 | c_tokens_ = self.data[self.index][1] 157 | c_tokens[:len(c_tokens_)] = c_tokens_ 158 | c_data.append(c_tokens) 159 | self.index += 1 160 | if self.index == self.len: 161 | self.end = True 162 | break 163 | return i_data, c_data 164 | 165 | 166 | optimizer = optim.Adam(dalle.parameters(), lr=lr) 167 | 168 | for epoch in range(start_epoch, start_epoch+n_epochs): 169 | batch_idx = 0 170 | train_loss = 0 171 | dset = ImageCaptions(data, batchsize=batchSize) # initialize iterator 172 | 173 | for i,c in dset: # loop through dataset by minibatch 174 | text = torch.LongTensor(c) # a minibatch of text (numerical tokens) 175 | images = torch.zeros(len(i), 3, 256, 256) # placeholder for images 176 | 177 | text = text.to(device) 178 | #print(text) 179 | 180 | # fetch images into tensor based on paths given in minibatch 181 | ix = 0 182 | for imgfn in i: # iterate through image paths in minibatch 183 | 184 | # note: images are expected to be in ./imagefolder/0/ 185 | img_t = read_image(opt.dataPath+"/0/"+imgfn).float()/255. # read image and scale into float 0..1 186 | img_t = tf(img_t) # normalize 187 | images[ix,:,:,:] = img_t 188 | ix += 1 189 | 190 | images = images.to(device) 191 | 192 | mask = torch.ones_like(text).bool().to(device) 193 | 194 | # train and optimize a single minibatch 195 | optimizer.zero_grad() 196 | loss = dalle(text, images, mask = mask, return_loss = True) 197 | train_loss += loss.item() 198 | loss.backward() 199 | optimizer.step() 200 | 201 | if batch_idx % log_interval == 0: 202 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 203 | epoch, batch_idx * len(i), len(data), 204 | 100. * batch_idx / int(round(len(data)/batchSize)), 205 | loss.item() / len(i))) 206 | 207 | batch_idx += 1 208 | 209 | print('====> Epoch: {} Average loss: {:.4f}'.format( 210 | epoch, train_loss / len(data))) 211 | 212 | torch.save(dalle.state_dict(), "./models/"+name+"_dalle_"+str(epoch)+".pth") 213 | 214 | # generate a test sample from the captions in the last minibatch 215 | oimgs = dalle.generate_images(text, mask = mask) 216 | save_image(oimgs, 217 | 'results/'+name+'_dalle_epoch_' + str(epoch) + '.png', normalize=True) 218 | 219 | -------------------------------------------------------------------------------- /dalle_pytorch/dalle_pytorch.py: -------------------------------------------------------------------------------- 1 | from math import log2, sqrt 2 | import torch 3 | from torch import nn, einsum 4 | import torch.nn.functional as F 5 | 6 | from einops import rearrange 7 | from axial_positional_embedding import AxialPositionalEmbedding 8 | from dalle_pytorch.transformer import Transformer 9 | 10 | # helpers 11 | 12 | def exists(val): 13 | return val is not None 14 | 15 | def default(val, d): 16 | return val if exists(val) else d 17 | 18 | def always(val): 19 | def inner(*args, **kwargs): 20 | return val 21 | return inner 22 | 23 | def is_empty(t): 24 | return t.nelement() == 0 25 | 26 | def masked_mean(t, mask, dim = 1): 27 | t = t.masked_fill(~mask[:, :, None], 0.) 28 | return t.sum(dim = 1) / mask.sum(dim = 1)[..., None] 29 | 30 | def eval_decorator(fn): 31 | def inner(model, *args, **kwargs): 32 | was_training = model.training 33 | model.eval() 34 | out = fn(model, *args, **kwargs) 35 | model.train(was_training) 36 | return out 37 | return inner 38 | 39 | # sampling helpers 40 | 41 | def top_k(logits, thres = 0.5): 42 | num_logits = logits.shape[-1] 43 | k = max(int((1 - thres) * num_logits), 1) 44 | val, ind = torch.topk(logits, k) 45 | probs = torch.full_like(logits, float('-inf')) 46 | probs.scatter_(1, ind, val) 47 | return probs 48 | 49 | # discrete vae class 50 | 51 | class ResBlock(nn.Module): 52 | def __init__(self, chan): 53 | super().__init__() 54 | self.net = nn.Sequential( 55 | nn.Conv2d(chan, chan, 3, padding = 1), 56 | nn.ReLU(), 57 | nn.Conv2d(chan, chan, 3, padding = 1), 58 | nn.ReLU(), 59 | nn.Conv2d(chan, chan, 1) 60 | ) 61 | 62 | def forward(self, x): 63 | return self.net(x) + x 64 | 65 | class DiscreteVAE(nn.Module): 66 | def __init__( 67 | self, 68 | image_size = 256, 69 | num_tokens = 512, 70 | codebook_dim = 512, 71 | num_layers = 3, 72 | num_resnet_blocks = 0, 73 | hidden_dim = 64, 74 | channels = 3, 75 | temperature = 0.9 76 | ): 77 | super().__init__() 78 | assert log2(image_size).is_integer(), 'image size must be a power of 2' 79 | assert num_layers >= 1, 'number of layers must be greater than or equal to 1' 80 | has_resblocks = num_resnet_blocks > 0 81 | 82 | self.image_size = image_size 83 | self.num_tokens = num_tokens 84 | self.num_layers = num_layers 85 | self.temperature = temperature 86 | self.codebook = nn.Embedding(num_tokens, codebook_dim) 87 | 88 | hdim = hidden_dim 89 | 90 | enc_chans = [hidden_dim] * num_layers 91 | dec_chans = list(reversed(enc_chans)) 92 | 93 | enc_chans = [channels, *enc_chans] 94 | 95 | dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0] 96 | dec_chans = [dec_init_chan, *dec_chans] 97 | 98 | enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans)) 99 | 100 | enc_layers = [] 101 | dec_layers = [] 102 | 103 | for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io): 104 | enc_layers.append(nn.Sequential(nn.Conv2d(enc_in, enc_out, 4, stride = 2, padding = 1), nn.ReLU())) 105 | dec_layers.append(nn.Sequential(nn.ConvTranspose2d(dec_in, dec_out, 4, stride = 2, padding = 1), nn.ReLU())) 106 | 107 | for _ in range(num_resnet_blocks): 108 | dec_layers.insert(0, ResBlock(dec_chans[1])) 109 | enc_layers.append(ResBlock(enc_chans[-1])) 110 | 111 | if num_resnet_blocks > 0: 112 | dec_layers.insert(0, nn.Conv2d(codebook_dim, dec_chans[1], 1)) 113 | 114 | enc_layers.append(nn.Conv2d(enc_chans[-1], num_tokens, 1)) 115 | dec_layers.append(nn.Conv2d(dec_chans[-1], channels, 1)) 116 | 117 | self.encoder = nn.Sequential(*enc_layers) 118 | self.decoder = nn.Sequential(*dec_layers) 119 | 120 | @torch.no_grad() 121 | def get_codebook_indices(self, images): 122 | logits = self.forward(images, return_logits = True) 123 | codebook_indices = logits.argmax(dim = 1).flatten(1) 124 | return codebook_indices 125 | 126 | def decode( 127 | self, 128 | img_seq 129 | ): 130 | image_embeds = self.codebook(img_seq) 131 | b, n, d = image_embeds.shape 132 | h = w = int(sqrt(n)) 133 | 134 | image_embeds = rearrange(image_embeds, 'b (h w) d -> b d h w', h = h, w = w) 135 | images = self.decoder(image_embeds) 136 | return images 137 | 138 | def forward( 139 | self, 140 | img, 141 | return_recon_loss = False, 142 | return_logits = False 143 | ): 144 | logits = self.encoder(img) 145 | 146 | if return_logits: 147 | return logits # return logits for getting hard image indices for DALL-E training 148 | 149 | soft_one_hot = F.gumbel_softmax(logits, tau = self.temperature, dim = 1) 150 | sampled = einsum('b n h w, n d -> b d h w', soft_one_hot, self.codebook.weight) 151 | out = self.decoder(sampled) 152 | 153 | if not return_recon_loss: 154 | return out 155 | 156 | loss = F.mse_loss(img, out) 157 | return loss 158 | 159 | # main classes 160 | 161 | class CLIP(nn.Module): 162 | def __init__( 163 | self, 164 | *, 165 | dim_text = 512, 166 | dim_image = 512, 167 | dim_latent = 512, 168 | num_text_tokens = 10000, 169 | text_enc_depth = 6, 170 | text_seq_len = 256, 171 | text_heads = 8, 172 | num_visual_tokens = 512, 173 | visual_enc_depth = 6, 174 | visual_heads = 8, 175 | visual_image_size = 256, 176 | visual_patch_size = 32, 177 | channels = 3 178 | ): 179 | super().__init__() 180 | self.text_emb = nn.Embedding(num_text_tokens, dim_text) 181 | self.text_pos_emb = nn.Embedding(text_seq_len, dim_text) 182 | self.text_transformer = Transformer(causal = False, seq_len = text_seq_len, dim = dim_text, depth = text_enc_depth, heads = text_heads) 183 | self.to_text_latent = nn.Linear(dim_text, dim_latent, bias = False) 184 | 185 | assert visual_image_size % visual_patch_size == 0, 'Image dimensions must be divisible by the patch size.' 186 | num_patches = (visual_image_size // visual_patch_size) ** 2 187 | patch_dim = channels * visual_patch_size ** 2 188 | 189 | self.visual_patch_size = visual_patch_size 190 | self.to_visual_embedding = nn.Linear(patch_dim, dim_image) 191 | self.visual_pos_emb = nn.Embedding(num_patches, dim_image) 192 | self.visual_transformer = Transformer(causal = False, seq_len = num_patches, dim = dim_image, depth = visual_enc_depth, heads = visual_heads) 193 | self.to_visual_latent = nn.Linear(dim_image, dim_latent, bias = False) 194 | 195 | self.temperature = nn.Parameter(torch.tensor(1.)) 196 | 197 | def forward( 198 | self, 199 | text, 200 | image, 201 | text_mask = None, 202 | return_loss = False 203 | ): 204 | b, device, p = text.shape[0], text.device, self.visual_patch_size 205 | 206 | text_emb = self.text_emb(text) 207 | text_emb += self.text_pos_emb(torch.arange(text.shape[1], device = device)) 208 | 209 | image_patches = rearrange(image, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) 210 | image_emb = self.to_visual_embedding(image_patches) 211 | image_emb += self.visual_pos_emb(torch.arange(image_emb.shape[1], device = device)) 212 | 213 | enc_text = self.text_transformer(text_emb, mask = text_mask) 214 | enc_image = self.visual_transformer(image_emb) 215 | 216 | if exists(text_mask): 217 | text_latents = masked_mean(enc_text, text_mask, dim = 1) 218 | else: 219 | text_latents = enc_text.mean(dim = 1) 220 | 221 | image_latents = enc_image.mean(dim = 1) 222 | 223 | text_latents = self.to_text_latent(text_latents) 224 | image_latents = self.to_visual_latent(image_latents) 225 | 226 | text_latents, image_latents = map(lambda t: F.normalize(t, p = 2, dim = -1), (text_latents, image_latents)) 227 | 228 | temp = self.temperature.exp() 229 | 230 | if not return_loss: 231 | sim = einsum('n d, n d -> n', text_latents, image_latents) * temp 232 | return sim 233 | 234 | sim = einsum('i d, j d -> i j', text_latents, image_latents) * temp 235 | labels = torch.arange(b, device = device) 236 | loss = F.cross_entropy(sim, labels) 237 | return loss 238 | 239 | # main DALL-E class 240 | 241 | class DALLE(nn.Module): 242 | def __init__( 243 | self, 244 | *, 245 | dim, 246 | vae, 247 | num_text_tokens = 10000, 248 | text_seq_len = 256, 249 | depth, 250 | heads = 8, 251 | dim_head = 64, 252 | reversible = False, 253 | attn_dropout = 0., 254 | ff_dropout = 0, 255 | sparse_attn = False 256 | ): 257 | super().__init__() 258 | assert isinstance(vae, DiscreteVAE), 'vae must be an instance of DiscreteVAE' 259 | 260 | image_size = vae.image_size 261 | num_image_tokens = vae.num_tokens 262 | image_seq_len = (vae.image_size // (2 ** vae.num_layers)) ** 2 263 | 264 | self.text_emb = nn.Embedding(num_text_tokens, dim) 265 | self.image_emb = nn.Embedding(num_image_tokens, dim) 266 | 267 | self.text_pos_emb = nn.Embedding(text_seq_len, dim) 268 | self.image_pos_emb = AxialPositionalEmbedding(dim, axial_shape = (image_size, image_size)) 269 | 270 | self.num_text_tokens = num_text_tokens # for offsetting logits index and calculating cross entropy loss 271 | self.num_image_tokens = num_image_tokens 272 | 273 | self.text_seq_len = text_seq_len 274 | self.image_seq_len = image_seq_len 275 | 276 | seq_len = text_seq_len + image_seq_len 277 | total_tokens = num_text_tokens + num_image_tokens + 1 # extra for EOS 278 | self.total_tokens = total_tokens 279 | 280 | self.vae = vae 281 | if exists(self.vae): 282 | self.vae = vae 283 | self.image_emb = vae.codebook 284 | 285 | self.transformer = Transformer( 286 | dim = dim, 287 | causal = True, 288 | seq_len = seq_len, 289 | depth = depth, 290 | heads = heads, 291 | dim_head = dim_head, 292 | reversible = reversible, 293 | attn_dropout = attn_dropout, 294 | ff_dropout = ff_dropout, 295 | sparse_attn = sparse_attn 296 | ) 297 | 298 | self.to_logits = nn.Sequential( 299 | nn.LayerNorm(dim), 300 | nn.Linear(dim, self.total_tokens), 301 | ) 302 | 303 | seq_range = torch.arange(seq_len) 304 | logits_range = torch.arange(total_tokens) 305 | 306 | seq_range = rearrange(seq_range, 'n -> () n ()') 307 | logits_range = rearrange(logits_range, 'd -> () () d') 308 | 309 | logits_mask = ( 310 | ((seq_range >= (text_seq_len - 1)) & (logits_range < num_text_tokens)) | 311 | ((seq_range < (text_seq_len - 1)) & (logits_range >= num_text_tokens)) | 312 | ((seq_range != (seq_len - 1)) & (logits_range >= (total_tokens - 1))) 313 | ) 314 | 315 | self.register_buffer('logits_mask', logits_mask) 316 | 317 | @torch.no_grad() 318 | @eval_decorator 319 | def generate_images( 320 | self, 321 | text, 322 | *, 323 | clip = None, 324 | mask = None, 325 | filter_thres = 0.5, 326 | temperature = 1. 327 | ): 328 | vae, text_seq_len, image_seq_len, num_text_tokens = self.vae, self.text_seq_len, self.image_seq_len, self.num_text_tokens 329 | total_len = text_seq_len + image_seq_len 330 | 331 | out = text 332 | for cur_len in range(text.shape[1], total_len): 333 | is_image = cur_len >= text_seq_len 334 | 335 | text, image = out[:, :text_seq_len], out[:, text_seq_len:] 336 | 337 | logits = self(text, image, mask = mask)[:, -1, :] 338 | 339 | filtered_logits = top_k(logits, thres = filter_thres) 340 | probs = F.softmax(filtered_logits / temperature, dim = -1) 341 | sample = torch.multinomial(probs, 1) 342 | 343 | sample -= (num_text_tokens if is_image else 0) # offset sampled token if it is an image token, since logit space is composed of text and then image tokens 344 | out = torch.cat((out, sample), dim=-1) 345 | 346 | if out.shape[1] <= text_seq_len: 347 | mask = F.pad(mask, (0, 1), value = True) 348 | 349 | text_seq = out[:, :text_seq_len] 350 | 351 | img_seq = out[:, -image_seq_len:] 352 | images = vae.decode(img_seq) 353 | 354 | if exists(clip): 355 | scores = clip(text_seq, images, return_loss = False) 356 | return images, scores 357 | 358 | return images 359 | 360 | def forward( 361 | self, 362 | text, 363 | image = None, 364 | mask = None, 365 | return_loss = False 366 | ): 367 | device = text.device 368 | eos_token_id = self.total_tokens - 1 369 | 370 | tokens = self.text_emb(text) 371 | tokens += self.text_pos_emb(torch.arange(text.shape[1], device = device)) 372 | 373 | seq_len = tokens.shape[1] 374 | 375 | if exists(image) and not is_empty(image): 376 | is_raw_image = len(image.shape) == 4 377 | if is_raw_image: 378 | image = self.vae.get_codebook_indices(image) 379 | 380 | image_len = image.shape[1] 381 | image_emb = self.image_emb(image) 382 | image_emb += self.image_pos_emb(image_emb) 383 | 384 | tokens = torch.cat((tokens, image_emb), dim = 1) 385 | 386 | seq_len += image_len 387 | if exists(mask): 388 | mask = F.pad(mask, (0, image_emb.shape[1]), value = True) 389 | 390 | out = self.transformer(tokens, mask = mask) 391 | logits = self.to_logits(out) 392 | 393 | # mask logits to make sure text predicts text (except last token), and image predicts image 394 | mask = self.logits_mask[:, :seq_len] 395 | max_neg_value = -torch.finfo(logits.dtype).max 396 | logits.masked_fill_(mask, max_neg_value) 397 | 398 | if not return_loss: 399 | return logits 400 | 401 | assert exists(image), 'when training, image must be supplied' 402 | 403 | offsetted_image = image + self.num_text_tokens 404 | labels = torch.cat((text, offsetted_image), dim = 1) 405 | labels = F.pad(labels, (0, 1), value = eos_token_id) # last token predicts EOS 406 | loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels[:, 1:]) 407 | return loss 408 | --------------------------------------------------------------------------------