├── requirements.txt ├── .gitignore ├── README.md ├── setup.py ├── dall_e ├── __init__.py ├── utils.py ├── encoder.py └── decoder.py ├── LICENSE ├── model_card.md └── notebooks └── usage.ipynb /requirements.txt: -------------------------------------------------------------------------------- 1 | Pillow 2 | blobfile 3 | mypy 4 | numpy 5 | pytest 6 | requests 7 | torch 8 | torchvision 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # OS specific 2 | *.DS_Store 3 | 4 | # Python 5 | /build 6 | /dist 7 | __pycache__ 8 | *.ipynb_checkpoints 9 | *.egg-info 10 | 11 | # Vim 12 | *.vim 13 | *.swk 14 | *.swl 15 | *.swm 16 | *.swn 17 | *.swo 18 | *.swp 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | [[Blog]](https://openai.com/blog/dall-e/) [[Paper]](https://arxiv.org/abs/2102.12092) [[Model Card]](model_card.md) [[Usage]](notebooks/usage.ipynb) 4 | 5 | This is the official PyTorch package for the discrete VAE used for DALL·E. The transformer used to generate the images from the text is not part of this code release. 6 | 7 | # Installation 8 | 9 | Before running [the example notebook](notebooks/usage.ipynb), you will need to install the package using 10 | 11 | pip install DALL-E 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | def parse_requirements(filename): 4 | lines = (line.strip() for line in open(filename)) 5 | return [line for line in lines if line and not line.startswith("#")] 6 | 7 | setup(name='DALL-E', 8 | version='0.1', 9 | description='PyTorch package for the discrete VAE used for DALL·E.', 10 | url='http://github.com/openai/DALL-E', 11 | author='Aditya Ramesh', 12 | author_email='aramesh@openai.com', 13 | license='BSD', 14 | packages=['dall_e'], 15 | install_requires=parse_requirements('requirements.txt'), 16 | zip_safe=True) 17 | -------------------------------------------------------------------------------- /dall_e/__init__.py: -------------------------------------------------------------------------------- 1 | import io, requests 2 | import torch 3 | import torch.nn as nn 4 | 5 | from dall_e.encoder import Encoder 6 | from dall_e.decoder import Decoder 7 | from dall_e.utils import map_pixels, unmap_pixels 8 | 9 | def load_model(path: str, device: torch.device = None) -> nn.Module: 10 | if path.startswith('http://') or path.startswith('https://'): 11 | resp = requests.get(path) 12 | resp.raise_for_status() 13 | 14 | with io.BytesIO(resp.content) as buf: 15 | return torch.load(buf, map_location=device) 16 | else: 17 | with open(path, 'rb') as f: 18 | return torch.load(f, map_location=device) 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Modified MIT License 2 | 3 | Software Copyright (c) 2021 OpenAI 4 | 5 | We don’t claim ownership of the content you create with the DALL-E discrete VAE, so it is yours to 6 | do with as you please. We only ask that you use the model responsibly and clearly indicate that it 7 | was used. 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 10 | associated documentation files (the "Software"), to deal in the Software without restriction, 11 | including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, 12 | and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, 13 | subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included 16 | in all copies or substantial portions of the Software. 17 | The above copyright notice and this permission notice need not be included 18 | with content created by the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 21 | INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 23 | BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 24 | TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE 25 | OR OTHER DEALINGS IN THE SOFTWARE. 26 | -------------------------------------------------------------------------------- /dall_e/utils.py: -------------------------------------------------------------------------------- 1 | import attr 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | logit_laplace_eps: float = 0.1 9 | 10 | @attr.s(eq=False) 11 | class Conv2d(nn.Module): 12 | n_in: int = attr.ib(validator=lambda i, a, x: x >= 1) 13 | n_out: int = attr.ib(validator=lambda i, a, x: x >= 1) 14 | kw: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 2 == 1) 15 | 16 | use_float16: bool = attr.ib(default=True) 17 | device: torch.device = attr.ib(default=torch.device('cpu')) 18 | requires_grad: bool = attr.ib(default=False) 19 | 20 | def __attrs_post_init__(self) -> None: 21 | super().__init__() 22 | 23 | w = torch.empty((self.n_out, self.n_in, self.kw, self.kw), dtype=torch.float32, 24 | device=self.device, requires_grad=self.requires_grad) 25 | w.normal_(std=1 / math.sqrt(self.n_in * self.kw ** 2)) 26 | 27 | b = torch.zeros((self.n_out,), dtype=torch.float32, device=self.device, 28 | requires_grad=self.requires_grad) 29 | self.w, self.b = nn.Parameter(w), nn.Parameter(b) 30 | 31 | def forward(self, x: torch.Tensor) -> torch.Tensor: 32 | if self.use_float16 and 'cuda' in self.w.device.type: 33 | if x.dtype != torch.float16: 34 | x = x.half() 35 | 36 | w, b = self.w.half(), self.b.half() 37 | else: 38 | if x.dtype != torch.float32: 39 | x = x.float() 40 | 41 | w, b = self.w, self.b 42 | 43 | return F.conv2d(x, w, b, padding=(self.kw - 1) // 2) 44 | 45 | def map_pixels(x: torch.Tensor) -> torch.Tensor: 46 | if len(x.shape) != 4: 47 | raise ValueError('expected input to be 4d') 48 | if x.dtype != torch.float: 49 | raise ValueError('expected input to have type float') 50 | 51 | return (1 - 2 * logit_laplace_eps) * x + logit_laplace_eps 52 | 53 | def unmap_pixels(x: torch.Tensor) -> torch.Tensor: 54 | if len(x.shape) != 4: 55 | raise ValueError('expected input to be 4d') 56 | if x.dtype != torch.float: 57 | raise ValueError('expected input to have type float') 58 | 59 | return torch.clamp((x - logit_laplace_eps) / (1 - 2 * logit_laplace_eps), 0, 1) 60 | -------------------------------------------------------------------------------- /model_card.md: -------------------------------------------------------------------------------- 1 | # Model Card: DALL·E dVAE 2 | 3 | Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993) and [Lessons from 4 | Archives (Jo & Gebru)](https://arxiv.org/pdf/1912.10389.pdf), we're providing some information about about the discrete 5 | VAE (dVAE) that was used to train DALL·E. 6 | 7 | ## Model Details 8 | 9 | The dVAE was developed by researchers at OpenAI to reduce the memory footprint of the transformer trained on the 10 | text-to-image generation task. The details involved in training the dVAE are described in [the paper][dalle_paper]. This 11 | model card describes the first version of the model, released in February 2021. The model consists of a convolutional 12 | encoder and decoder whose architectures are described [here](dall_e/encoder.py) and [here](dall_e/decoder.py), respectively. 13 | For questions or comments about the models or the code release, please file a Github issue. 14 | 15 | ## Model Use 16 | 17 | ### Intended Use 18 | 19 | The model is intended for others to use for training their own generative models. 20 | 21 | ### Out-of-Scope Use Cases 22 | 23 | This model is inappropriate for high-fidelity image processing applications. We also do not recommend its use as a 24 | general-purpose image compressor. 25 | 26 | ## Training Data 27 | 28 | The model was trained on publicly available text-image pairs collected from the internet. This data consists partly of 29 | [Conceptual Captions][cc] and a filtered subset of [YFCC100M][yfcc100m]. We used a subset of the filters described in 30 | [Sharma et al.][cc_paper] to construct this dataset; further details are described in [our paper][dalle_paper]. We will 31 | not be releasing the dataset. 32 | 33 | ## Performance and Limitations 34 | 35 | The heavy compression from the encoding process results in a noticeable loss of detail in the reconstructed images. This 36 | renders it inappropriate for applications that require fine-grained details of the image to be preserved. 37 | 38 | [dalle_paper]: https://arxiv.org/abs/2102.12092 39 | [cc]: https://ai.google.com/research/ConceptualCaptions 40 | [cc_paper]: https://www.aclweb.org/anthology/P18-1238/ 41 | [yfcc100m]: http://projects.dfki.uni-kl.de/yfcc100m/ 42 | -------------------------------------------------------------------------------- /notebooks/usage.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import io\n", 10 | "import os, sys\n", 11 | "import requests\n", 12 | "import PIL\n", 13 | "\n", 14 | "import torch\n", 15 | "import torchvision.transforms as T\n", 16 | "import torchvision.transforms.functional as TF\n", 17 | "\n", 18 | "from dall_e import map_pixels, unmap_pixels, load_model\n", 19 | "from IPython.display import display, display_markdown\n", 20 | "\n", 21 | "target_image_size = 256\n", 22 | "\n", 23 | "def download_image(url):\n", 24 | " resp = requests.get(url)\n", 25 | " resp.raise_for_status()\n", 26 | " return PIL.Image.open(io.BytesIO(resp.content))\n", 27 | "\n", 28 | "def preprocess(img):\n", 29 | " s = min(img.size)\n", 30 | " \n", 31 | " if s < target_image_size:\n", 32 | " raise ValueError(f'min dim for image {s} < {target_image_size}')\n", 33 | " \n", 34 | " r = target_image_size / s\n", 35 | " s = (round(r * img.size[1]), round(r * img.size[0]))\n", 36 | " img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)\n", 37 | " img = TF.center_crop(img, output_size=2 * [target_image_size])\n", 38 | " img = torch.unsqueeze(T.ToTensor()(img), 0)\n", 39 | " return map_pixels(img)" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "# This can be changed to a GPU, e.g. 'cuda:0'.\n", 49 | "dev = torch.device('cpu')\n", 50 | "\n", 51 | "# For faster load times, download these files locally and use the local paths instead.\n", 52 | "enc = load_model(\"https://cdn.openai.com/dall-e/encoder.pkl\", dev)\n", 53 | "dec = load_model(\"https://cdn.openai.com/dall-e/decoder.pkl\", dev)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "x = preprocess(download_image('https://assets.bwbx.io/images/users/iqjWHBFdfxIU/iKIWgaiJUtss/v2/1000x-1.jpg'))\n", 63 | "display_markdown('Original image:')\n", 64 | "display(T.ToPILImage(mode='RGB')(x[0]))" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "import torch.nn.functional as F\n", 74 | "\n", 75 | "z_logits = enc(x)\n", 76 | "z = torch.argmax(z_logits, axis=1)\n", 77 | "z = F.one_hot(z, num_classes=enc.vocab_size).permute(0, 3, 1, 2).float()\n", 78 | "\n", 79 | "x_stats = dec(z).float()\n", 80 | "x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3]))\n", 81 | "x_rec = T.ToPILImage(mode='RGB')(x_rec[0])\n", 82 | "\n", 83 | "display_markdown('Reconstructed image:')\n", 84 | "display(x_rec)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [] 93 | } 94 | ], 95 | "metadata": { 96 | "kernelspec": { 97 | "display_name": "Python 3", 98 | "language": "python", 99 | "name": "python3" 100 | }, 101 | "language_info": { 102 | "codemirror_mode": { 103 | "name": "ipython", 104 | "version": 3 105 | }, 106 | "file_extension": ".py", 107 | "mimetype": "text/x-python", 108 | "name": "python", 109 | "nbconvert_exporter": "python", 110 | "pygments_lexer": "ipython3", 111 | "version": "3.9.1" 112 | } 113 | }, 114 | "nbformat": 4, 115 | "nbformat_minor": 2 116 | } 117 | -------------------------------------------------------------------------------- /dall_e/encoder.py: -------------------------------------------------------------------------------- 1 | import attr 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from collections import OrderedDict 9 | from functools import partial 10 | from dall_e.utils import Conv2d 11 | 12 | @attr.s(eq=False, repr=False) 13 | class EncoderBlock(nn.Module): 14 | n_in: int = attr.ib(validator=lambda i, a, x: x >= 1) 15 | n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 ==0) 16 | n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1) 17 | 18 | device: torch.device = attr.ib(default=None) 19 | requires_grad: bool = attr.ib(default=False) 20 | 21 | def __attrs_post_init__(self) -> None: 22 | super().__init__() 23 | self.n_hid = self.n_out // 4 24 | self.post_gain = 1 / (self.n_layers ** 2) 25 | 26 | make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad) 27 | self.id_path = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity() 28 | self.res_path = nn.Sequential(OrderedDict([ 29 | ('relu_1', nn.ReLU()), 30 | ('conv_1', make_conv(self.n_in, self.n_hid, 3)), 31 | ('relu_2', nn.ReLU()), 32 | ('conv_2', make_conv(self.n_hid, self.n_hid, 3)), 33 | ('relu_3', nn.ReLU()), 34 | ('conv_3', make_conv(self.n_hid, self.n_hid, 3)), 35 | ('relu_4', nn.ReLU()), 36 | ('conv_4', make_conv(self.n_hid, self.n_out, 1)),])) 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | return self.id_path(x) + self.post_gain * self.res_path(x) 40 | 41 | @attr.s(eq=False, repr=False) 42 | class Encoder(nn.Module): 43 | group_count: int = 4 44 | n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64) 45 | n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1) 46 | input_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1) 47 | vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512) 48 | 49 | device: torch.device = attr.ib(default=torch.device('cpu')) 50 | requires_grad: bool = attr.ib(default=False) 51 | use_mixed_precision: bool = attr.ib(default=True) 52 | 53 | def __attrs_post_init__(self) -> None: 54 | super().__init__() 55 | 56 | blk_range = range(self.n_blk_per_group) 57 | n_layers = self.group_count * self.n_blk_per_group 58 | make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad) 59 | make_blk = partial(EncoderBlock, n_layers=n_layers, device=self.device, 60 | requires_grad=self.requires_grad) 61 | 62 | self.blocks = nn.Sequential(OrderedDict([ 63 | ('input', make_conv(self.input_channels, 1 * self.n_hid, 7)), 64 | ('group_1', nn.Sequential(OrderedDict([ 65 | *[(f'block_{i + 1}', make_blk(1 * self.n_hid, 1 * self.n_hid)) for i in blk_range], 66 | ('pool', nn.MaxPool2d(kernel_size=2)), 67 | ]))), 68 | ('group_2', nn.Sequential(OrderedDict([ 69 | *[(f'block_{i + 1}', make_blk(1 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range], 70 | ('pool', nn.MaxPool2d(kernel_size=2)), 71 | ]))), 72 | ('group_3', nn.Sequential(OrderedDict([ 73 | *[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range], 74 | ('pool', nn.MaxPool2d(kernel_size=2)), 75 | ]))), 76 | ('group_4', nn.Sequential(OrderedDict([ 77 | *[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range], 78 | ]))), 79 | ('output', nn.Sequential(OrderedDict([ 80 | ('relu', nn.ReLU()), 81 | ('conv', make_conv(8 * self.n_hid, self.vocab_size, 1, use_float16=False)), 82 | ]))), 83 | ])) 84 | 85 | def forward(self, x: torch.Tensor) -> torch.Tensor: 86 | if len(x.shape) != 4: 87 | raise ValueError(f'input shape {x.shape} is not 4d') 88 | if x.shape[1] != self.input_channels: 89 | raise ValueError(f'input has {x.shape[1]} channels but model built for {self.input_channels}') 90 | if x.dtype != torch.float32: 91 | raise ValueError('input must have dtype torch.float32') 92 | 93 | return self.blocks(x) 94 | -------------------------------------------------------------------------------- /dall_e/decoder.py: -------------------------------------------------------------------------------- 1 | import attr 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from collections import OrderedDict 9 | from functools import partial 10 | from dall_e.utils import Conv2d 11 | 12 | @attr.s(eq=False, repr=False) 13 | class DecoderBlock(nn.Module): 14 | n_in: int = attr.ib(validator=lambda i, a, x: x >= 1) 15 | n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 ==0) 16 | n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1) 17 | 18 | device: torch.device = attr.ib(default=None) 19 | requires_grad: bool = attr.ib(default=False) 20 | 21 | def __attrs_post_init__(self) -> None: 22 | super().__init__() 23 | self.n_hid = self.n_out // 4 24 | self.post_gain = 1 / (self.n_layers ** 2) 25 | 26 | make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad) 27 | self.id_path = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity() 28 | self.res_path = nn.Sequential(OrderedDict([ 29 | ('relu_1', nn.ReLU()), 30 | ('conv_1', make_conv(self.n_in, self.n_hid, 1)), 31 | ('relu_2', nn.ReLU()), 32 | ('conv_2', make_conv(self.n_hid, self.n_hid, 3)), 33 | ('relu_3', nn.ReLU()), 34 | ('conv_3', make_conv(self.n_hid, self.n_hid, 3)), 35 | ('relu_4', nn.ReLU()), 36 | ('conv_4', make_conv(self.n_hid, self.n_out, 3)),])) 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | return self.id_path(x) + self.post_gain * self.res_path(x) 40 | 41 | @attr.s(eq=False, repr=False) 42 | class Decoder(nn.Module): 43 | group_count: int = 4 44 | n_init: int = attr.ib(default=128, validator=lambda i, a, x: x >= 8) 45 | n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64) 46 | n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1) 47 | output_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1) 48 | vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512) 49 | 50 | device: torch.device = attr.ib(default=torch.device('cpu')) 51 | requires_grad: bool = attr.ib(default=False) 52 | use_mixed_precision: bool = attr.ib(default=True) 53 | 54 | def __attrs_post_init__(self) -> None: 55 | super().__init__() 56 | 57 | blk_range = range(self.n_blk_per_group) 58 | n_layers = self.group_count * self.n_blk_per_group 59 | make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad) 60 | make_blk = partial(DecoderBlock, n_layers=n_layers, device=self.device, 61 | requires_grad=self.requires_grad) 62 | 63 | self.blocks = nn.Sequential(OrderedDict([ 64 | ('input', make_conv(self.vocab_size, self.n_init, 1, use_float16=False)), 65 | ('group_1', nn.Sequential(OrderedDict([ 66 | *[(f'block_{i + 1}', make_blk(self.n_init if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range], 67 | ('upsample', nn.Upsample(scale_factor=2, mode='nearest')), 68 | ]))), 69 | ('group_2', nn.Sequential(OrderedDict([ 70 | *[(f'block_{i + 1}', make_blk(8 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range], 71 | ('upsample', nn.Upsample(scale_factor=2, mode='nearest')), 72 | ]))), 73 | ('group_3', nn.Sequential(OrderedDict([ 74 | *[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range], 75 | ('upsample', nn.Upsample(scale_factor=2, mode='nearest')), 76 | ]))), 77 | ('group_4', nn.Sequential(OrderedDict([ 78 | *[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 1 * self.n_hid, 1 * self.n_hid)) for i in blk_range], 79 | ]))), 80 | ('output', nn.Sequential(OrderedDict([ 81 | ('relu', nn.ReLU()), 82 | ('conv', make_conv(1 * self.n_hid, 2 * self.output_channels, 1)), 83 | ]))), 84 | ])) 85 | 86 | def forward(self, x: torch.Tensor) -> torch.Tensor: 87 | if len(x.shape) != 4: 88 | raise ValueError(f'input shape {x.shape} is not 4d') 89 | if x.shape[1] != self.vocab_size: 90 | raise ValueError(f'input has {x.shape[1]} channels but model built for {self.vocab_size}') 91 | if x.dtype != torch.float32: 92 | raise ValueError('input must have dtype torch.float32') 93 | 94 | return self.blocks(x) 95 | --------------------------------------------------------------------------------