├── .gitignore ├── LICENSE ├── README.md ├── dalle2_decoder ├── __init__.py ├── clip │ ├── __init__.py │ ├── attention.py │ ├── config.yaml │ ├── encoders.py │ ├── model_creation.py │ └── utils.py ├── dataset.py ├── dist_util.py ├── download.py ├── fp16_util.py ├── gaussian_diffusion.py ├── logger.py ├── losses.py ├── model_creation.py ├── nn.py ├── resample.py ├── respace.py ├── text2im_model.py ├── tokenizer │ ├── __init__.py │ ├── bpe.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── encoder.json.gz │ ├── simple_tokenizer.py │ └── vocab.bpe.gz ├── train_utils.py ├── unet.py ├── utils.py └── xf.py ├── notebooks └── inference.ipynb └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.egg-info/ 3 | .DS_Store 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dalle2-Decoder 2 | 3 | This is the code for dalle2-decoder. 4 | 5 | # Usage 6 | 7 | To install this package, clone this repository and then run: 8 | 9 | ``` 10 | pip install -e . 11 | ``` 12 | 13 | For detailed usage examples, see the [notebooks](notebooks) directory. 14 | 15 | * The [inference](notebooks/inference.ipynb) [![][colab]][colab-inference] notebook shows how to use Dalle2-Decoder. 16 | 17 | [colab]: 18 | [colab-inference]: 19 | -------------------------------------------------------------------------------- /dalle2_decoder/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | A codebase for performing model inference with a text-conditional diffusion model. 3 | """ 4 | -------------------------------------------------------------------------------- /dalle2_decoder/clip/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralPushkin/Dalle2-Decoder/47af769d01d50b7a7f9c1c2b8cef6112b9870bdb/dalle2_decoder/clip/__init__.py -------------------------------------------------------------------------------- /dalle2_decoder/clip/attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | from abc import ABC, abstractmethod 3 | from itertools import product 4 | from typing import Any, Optional 5 | 6 | import attr 7 | import numpy as np 8 | import torch 9 | 10 | 11 | @attr.s 12 | class AttentionMask(ABC): 13 | query_context_size: int = attr.ib(validator=lambda i, a, x: x >= 1) # type: ignore 14 | key_context_size: int = attr.ib(validator=lambda i, a, x: x >= 1) # type: ignore 15 | block_size: int = attr.ib(validator=lambda i, a, x: x >= 1) # type: ignore 16 | n_head: int = attr.ib(validator=lambda i, a, x: x >= 1) # type: ignore 17 | is_head_specific: bool = attr.ib(default=False) 18 | n_query_pad: int = attr.ib(default=0) 19 | n_key_pad: int = attr.ib(default=0) 20 | 21 | def __attrs_post_init__(self) -> None: 22 | if self.query_context_size % self.block_size != 0: 23 | raise ValueError() 24 | if self.key_context_size % self.block_size != 0: 25 | raise ValueError() 26 | if self.n_query_pad >= self.query_context_size: 27 | raise ValueError() 28 | if self.n_key_pad >= self.key_context_size: 29 | raise ValueError() 30 | 31 | self.n_query_block = self.query_context_size // self.block_size 32 | self.n_key_block = self.key_context_size // self.block_size 33 | self.first_pad_query_block_idx = self.n_query_block - int( 34 | math.ceil(self.n_query_pad / self.block_size) 35 | ) 36 | self.first_pad_key_block_idx = self.n_key_block - int( 37 | math.ceil(self.n_key_pad / self.block_size) 38 | ) 39 | 40 | def _make_global_layout(self) -> None: 41 | if not self.is_head_specific: 42 | m = np.ones([self.n_query_block, self.n_key_block], dtype=np.bool) 43 | r = product(*[range(n) for n in m.shape]) 44 | 45 | for qb, kb in r: 46 | m[qb, kb] = np.any(self.block_layout(None, 0, qb, kb, 0)) 47 | else: 48 | m = np.ones([self.n_head, self.n_query_block, self.n_key_block], dtype=np.bool) 49 | r = product(*[range(n) for n in m.shape]) 50 | 51 | for h, qb, kb in r: 52 | m[h, qb, kb] = np.any(self.block_layout(None, h, qb, kb, 0)) 53 | 54 | self.global_layout = m 55 | 56 | @abstractmethod 57 | def _block_layout( 58 | self, blk_shape: Any, head_idx: int, query_idx: int, key_idx: int, blk_idx: int 59 | ) -> np.ndarray: 60 | raise NotImplementedError() 61 | 62 | def block_layout( 63 | self, blk_shape: Any, head_idx: int, query_idx: int, key_idx: int, blk_idx: int 64 | ) -> np.ndarray: 65 | """ 66 | `query_idx`, `key_idx` are block-level, zero-based indices. 67 | """ 68 | 69 | m = np.ones([self.block_size, self.block_size], dtype=np.bool) 70 | 71 | if query_idx >= self.first_pad_query_block_idx: 72 | n_pad = min( 73 | self.block_size, 74 | (query_idx + 1) * self.block_size - (self.query_context_size - self.n_query_pad), 75 | ) 76 | assert n_pad > 0 77 | m[self.block_size - n_pad :] = False 78 | if key_idx >= self.first_pad_key_block_idx: 79 | n_pad = min( 80 | self.block_size, 81 | (key_idx + 1) * self.block_size - (self.key_context_size - self.n_key_pad), 82 | ) 83 | assert n_pad > 0 84 | m[:, self.block_size - n_pad :] = False 85 | 86 | return m & self._block_layout(blk_shape, head_idx, query_idx, key_idx, blk_idx) 87 | 88 | 89 | @attr.s 90 | class DenseAttentionMask(AttentionMask): 91 | def __attrs_post_init__(self) -> None: 92 | super().__attrs_post_init__() 93 | 94 | self.global_layout = np.ones([self.n_query_block, self.n_key_block], dtype=np.bool) 95 | n_zero_query_blocks = self.n_query_pad // self.block_size 96 | n_zero_key_blocks = self.n_key_pad // self.block_size 97 | self.global_layout[self.n_query_block - n_zero_query_blocks :] = False 98 | self.global_layout[:, self.n_key_block - n_zero_key_blocks :] = False 99 | 100 | def _block_layout( 101 | self, blk_shape: Any, head_idx: int, query_idx: int, key_idx: int, blk_idx: int 102 | ) -> np.ndarray: 103 | return np.ones([self.block_size, self.block_size], dtype=np.bool) 104 | 105 | 106 | @attr.s 107 | class DenseCausalAttentionMask(AttentionMask): 108 | def __attrs_post_init__(self) -> None: 109 | super().__attrs_post_init__() 110 | 111 | self.global_layout = np.tril(np.ones([self.n_query_block, self.n_key_block], dtype=np.bool)) 112 | n_zero_query_blocks = self.n_query_pad // self.block_size 113 | n_zero_key_blocks = self.n_key_pad // self.block_size 114 | self.global_layout[self.n_query_block - n_zero_query_blocks :] = False 115 | self.global_layout[:, self.n_key_block - n_zero_key_blocks :] = False 116 | 117 | def _block_layout( 118 | self, blk_shape: Any, head_idx: int, query_idx: int, key_idx: int, blk_idx: int 119 | ) -> np.ndarray: 120 | if query_idx > key_idx: 121 | return np.ones(2 * [self.block_size], dtype=np.bool) 122 | elif query_idx < key_idx: 123 | return np.zeros(2 * [self.block_size], dtype=np.bool) 124 | else: 125 | return np.tril(np.ones(2 * [self.block_size], dtype=np.bool)) 126 | 127 | 128 | @attr.s(eq=False, repr=False) 129 | class AttentionInfo: 130 | n_heads: int = attr.ib() 131 | ctx_blks_q: int = attr.ib() 132 | ctx_blks_k: int = attr.ib() 133 | block_size: int = attr.ib() 134 | pytorch_attn_bias: Optional[torch.Tensor] = attr.ib() 135 | 136 | 137 | def to_attention_info(d: AttentionMask) -> AttentionInfo: 138 | return AttentionInfo( 139 | n_heads=d.n_head, 140 | ctx_blks_q=d.n_query_block, 141 | ctx_blks_k=d.n_key_block, 142 | block_size=d.block_size, 143 | pytorch_attn_bias=None, 144 | ) 145 | 146 | 147 | def make_full_layout(d: AttentionMask) -> np.ndarray: 148 | """ 149 | Returns the `context_size x context_size` layout matrix described by `d`. If the layout is dependent on the index of 150 | the attention head, a `attention_head x context_size x context_size` layout matrix is returned instead. 151 | """ 152 | 153 | if not d.is_head_specific: 154 | u = np.reshape(d.global_layout, [d.n_query_block, d.n_key_block, 1, 1]) 155 | r = product(range(d.n_query_block), range(d.n_key_block)) 156 | v = np.array([d.block_layout(None, 0, i, j, 0) for i, j in r]) 157 | v = np.reshape(v, [d.n_query_block, d.n_key_block, d.block_size, d.block_size]) 158 | 159 | w = u * v 160 | w = np.transpose(w, [0, 2, 1, 3]) 161 | w = np.reshape(w, [d.query_context_size, d.key_context_size]) 162 | return w 163 | else: 164 | if len(d.global_layout.shape) == 2: 165 | u = np.reshape(d.global_layout, [1, d.n_query_block, d.n_key_block, 1, 1]) 166 | u = np.tile(u, [d.n_head, 1, 1, 1, 1]) 167 | elif len(d.global_layout.shape) == 3: 168 | u = np.reshape(d.global_layout, [d.n_head, d.n_query_block, d.n_key_block, 1, 1]) 169 | else: 170 | raise RuntimeError() 171 | 172 | s = product(range(d.n_head), range(d.n_query_block), range(d.n_key_block)) 173 | v = np.array([d.block_layout(None, i, j, k, 0) for i, j, k in s]) 174 | v = np.reshape(v, [d.n_head, d.n_query_block, d.n_key_block, d.block_size, d.block_size]) 175 | 176 | w = u * v 177 | w = np.transpose(w, [0, 1, 3, 2, 4]) 178 | w = np.reshape(w, [d.n_head, d.query_context_size, d.key_context_size]) 179 | return w 180 | -------------------------------------------------------------------------------- /dalle2_decoder/clip/config.yaml: -------------------------------------------------------------------------------- 1 | logit_scale: 100.0 2 | 3 | # Diffusion settings 4 | beta_schedule: "squaredcos_cap_v2" 5 | n_timesteps: 1000 6 | 7 | # Architecture settings 8 | image_size: 64 9 | patch_size: 4 10 | n_vocab: 65536 11 | max_text_len: 77 12 | n_embd: 512 13 | n_head_state_text: 64 14 | n_head_text: 8 15 | n_xf_blocks_text: 12 16 | n_head_state_image: 64 17 | n_head_image: 12 18 | n_xf_blocks_image: 12 19 | -------------------------------------------------------------------------------- /dalle2_decoder/clip/encoders.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import OrderedDict 3 | from typing import List, Optional, Tuple, cast 4 | 5 | import attr 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from .attention import ( 12 | AttentionInfo, 13 | DenseAttentionMask, 14 | DenseCausalAttentionMask, 15 | make_full_layout, 16 | to_attention_info, 17 | ) 18 | from .utils import Affine, LayerNorm, zero_key_bias_grad 19 | 20 | # Constants used in the original CLIP implementation. 21 | image_channel_means = [122.77093945, 116.74601272, 104.09373519] 22 | image_channel_stds = [68.50053285, 66.63215831, 70.32316309] 23 | 24 | 25 | @attr.s(eq=False, repr=False) 26 | class TextEmbedding(nn.Module): 27 | n_vocab: int = attr.ib() 28 | n_context: int = attr.ib() 29 | n_state: int = attr.ib() 30 | device: torch.device = attr.ib(default=torch.device("cuda")) 31 | 32 | def __attrs_post_init__(self) -> None: 33 | super().__init__() 34 | 35 | w_voc = torch.empty((self.n_vocab, self.n_state), dtype=torch.float32, device=self.device) 36 | w_pos = torch.empty((self.n_context, self.n_state), dtype=torch.float32, device=self.device) 37 | 38 | with torch.no_grad(): 39 | w_voc.normal_(std=0.02) 40 | w_pos.normal_(std=0.01) 41 | 42 | self.w_voc = nn.Parameter(w_voc) 43 | self.w_pos = nn.Parameter(w_pos) 44 | 45 | def forward(self, x: torch.Tensor) -> torch.Tensor: 46 | if len(x.shape) != 2: 47 | raise ValueError() 48 | 49 | return F.embedding(x, self.w_voc) + self.w_pos[None, :, :] 50 | 51 | 52 | @attr.s(eq=False, repr=False) 53 | class ImageEmbedding(nn.Module): 54 | image_size: int = attr.ib() 55 | patch_size: int = attr.ib() 56 | n_state: int = attr.ib() 57 | n_timestep: int = attr.ib(default=0) 58 | device: torch.device = attr.ib(default=torch.device("cuda")) 59 | 60 | def __attrs_post_init__(self) -> None: 61 | super().__init__() 62 | 63 | if self.image_size % self.patch_size != 0: 64 | raise ValueError() 65 | 66 | n_patch = self.image_size // self.patch_size 67 | patch_proj = torch.empty( 68 | (self.n_state, 3) + 2 * (self.patch_size,), dtype=torch.float32, device=self.device 69 | ) 70 | w_pos = torch.empty( 71 | (1 + n_patch ** 2, self.n_state), dtype=torch.float32, device=self.device 72 | ) 73 | 74 | with torch.no_grad(): 75 | if self.n_timestep == 0: 76 | pred_state = torch.empty((self.n_state,), dtype=torch.float32, device=self.device) 77 | pred_state.normal_(std=1 / np.sqrt(self.n_state)) 78 | self.pred_state = nn.Parameter(pred_state) 79 | else: 80 | w_t = torch.empty( 81 | (self.n_timestep, self.n_state), dtype=torch.float32, device=self.device 82 | ) 83 | w_t.normal_(std=1 / np.sqrt(self.n_state)) 84 | self.w_t = nn.Parameter(w_t) 85 | 86 | patch_proj.normal_(std=np.sqrt(2 / (self.n_state * self.patch_size ** 2))) 87 | w_pos.normal_(std=1 / np.sqrt(self.n_state)) 88 | 89 | self.patch_proj = nn.Parameter(patch_proj) 90 | self.w_pos = nn.Parameter(w_pos) 91 | 92 | self.channel_means = torch.tensor( 93 | image_channel_means, dtype=torch.float32, device=self.device 94 | )[None, :, None, None] 95 | self.channel_stds = torch.tensor( 96 | image_channel_stds, dtype=torch.float32, device=self.device 97 | )[None, :, None, None] 98 | self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device) 99 | 100 | def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None) -> torch.Tensor: 101 | if len(x.shape) != 4: 102 | raise ValueError("input should be 4d") 103 | if x.shape[1] != 3: 104 | raise ValueError("input should have 3 channels") 105 | if not (x.shape[2] == self.image_size and x.shape[3] == self.image_size): 106 | raise ValueError(f"input is not {self.image_size} x {self.image_size}") 107 | 108 | if (self.n_timestep == 0 and t is not None) or (self.n_timestep != 0 and t is None): 109 | raise ValueError() 110 | if self.n_timestep != 0: 111 | assert t is not None 112 | if len(t.shape) != 1: 113 | raise ValueError() 114 | if t.shape[0] != x.shape[0]: 115 | raise ValueError() 116 | 117 | x = (x - self.channel_means) / self.channel_stds 118 | x = F.conv2d(x, self.patch_proj, stride=self.patch_size) 119 | x = x.reshape(x.shape[0], self.n_state, (self.image_size // self.patch_size) ** 2).permute( 120 | 0, 2, 1 121 | ) 122 | 123 | sot = ( 124 | self.pred_state[None, None].expand(x.shape[0], -1, -1) 125 | if self.n_timestep == 0 126 | else F.embedding(cast(torch.Tensor, t), self.w_t)[:, None] 127 | ) 128 | x = torch.cat((sot, x), dim=1) + self.w_pos[None] 129 | return self.ln(x) 130 | 131 | 132 | @attr.s(eq=False, repr=False) 133 | class AttentionResblock(nn.Module): 134 | n_state: int = attr.ib() 135 | n_resblocks: int = attr.ib() 136 | attn_fn: AttentionInfo = attr.ib() 137 | device: torch.device = attr.ib(default=torch.device("cuda")) 138 | 139 | def __attrs_post_init__(self) -> None: 140 | super().__init__() 141 | 142 | self.n_head_state = self.n_state // self.attn_fn.n_heads 143 | self.qk_scale = 1 / np.sqrt(self.n_head_state) 144 | 145 | self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device) 146 | self.f_q = Affine( 147 | self.n_state, 148 | self.n_state, 149 | std=1 / math.sqrt(self.n_state), 150 | use_bias=True, 151 | bias_filter_fn=zero_key_bias_grad, 152 | device=self.device, 153 | ) 154 | self.f_k = Affine( 155 | self.n_state, 156 | self.n_state, 157 | std=1 / math.sqrt(self.n_state), 158 | use_bias=False, 159 | bias_filter_fn=zero_key_bias_grad, 160 | device=self.device, 161 | ) 162 | self.f_v = Affine( 163 | self.n_state, 164 | self.n_state, 165 | std=1 / math.sqrt(self.n_state), 166 | use_bias=True, 167 | bias_filter_fn=zero_key_bias_grad, 168 | device=self.device, 169 | ) 170 | self.f_c = Affine( 171 | self.n_state, 172 | self.n_state, 173 | use_bias=True, 174 | std=1 / np.sqrt(self.n_state * self.n_resblocks ** 2), 175 | device=self.device, 176 | ) # XXX 177 | 178 | def forward(self, m: torch.Tensor) -> torch.Tensor: 179 | n_context = m.shape[1] 180 | n_query_pad = self.attn_fn.ctx_blks_q * self.attn_fn.block_size - n_context 181 | n_key_pad = self.attn_fn.ctx_blks_k * self.attn_fn.block_size - n_context 182 | assert n_query_pad >= 0 183 | assert n_key_pad >= 0 184 | 185 | r = m 186 | r = self.ln(r) 187 | q, k, v = self.f_q(r), self.f_k(r), self.f_v(r) 188 | 189 | if n_query_pad != 0: 190 | q = F.pad(q, (0, 0, 0, n_query_pad)) 191 | 192 | if n_key_pad != 0: 193 | k = F.pad(k, (0, 0, 0, n_key_pad)) 194 | v = F.pad(v, (0, 0, 0, n_key_pad)) 195 | 196 | q = q.view([q.shape[0], -1, self.attn_fn.n_heads, self.n_head_state]).permute((0, 2, 1, 3)) 197 | k = k.view([k.shape[0], -1, self.attn_fn.n_heads, self.n_head_state]).permute((0, 2, 1, 3)) 198 | v = v.view([v.shape[0], -1, self.attn_fn.n_heads, self.n_head_state]).permute((0, 2, 1, 3)) 199 | w = torch.einsum( 200 | "bhcd,bhkd->bhck", q * math.sqrt(self.qk_scale), k * math.sqrt(self.qk_scale) 201 | ) 202 | 203 | if hasattr(self.attn_fn, "pytorch_attn_bias"): 204 | bias = self.attn_fn.pytorch_attn_bias 205 | assert len(bias.shape) in {2, 3} 206 | 207 | if len(bias.shape) == 2: 208 | w = torch.softmax(w + self.attn_fn.pytorch_attn_bias[None, None], dim=-1) 209 | elif len(bias.shape) == 3: 210 | w = torch.softmax(w + self.attn_fn.pytorch_attn_bias[None], dim=-1) 211 | else: 212 | w = torch.softmax(w, dim=-1) 213 | 214 | r = torch.einsum("bhck,bhkd->bhcd", w, v) 215 | r = r.permute((0, 2, 1, 3)).reshape((r.shape[0], -1, self.n_state)) 216 | 217 | if n_query_pad != 0: 218 | r = r[:, :-n_query_pad] 219 | 220 | assert r.shape[1] == n_context 221 | 222 | r = self.f_c(r) 223 | return m + r 224 | 225 | 226 | @attr.s(eq=False, repr=False) 227 | class FullyConnectedResblock(nn.Module): 228 | """ 229 | Not imported from other files because we retain Alec's original inits. 230 | """ 231 | 232 | n_state: int = attr.ib() 233 | n_resblocks: int = attr.ib() 234 | device: torch.device = attr.ib(default=torch.device("cuda")) 235 | 236 | def __attrs_post_init__(self) -> None: 237 | super().__init__() 238 | 239 | self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device) 240 | self.f_1 = Affine( 241 | self.n_state, 242 | 4 * self.n_state, 243 | use_bias=True, 244 | std=np.sqrt(2 / (4 * self.n_state)), 245 | device=self.device, 246 | ) 247 | self.f_2 = Affine( 248 | 4 * self.n_state, 249 | self.n_state, 250 | use_bias=True, 251 | std=1 / np.sqrt(self.n_state * self.n_resblocks ** 2), 252 | device=self.device, 253 | ) # XXX 254 | 255 | def forward(self, m: torch.Tensor) -> torch.Tensor: 256 | r = m 257 | r = self.ln(r) 258 | 259 | r = self.f_2(F.gelu(self.f_1(r))) 260 | return m + r 261 | 262 | 263 | @attr.s(eq=False, repr=False) 264 | class TransformerBlock(nn.Module): 265 | n_state: int = attr.ib() 266 | n_resblocks: int = attr.ib() 267 | attn_fn: AttentionInfo = attr.ib() 268 | device: torch.device = attr.ib(default=torch.device("cuda")) 269 | 270 | def __attrs_post_init__(self) -> None: 271 | super().__init__() 272 | 273 | self.f_attn = AttentionResblock( 274 | self.n_state, 275 | self.n_resblocks, 276 | self.attn_fn, 277 | self.device, 278 | ) 279 | self.f_mlp = FullyConnectedResblock(self.n_state, self.n_resblocks, self.device) 280 | 281 | def forward(self, x: torch.Tensor) -> torch.Tensor: 282 | return self.f_mlp(self.f_attn(x)) 283 | 284 | 285 | @attr.s(eq=False, repr=False) 286 | class TextFeatureExtractor(nn.Module): 287 | n_state: int = attr.ib() 288 | n_embd: int = attr.ib() 289 | device: torch.device = attr.ib(default=torch.device("cuda")) 290 | 291 | def __attrs_post_init__(self) -> None: 292 | super().__init__() 293 | 294 | self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device) 295 | self.f = Affine(self.n_state, self.n_embd, use_bias=False, device=self.device) 296 | 297 | def forward( 298 | self, text: torch.Tensor, text_len: torch.Tensor, return_probe_features: bool = False 299 | ) -> torch.Tensor: 300 | if len(text.shape) != 3: 301 | raise ValueError("expected text to be 3d") 302 | if len(text_len.shape) != 1: 303 | raise ValueError("expected text length to be 1d") 304 | if text.shape[0] != text_len.shape[0]: 305 | raise ValueError("text and text_len have inconsistent batch dimensions") 306 | 307 | index = (text_len - 1)[:, None, None].expand(-1, 1, text.shape[2]) 308 | x = torch.gather(text, dim=1, index=index) 309 | assert list(x.shape) == [text.shape[0], 1, text.shape[2]] 310 | 311 | if return_probe_features: 312 | return x[:, 0] 313 | 314 | x = self.ln(x) 315 | return self.f(x[:, 0]) 316 | 317 | 318 | @attr.s(eq=False, repr=False) 319 | class ImageFeatureExtractor(nn.Module): 320 | n_state: int = attr.ib() 321 | n_embd: int = attr.ib() 322 | device: torch.device = attr.ib(default=torch.device("cuda")) 323 | 324 | def __attrs_post_init__(self) -> None: 325 | super().__init__() 326 | 327 | self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device) 328 | self.f = Affine(self.n_state, self.n_embd, use_bias=False, device=self.device) 329 | 330 | def forward(self, x: torch.Tensor, return_probe_features: bool = False) -> torch.Tensor: 331 | if return_probe_features: 332 | return x[:, 0] 333 | 334 | x = self.ln(x[:, :1]) 335 | return self.f(x[:, 0]) 336 | 337 | 338 | @attr.s(eq=False, repr=False) 339 | class TextEncoder(nn.Module): 340 | n_bpe_vocab: int = attr.ib() 341 | max_text_len: int = attr.ib() 342 | n_embd: int = attr.ib() 343 | n_head: int = attr.ib() 344 | n_xf_blocks: int = attr.ib() 345 | n_head_state: int = attr.ib(default=64) 346 | device: torch.device = attr.ib(default=torch.device("cuda")) 347 | block_size: int = attr.ib(init=False, default=32) 348 | 349 | def __attrs_post_init__(self) -> None: 350 | super().__init__() 351 | 352 | self.n_state = self.n_head * self.n_head_state 353 | n_rounded_context = self.block_size * int(math.ceil(self.max_text_len / self.block_size)) 354 | n_pad = n_rounded_context - self.max_text_len 355 | 356 | args = ( 357 | n_rounded_context, 358 | n_rounded_context, 359 | self.block_size, 360 | self.n_head, 361 | False, 362 | n_pad, 363 | n_pad, 364 | ) 365 | mask = DenseCausalAttentionMask(*args) 366 | attn_fn = to_attention_info(mask) 367 | 368 | m = 1 - make_full_layout(mask).astype(np.float32) 369 | m[m == 1] = -1e10 370 | attn_fn.pytorch_attn_bias = torch.from_numpy(m).to(self.device) 371 | 372 | blocks: List[Tuple[str, nn.Module]] = [ 373 | ( 374 | "input", 375 | TextEmbedding( 376 | self.n_bpe_vocab, self.max_text_len, self.n_state, device=self.device 377 | ), 378 | ) 379 | ] 380 | 381 | for i in range(self.n_xf_blocks): 382 | blocks.append( 383 | ( 384 | f"block_{i}", 385 | TransformerBlock(self.n_state, 2 * self.n_xf_blocks, attn_fn, self.device), 386 | ) 387 | ) 388 | 389 | blocks.append( 390 | ("output", TextFeatureExtractor(self.n_state, self.n_embd, device=self.device)) 391 | ) 392 | 393 | self.blocks = nn.ModuleDict(OrderedDict(blocks)) 394 | 395 | def forward( 396 | self, 397 | text: torch.Tensor, 398 | text_len: torch.Tensor, 399 | return_probe_features: bool = False, 400 | ) -> torch.Tensor: 401 | 402 | n_batch = text.shape[0] 403 | h = self.blocks["input"](text) 404 | 405 | for i in range(self.n_xf_blocks): 406 | h = self.blocks[f"block_{i}"](h) 407 | 408 | h = self.blocks["output"](h, text_len, return_probe_features=return_probe_features) 409 | 410 | assert list(h.shape) == [ 411 | n_batch, 412 | self.n_embd if not return_probe_features else self.n_state, 413 | ] 414 | return h 415 | 416 | 417 | @attr.s(eq=False, repr=False) 418 | class ImageEncoder(nn.Module): 419 | image_size: int = attr.ib() 420 | patch_size: int = attr.ib() 421 | n_embd: int = attr.ib() 422 | n_head: int = attr.ib() 423 | n_xf_blocks: int = attr.ib() 424 | n_head_state: int = attr.ib(default=64) 425 | n_timestep: int = attr.ib(default=0) 426 | device: torch.device = attr.ib(default=torch.device("cuda")) 427 | block_size: int = attr.ib(init=False, default=32) 428 | 429 | def __attrs_post_init__(self) -> None: 430 | super().__init__() 431 | 432 | self.n_state = self.n_head * self.n_head_state 433 | self.n_context = 1 + (self.image_size // self.patch_size) ** 2 434 | n_rounded_context = self.block_size * int(math.ceil(self.n_context / self.block_size)) 435 | n_pad = n_rounded_context - self.n_context 436 | 437 | args = ( 438 | n_rounded_context, 439 | n_rounded_context, 440 | self.block_size, 441 | self.n_head, 442 | False, 443 | n_pad, 444 | n_pad, 445 | ) 446 | mask = DenseAttentionMask(*args) 447 | attn_fn = to_attention_info(mask) 448 | 449 | m = 1 - make_full_layout(mask).astype(np.float32) 450 | m[m == 1] = -1e10 451 | attn_fn.pytorch_attn_bias = torch.from_numpy(m).to(self.device) 452 | 453 | blocks: List[Tuple[str, nn.Module]] = [ 454 | ( 455 | "input", 456 | ImageEmbedding( 457 | self.image_size, 458 | self.patch_size, 459 | self.n_state, 460 | n_timestep=self.n_timestep, 461 | device=self.device, 462 | ), 463 | ) 464 | ] 465 | 466 | for i in range(self.n_xf_blocks): 467 | blocks.append( 468 | ( 469 | f"block_{i}", 470 | TransformerBlock(self.n_state, 2 * self.n_xf_blocks, attn_fn, self.device), 471 | ) 472 | ) 473 | 474 | blocks.append(("output", ImageFeatureExtractor(self.n_state, self.n_embd, self.device))) 475 | 476 | self.blocks = nn.ModuleDict(OrderedDict(blocks)) 477 | 478 | def forward( 479 | self, 480 | image: torch.Tensor, 481 | timesteps: Optional[torch.Tensor] = None, 482 | return_probe_features: bool = False, 483 | ) -> torch.Tensor: 484 | n_batch = image.shape[0] 485 | h = self.blocks["input"](image, t=timesteps) 486 | 487 | for i in range(self.n_xf_blocks): 488 | h = self.blocks[f"block_{i}"](h) 489 | 490 | h = self.blocks["output"](h, return_probe_features=return_probe_features) 491 | 492 | assert list(h.shape) == [ 493 | n_batch, 494 | self.n_embd if not return_probe_features else self.n_state, 495 | ] 496 | 497 | return h 498 | -------------------------------------------------------------------------------- /dalle2_decoder/clip/model_creation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import lru_cache 3 | from typing import Any, Callable, Dict, List, Optional, Tuple 4 | 5 | import attr 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import yaml 10 | from glide_text2im.tokenizer.simple_tokenizer import SimpleTokenizer 11 | 12 | from .encoders import ImageEncoder, TextEncoder 13 | 14 | 15 | @lru_cache() 16 | def default_config_path() -> str: 17 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "config.yaml") 18 | 19 | 20 | @attr.s 21 | class CLIPModel: 22 | config: Dict[str, Any] = attr.ib() 23 | text_encoder: nn.Module = attr.ib() 24 | image_encoder: nn.Module = attr.ib() 25 | logit_scale: torch.Tensor = attr.ib() 26 | device: torch.device = attr.ib() 27 | tokenizer: SimpleTokenizer = attr.ib() 28 | 29 | def encode_prompts(self, prompts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: 30 | tokens = [] 31 | lens = [] 32 | for prompt in prompts: 33 | sub_tokens, sub_len = self.tokenizer.padded_tokens_and_len( 34 | self.tokenizer.encode(prompt), self.text_encoder.max_text_len 35 | ) 36 | tokens.append(sub_tokens) 37 | lens.append(sub_len) 38 | return ( 39 | torch.tensor(tokens).to(dtype=torch.long, device=self.device), 40 | torch.tensor(lens).to(dtype=torch.long, device=self.device), 41 | ) 42 | 43 | def text_embeddings(self, prompts: List[str]) -> torch.Tensor: 44 | tokens, lens = self.encode_prompts(prompts) 45 | z_t = self.text_encoder(tokens, lens) 46 | return z_t / (torch.linalg.norm(z_t, dim=-1, keepdim=True) + 1e-12) 47 | 48 | def image_embeddings(self, images: torch.Tensor, t: torch.Tensor) -> torch.Tensor: 49 | z_i = self.image_encoder((images + 1) * 127.5, t) 50 | return z_i / (torch.linalg.norm(z_i, dim=-1, keepdim=True) + 1e-12) 51 | 52 | def cond_fn(self, prompts: List[str], grad_scale: float) -> Callable[..., torch.Tensor]: 53 | with torch.no_grad(): 54 | z_t = self.text_embeddings(prompts) 55 | 56 | def cond_fn(x, t, grad_scale=grad_scale, **kwargs): 57 | with torch.enable_grad(): 58 | x_var = x.detach().requires_grad_(True) 59 | z_i = self.image_embeddings(x_var, t) 60 | loss = torch.exp(self.logit_scale) * (z_t * z_i).sum() 61 | grad = torch.autograd.grad(loss, x_var)[0].detach() 62 | return grad * grad_scale 63 | 64 | return cond_fn 65 | 66 | 67 | def create_clip_model( 68 | config_path: Optional[str] = None, 69 | device: Optional[torch.device] = None, 70 | tokenizer: Optional[SimpleTokenizer] = None, 71 | ) -> CLIPModel: 72 | if config_path is None: 73 | config_path = default_config_path() 74 | if device is None: 75 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 76 | if tokenizer is None: 77 | tokenizer = SimpleTokenizer() 78 | 79 | with open(config_path, "r") as f: 80 | config = yaml.load(f, Loader=yaml.SafeLoader) 81 | 82 | text_encoder = TextEncoder( 83 | n_bpe_vocab=config["n_vocab"], 84 | max_text_len=config["max_text_len"], 85 | n_embd=config["n_embd"], 86 | n_head=config["n_head_text"], 87 | n_xf_blocks=config["n_xf_blocks_text"], 88 | n_head_state=config["n_head_state_text"], 89 | device=device, 90 | ) 91 | 92 | image_encoder = ImageEncoder( 93 | image_size=config["image_size"], 94 | patch_size=config["patch_size"], 95 | n_embd=config["n_embd"], 96 | n_head=config["n_head_image"], 97 | n_xf_blocks=config["n_xf_blocks_image"], 98 | n_head_state=config["n_head_state_image"], 99 | n_timestep=config["n_timesteps"], 100 | device=device, 101 | ) 102 | 103 | logit_scale = torch.tensor( 104 | np.log(config["logit_scale"]), 105 | dtype=torch.float32, 106 | device=device, 107 | requires_grad=False, 108 | ) 109 | 110 | return CLIPModel( 111 | config=config, 112 | text_encoder=text_encoder, 113 | image_encoder=image_encoder, 114 | logit_scale=logit_scale, 115 | device=device, 116 | tokenizer=tokenizer, 117 | ) 118 | -------------------------------------------------------------------------------- /dalle2_decoder/clip/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Callable, Optional 3 | 4 | import attr 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | FilterFn = Callable[[torch.Tensor], torch.Tensor] 10 | 11 | 12 | class ZeroKeyBiasGrad(torch.autograd.Function): 13 | @staticmethod 14 | def forward(ctx, x): 15 | return x 16 | 17 | @staticmethod 18 | def backward(ctx, output_grad): 19 | output_grad = output_grad.clone() 20 | output_grad.chunk(3)[1].zero_() 21 | return output_grad 22 | 23 | 24 | def zero_key_bias_grad(x: torch.Tensor) -> torch.Tensor: 25 | return ZeroKeyBiasGrad.apply(x) 26 | 27 | 28 | @attr.s(eq=False, repr=False) 29 | class LayerNorm(nn.Module): 30 | n_state: int = attr.ib() 31 | eps: float = attr.ib(default=1e-6) 32 | device: torch.device = attr.ib(default=torch.device("cuda")) 33 | 34 | def __attrs_post_init__(self) -> None: 35 | super().__init__() 36 | self.g = nn.Parameter(torch.ones((self.n_state,), dtype=torch.float32, device=self.device)) 37 | self.b = nn.Parameter(torch.zeros((self.n_state,), dtype=torch.float32, device=self.device)) 38 | self.g.weight_decay_level = "disable" # type: ignore 39 | self.b.weight_decay_level = "disable" # type: ignore 40 | 41 | def forward(self, x: torch.Tensor) -> torch.Tensor: 42 | return F.layer_norm( 43 | x.type(torch.float32), torch.Size((self.n_state,)), self.g, self.b, self.eps 44 | ) 45 | 46 | 47 | @attr.s(eq=False, repr=False) 48 | class Affine(nn.Module): 49 | n_in: int = attr.ib() 50 | n_out: int = attr.ib() 51 | use_bias: bool = attr.ib(default=True) 52 | use_admnet_init: bool = attr.ib(default=False) 53 | std: Optional[float] = attr.ib(default=None) 54 | extra_init_scale: Optional[float] = attr.ib(default=None) 55 | bias_filter_fn: FilterFn = attr.ib(default=lambda x: x) 56 | device: torch.device = attr.ib(default=torch.device("cuda")) 57 | 58 | def __attrs_post_init__(self) -> None: 59 | super().__init__() 60 | 61 | if not self.use_admnet_init: 62 | self.std = self.std if self.std is not None else math.sqrt(2 / (self.n_in + self.n_out)) 63 | self.std = ( 64 | self.std if self.extra_init_scale is None else self.std * self.extra_init_scale 65 | ) 66 | 67 | w = torch.empty((self.n_out, self.n_in), dtype=torch.float32, device=self.device) 68 | self.w = nn.Parameter(w) 69 | 70 | if self.use_bias: 71 | self.b = nn.Parameter( 72 | torch.zeros((self.n_out,), dtype=torch.float32, device=self.device) 73 | ) 74 | self.b.weight_decay_level = "disable" # type: ignore 75 | else: 76 | if self.extra_init_scale is not None: 77 | raise ValueError("extra_init_scale incompatible with admnet init") 78 | 79 | w = torch.empty((self.n_out, self.n_in), dtype=torch.float32, device=self.device) 80 | 81 | if self.use_bias: 82 | b = torch.empty((self.n_out,), dtype=torch.float32, device=self.device) 83 | 84 | self.w = nn.Parameter(w) 85 | 86 | if self.use_bias: 87 | self.b = nn.Parameter(b) 88 | self.b.weight_decay_level = "disable" # type: ignore 89 | 90 | def forward(self, x: torch.Tensor) -> torch.Tensor: 91 | w = self.w if self.w.dtype == x.dtype else self.w.to(x.dtype) 92 | b = ( 93 | self.bias_filter_fn(self.b if self.b.dtype == x.dtype else self.b.to(x.dtype)) 94 | if self.use_bias 95 | else None 96 | ) 97 | return F.linear(x, w, b) 98 | -------------------------------------------------------------------------------- /dalle2_decoder/dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import blobfile as bf 3 | from mpi4py import MPI 4 | import numpy as np 5 | from torch.utils.data import DataLoader, Dataset 6 | import torch 7 | import json 8 | import os 9 | 10 | def get_loader(batch_size, resolution, image_paths, clip_embedings, tokens, masks, pad_token=50256, zero_clip_emb_prob=0.1, zero_text_prob=0.5, shuffle=True,): 11 | dataset = ImageDataset(resolution, image_paths, clip_embedings, tokens, masks, pad_token, zero_clip_emb_prob, zero_text_prob) 12 | loader = DataLoader( 13 | dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1, drop_last=True 14 | ) 15 | while True: 16 | yield from loader 17 | 18 | 19 | def get_second_loader(batch_size, resolution, json_paths, main_dir, pad_token=50256, zero_clip_emb_prob=0.1, zero_text_prob=0.5, shuffle=True,): 20 | dataset = SecondImageDataset(resolution, json_paths, main_dir, pad_token, zero_clip_emb_prob, zero_text_prob) 21 | loader = DataLoader( 22 | dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1, drop_last=True 23 | ) 24 | while True: 25 | yield from loader 26 | 27 | class ImageDataset(Dataset): 28 | def __init__(self, resolution, image_paths, clip_embedings, tokens, masks, pad_token=50256, zero_clip_emb_prob=0.1, zero_text_prob=0.5): 29 | super().__init__() 30 | self.resolution = resolution 31 | self.image_paths = image_paths 32 | self.clip_embedings = clip_embedings 33 | self.tokens = tokens 34 | self.masks = masks 35 | self.pad_token = pad_token 36 | self.zero_clip_emb_prob = zero_clip_emb_prob 37 | self.zero_text_prob = zero_text_prob 38 | 39 | def __len__(self): 40 | return len(self.image_paths) 41 | 42 | def __getitem__(self, idx): 43 | path = self.image_paths[idx] 44 | with bf.BlobFile(path, "rb") as f: 45 | pil_image = Image.open(f) 46 | pil_image.load() 47 | 48 | # We are not on a new enough PIL to support the `reducing_gap` 49 | # argument, which uses BOX downsampling at powers of two first. 50 | # Thus, we do it by hand to improve downsample quality. 51 | while min(*pil_image.size) >= 2 * self.resolution: 52 | pil_image = pil_image.resize( 53 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 54 | ) 55 | 56 | scale = self.resolution / min(*pil_image.size) 57 | pil_image = pil_image.resize( 58 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 59 | ) 60 | 61 | arr = np.array(pil_image.convert("RGB")) 62 | crop_y = (arr.shape[0] - self.resolution) // 2 63 | crop_x = (arr.shape[1] - self.resolution) // 2 64 | arr = arr[crop_y : crop_y + self.resolution, crop_x : crop_x + self.resolution] 65 | arr = arr.astype(np.float32) / 127.5 - 1 66 | 67 | out_dict = {} 68 | 69 | clip_embeding = self.clip_embedings[idx] 70 | if np.random.binomial(1, self.zero_clip_emb_prob): 71 | clip_embeding = [0] * len(clip_embeding) 72 | clip_embeding = torch.tensor(clip_embeding).float() 73 | 74 | tokens_sample = self.tokens[idx] 75 | mask = self.masks[idx] 76 | if np.random.binomial(1, self.zero_text_prob): 77 | tokens_sample = [self.pad_token] * len(tokens_sample) 78 | mask = [False] * len(mask) 79 | tokens_sample = torch.tensor(tokens_sample) 80 | mask = torch.tensor( 81 | mask, 82 | dtype=torch.bool, 83 | ) 84 | out_dict["clip_emb"] = clip_embeding 85 | out_dict["tokens"] = tokens_sample 86 | out_dict["mask"] = mask 87 | return np.transpose(arr, [2, 0, 1]), out_dict 88 | class SecondImageDataset(Dataset): 89 | def __init__(self, resolution, json_paths, main_dir, pad_token=50256, zero_clip_emb_prob=0.1, zero_text_prob=0.5): 90 | super().__init__() 91 | self.resolution = resolution 92 | self.main_dir = main_dir 93 | self.json_paths = json_paths 94 | self.pad_token = pad_token 95 | self.zero_clip_emb_prob = zero_clip_emb_prob 96 | self.zero_text_prob = zero_text_prob 97 | 98 | def __len__(self): 99 | return len(self.json_paths) 100 | 101 | def __getitem__(self, idx): 102 | 103 | with open(os.path.join(self.main_dir, self.json_paths[idx])) as json_file: 104 | in_data = json.load(json_file) 105 | path = in_data['path'] 106 | with bf.BlobFile(path, "rb") as f: 107 | pil_image = Image.open(f) 108 | pil_image.load() 109 | 110 | # We are not on a new enough PIL to support the `reducing_gap` 111 | # argument, which uses BOX downsampling at powers of two first. 112 | # Thus, we do it by hand to improve downsample quality. 113 | while min(*pil_image.size) >= 2 * self.resolution: 114 | pil_image = pil_image.resize( 115 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 116 | ) 117 | 118 | scale = self.resolution / min(*pil_image.size) 119 | pil_image = pil_image.resize( 120 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 121 | ) 122 | 123 | arr = np.array(pil_image.convert("RGB")) 124 | crop_y = (arr.shape[0] - self.resolution) // 2 125 | crop_x = (arr.shape[1] - self.resolution) // 2 126 | arr = arr[crop_y : crop_y + self.resolution, crop_x : crop_x + self.resolution] 127 | arr = arr.astype(np.float32) / 127.5 - 1 128 | 129 | out_dict = {} 130 | 131 | clip_embeding = [float(i) for i in in_data['clip_emb']] 132 | if np.random.binomial(1, self.zero_clip_emb_prob): 133 | clip_embeding = [0] * len(clip_embeding) 134 | clip_embeding = torch.tensor(clip_embeding).float() 135 | 136 | tokens_sample = in_data['tokens'] 137 | mask = in_data['masks'] 138 | if np.random.binomial(1, self.zero_text_prob): 139 | tokens_sample = [self.pad_token] * len(tokens_sample) 140 | mask = [False] * len(mask) 141 | tokens_sample = torch.tensor(tokens_sample) 142 | mask = torch.tensor( 143 | mask, 144 | dtype=torch.bool, 145 | ) 146 | out_dict["clip_emb"] = clip_embeding 147 | out_dict["tokens"] = tokens_sample 148 | out_dict["mask"] = mask 149 | return np.transpose(arr, [2, 0, 1]), out_dict 150 | -------------------------------------------------------------------------------- /dalle2_decoder/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | 9 | import blobfile as bf 10 | from mpi4py import MPI 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | # Change this to reflect your cluster layout. 15 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 16 | GPUS_PER_NODE = 8 17 | 18 | SETUP_RETRY_COUNT = 3 19 | 20 | 21 | def setup_dist(): 22 | """ 23 | Setup a distributed process group. 24 | """ 25 | if dist.is_initialized(): 26 | return 27 | 28 | comm = MPI.COMM_WORLD 29 | backend = "gloo" if not th.cuda.is_available() else "nccl" 30 | 31 | if backend == "gloo": 32 | hostname = "localhost" 33 | else: 34 | hostname = socket.gethostbyname(socket.getfqdn()) 35 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 36 | os.environ["RANK"] = str(comm.rank) 37 | os.environ["WORLD_SIZE"] = str(comm.size) 38 | 39 | port = comm.bcast(_find_free_port(), root=0) 40 | os.environ["MASTER_PORT"] = str(port) 41 | dist.init_process_group(backend=backend, init_method="env://") 42 | 43 | 44 | def dev(): 45 | """ 46 | Get the device to use for torch.distributed. 47 | """ 48 | if th.cuda.is_available(): 49 | return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}") 50 | return th.device("cpu") 51 | 52 | 53 | def load_state_dict(path, **kwargs): 54 | """ 55 | Load a PyTorch file without redundant fetches across MPI ranks. 56 | """ 57 | if MPI.COMM_WORLD.Get_rank() == 0: 58 | with bf.BlobFile(path, "rb") as f: 59 | data = f.read() 60 | else: 61 | data = None 62 | data = MPI.COMM_WORLD.bcast(data) 63 | return th.load(io.BytesIO(data), **kwargs) 64 | 65 | 66 | def sync_params(params): 67 | """ 68 | Synchronize a sequence of Tensors across ranks from rank 0. 69 | """ 70 | for p in params: 71 | with th.no_grad(): 72 | dist.broadcast(p, 0) 73 | 74 | 75 | def _find_free_port(): 76 | try: 77 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 78 | s.bind(("", 0)) 79 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 80 | return s.getsockname()[1] 81 | finally: 82 | s.close() 83 | s.close() -------------------------------------------------------------------------------- /dalle2_decoder/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import lru_cache 3 | from typing import Dict, Optional 4 | 5 | import requests 6 | import torch as th 7 | from filelock import FileLock 8 | from tqdm.auto import tqdm 9 | 10 | MODEL_PATHS = { 11 | "base": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt", 12 | "upsample": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt", 13 | "base-inpaint": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base_inpaint.pt", 14 | "upsample-inpaint": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample_inpaint.pt", 15 | "clip/image-enc": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/clip_image_enc.pt", 16 | "clip/text-enc": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/clip_text_enc.pt", 17 | } 18 | 19 | 20 | @lru_cache() 21 | def default_cache_dir() -> str: 22 | return os.path.join(os.path.abspath(os.getcwd()), "glide_model_cache") 23 | 24 | 25 | def fetch_file_cached( 26 | url: str, progress: bool = True, cache_dir: Optional[str] = None, chunk_size: int = 4096 27 | ) -> str: 28 | """ 29 | Download the file at the given URL into a local file and return the path. 30 | 31 | If cache_dir is specified, it will be used to download the files. 32 | Otherwise, default_cache_dir() is used. 33 | """ 34 | if cache_dir is None: 35 | cache_dir = default_cache_dir() 36 | os.makedirs(cache_dir, exist_ok=True) 37 | local_path = os.path.join(cache_dir, url.split("/")[-1]) 38 | if os.path.exists(local_path): 39 | return local_path 40 | response = requests.get(url, stream=True) 41 | size = int(response.headers.get("content-length", "0")) 42 | with FileLock(local_path + ".lock"): 43 | if progress: 44 | pbar = tqdm(total=size, unit="iB", unit_scale=True) 45 | tmp_path = local_path + ".tmp" 46 | with open(tmp_path, "wb") as f: 47 | for chunk in response.iter_content(chunk_size): 48 | if progress: 49 | pbar.update(len(chunk)) 50 | f.write(chunk) 51 | os.rename(tmp_path, local_path) 52 | if progress: 53 | pbar.close() 54 | return local_path 55 | 56 | 57 | def load_checkpoint( 58 | checkpoint_name: str, 59 | device: th.device, 60 | progress: bool = True, 61 | cache_dir: Optional[str] = None, 62 | chunk_size: int = 4096, 63 | ) -> Dict[str, th.Tensor]: 64 | if checkpoint_name not in MODEL_PATHS: 65 | raise ValueError( 66 | f"Unknown checkpoint name {checkpoint_name}. Known names are: {MODEL_PATHS.keys()}." 67 | ) 68 | path = fetch_file_cached( 69 | MODEL_PATHS[checkpoint_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size 70 | ) 71 | return th.load(path, map_location=device) 72 | -------------------------------------------------------------------------------- /dalle2_decoder/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to inference with 16-bit precision. 3 | """ 4 | 5 | import torch.nn as nn 6 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 7 | 8 | def convert_module_to_f16(l): 9 | """ 10 | Convert primitive modules to float16. 11 | """ 12 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 13 | l.weight.data = l.weight.data.half() 14 | if l.bias is not None: 15 | l.bias.data = l.bias.data.half() 16 | 17 | 18 | def convert_module_to_f32(l): 19 | """ 20 | Convert primitive modules to float32, undoing convert_module_to_f16(). 21 | """ 22 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 23 | l.weight.data = l.weight.data.float() 24 | if l.bias is not None: 25 | l.bias.data = l.bias.data.float() 26 | def make_master_params(model_params): 27 | """ 28 | Copy model parameters into a (differently-shaped) list of full-precision 29 | parameters. 30 | """ 31 | master_params = _flatten_dense_tensors( 32 | [param.detach().float() for param in model_params] 33 | ) 34 | master_params = nn.Parameter(master_params) 35 | master_params.requires_grad = True 36 | return [master_params] 37 | 38 | 39 | def model_grads_to_master_grads(model_params, master_params): 40 | """ 41 | Copy the gradients from the model parameters into the master parameters 42 | from make_master_params(). 43 | """ 44 | master_params[0].grad = _flatten_dense_tensors( 45 | [param.grad.data.detach().float() for param in model_params] 46 | ) 47 | 48 | 49 | def master_params_to_model_params(model_params, master_params): 50 | """ 51 | Copy the master parameter data back into the model parameters. 52 | """ 53 | # Without copying to a list, if a generator is passed, this will 54 | # silently not copy any parameters. 55 | model_params = list(model_params) 56 | 57 | for param, master_param in zip( 58 | model_params, unflatten_master_params(model_params, master_params) 59 | ): 60 | param.detach().copy_(master_param) 61 | 62 | 63 | def unflatten_master_params(model_params, master_params): 64 | """ 65 | Unflatten the master parameters to look like model_params. 66 | """ 67 | return _unflatten_dense_tensors(master_params[0].detach(), model_params) 68 | 69 | 70 | def zero_grad(model_params): 71 | for param in model_params: 72 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 73 | if param.grad is not None: 74 | param.grad.detach_() 75 | param.grad.zero_() -------------------------------------------------------------------------------- /dalle2_decoder/gaussian_diffusion.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | This code started out as a PyTorch port of Ho et al's diffusion models: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py 5 | Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules. 6 | """ 7 | 8 | import enum 9 | import math 10 | 11 | import numpy as np 12 | import torch as th 13 | 14 | from .utils import mean_flat 15 | from .losses import normal_kl, discretized_gaussian_log_likelihood 16 | 17 | 18 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): 19 | """ 20 | Get a pre-defined beta schedule for the given name. 21 | The beta schedule library consists of beta schedules which remain similar 22 | in the limit of num_diffusion_timesteps. 23 | Beta schedules may be added, but should not be removed or changed once 24 | they are committed to maintain backwards compatibility. 25 | """ 26 | if schedule_name == "linear": 27 | # Linear schedule from Ho et al, extended to work for any number of 28 | # diffusion steps. 29 | scale = 1000 / num_diffusion_timesteps 30 | beta_start = scale * 0.0001 31 | beta_end = scale * 0.02 32 | return np.linspace( 33 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 34 | ) 35 | elif schedule_name == "cosine": 36 | return betas_for_alpha_bar( 37 | num_diffusion_timesteps, 38 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, 39 | ) 40 | else: 41 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}") 42 | 43 | 44 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 45 | """ 46 | Create a beta schedule that discretizes the given alpha_t_bar function, 47 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 48 | :param num_diffusion_timesteps: the number of betas to produce. 49 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 50 | produces the cumulative product of (1-beta) up to that 51 | part of the diffusion process. 52 | :param max_beta: the maximum beta to use; use values lower than 1 to 53 | prevent singularities. 54 | """ 55 | betas = [] 56 | for i in range(num_diffusion_timesteps): 57 | t1 = i / num_diffusion_timesteps 58 | t2 = (i + 1) / num_diffusion_timesteps 59 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 60 | return np.array(betas) 61 | 62 | 63 | class ModelMeanType(enum.Enum): 64 | """ 65 | Which type of output the model predicts. 66 | """ 67 | 68 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1} 69 | START_X = enum.auto() # the model predicts x_0 70 | EPSILON = enum.auto() # the model predicts epsilon 71 | 72 | 73 | class ModelVarType(enum.Enum): 74 | """ 75 | What is used as the model's output variance. 76 | The LEARNED_RANGE option has been added to allow the model to predict 77 | values between FIXED_SMALL and FIXED_LARGE, making its job easier. 78 | """ 79 | 80 | LEARNED = enum.auto() 81 | FIXED_SMALL = enum.auto() 82 | FIXED_LARGE = enum.auto() 83 | LEARNED_RANGE = enum.auto() 84 | 85 | 86 | class LossType(enum.Enum): 87 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances) 88 | RESCALED_MSE = ( 89 | enum.auto() 90 | ) # use raw MSE loss (with RESCALED_KL when learning variances) 91 | KL = enum.auto() # use the variational lower-bound 92 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB 93 | 94 | def is_vb(self): 95 | return self == LossType.KL or self == LossType.RESCALED_KL 96 | 97 | 98 | class GaussianDiffusion: 99 | """ 100 | Utilities for training and sampling diffusion models. 101 | Ported directly from here, and then adapted over time to further experimentation. 102 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 103 | :param betas: a 1-D numpy array of betas for each diffusion timestep, 104 | starting at T and going to 1. 105 | :param model_mean_type: a ModelMeanType determining what the model outputs. 106 | :param model_var_type: a ModelVarType determining how variance is output. 107 | :param loss_type: a LossType determining the loss function to use. 108 | :param rescale_timesteps: if True, pass floating point timesteps into the 109 | model so that they are always scaled like in the 110 | original paper (0 to 1000). 111 | """ 112 | 113 | def __init__( 114 | self, 115 | *, 116 | betas, 117 | model_mean_type, 118 | model_var_type, 119 | loss_type, 120 | rescale_timesteps=False, 121 | ): 122 | self.model_mean_type = model_mean_type 123 | self.model_var_type = model_var_type 124 | self.loss_type = loss_type 125 | self.rescale_timesteps = rescale_timesteps 126 | 127 | # Use float64 for accuracy. 128 | betas = np.array(betas, dtype=np.float64) 129 | self.betas = betas 130 | assert len(betas.shape) == 1, "betas must be 1-D" 131 | assert (betas > 0).all() and (betas <= 1).all() 132 | 133 | self.num_timesteps = int(betas.shape[0]) 134 | 135 | alphas = 1.0 - betas 136 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 137 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) 138 | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) 139 | assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) 140 | 141 | # calculations for diffusion q(x_t | x_{t-1}) and others 142 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) 143 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) 144 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) 145 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) 146 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) 147 | 148 | # calculations for posterior q(x_{t-1} | x_t, x_0) 149 | self.posterior_variance = ( 150 | betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 151 | ) 152 | # log calculation clipped because the posterior variance is 0 at the 153 | # beginning of the diffusion chain. 154 | self.posterior_log_variance_clipped = np.log( 155 | np.append(self.posterior_variance[1], self.posterior_variance[1:]) 156 | ) 157 | self.posterior_mean_coef1 = ( 158 | betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 159 | ) 160 | self.posterior_mean_coef2 = ( 161 | (1.0 - self.alphas_cumprod_prev) 162 | * np.sqrt(alphas) 163 | / (1.0 - self.alphas_cumprod) 164 | ) 165 | 166 | def q_mean_variance(self, x_start, t): 167 | """ 168 | Get the distribution q(x_t | x_0). 169 | :param x_start: the [N x C x ...] tensor of noiseless inputs. 170 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 171 | :return: A tuple (mean, variance, log_variance), all of x_start's shape. 172 | """ 173 | mean = ( 174 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 175 | ) 176 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) 177 | log_variance = _extract_into_tensor( 178 | self.log_one_minus_alphas_cumprod, t, x_start.shape 179 | ) 180 | return mean, variance, log_variance 181 | 182 | def q_sample(self, x_start, t, noise=None): 183 | """ 184 | Diffuse the data for a given number of diffusion steps. 185 | In other words, sample from q(x_t | x_0). 186 | :param x_start: the initial data batch. 187 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 188 | :param noise: if specified, the split-out normal noise. 189 | :return: A noisy version of x_start. 190 | """ 191 | if noise is None: 192 | noise = th.randn_like(x_start) 193 | assert noise.shape == x_start.shape 194 | return ( 195 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 196 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) 197 | * noise 198 | ) 199 | 200 | def q_posterior_mean_variance(self, x_start, x_t, t): 201 | """ 202 | Compute the mean and variance of the diffusion posterior: 203 | q(x_{t-1} | x_t, x_0) 204 | """ 205 | assert x_start.shape == x_t.shape 206 | posterior_mean = ( 207 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start 208 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t 209 | ) 210 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) 211 | posterior_log_variance_clipped = _extract_into_tensor( 212 | self.posterior_log_variance_clipped, t, x_t.shape 213 | ) 214 | assert ( 215 | posterior_mean.shape[0] 216 | == posterior_variance.shape[0] 217 | == posterior_log_variance_clipped.shape[0] 218 | == x_start.shape[0] 219 | ) 220 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 221 | 222 | def p_mean_variance( 223 | self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None 224 | ): 225 | """ 226 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of 227 | the initial x, x_0. 228 | :param model: the model, which takes a signal and a batch of timesteps 229 | as input. 230 | :param x: the [N x C x ...] tensor at time t. 231 | :param t: a 1-D Tensor of timesteps. 232 | :param clip_denoised: if True, clip the denoised signal into [-1, 1]. 233 | :param denoised_fn: if not None, a function which applies to the 234 | x_start prediction before it is used to sample. Applies before 235 | clip_denoised. 236 | :param model_kwargs: if not None, a dict of extra keyword arguments to 237 | pass to the model. This can be used for conditioning. 238 | :return: a dict with the following keys: 239 | - 'mean': the model mean output. 240 | - 'variance': the model variance output. 241 | - 'log_variance': the log of 'variance'. 242 | - 'pred_xstart': the prediction for x_0. 243 | """ 244 | if model_kwargs is None: 245 | model_kwargs = {} 246 | 247 | B, C = x.shape[:2] 248 | assert t.shape == (B,) 249 | model_output = model(x, self._scale_timesteps(t), **model_kwargs) 250 | 251 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: 252 | assert model_output.shape == (B, C * 2, *x.shape[2:]) 253 | model_output, model_var_values = th.split(model_output, C, dim=1) 254 | if self.model_var_type == ModelVarType.LEARNED: 255 | model_log_variance = model_var_values 256 | model_variance = th.exp(model_log_variance) 257 | else: 258 | min_log = _extract_into_tensor( 259 | self.posterior_log_variance_clipped, t, x.shape 260 | ) 261 | max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) 262 | # The model_var_values is [-1, 1] for [min_var, max_var]. 263 | frac = (model_var_values + 1) / 2 264 | model_log_variance = frac * max_log + (1 - frac) * min_log 265 | model_variance = th.exp(model_log_variance) 266 | else: 267 | model_variance, model_log_variance = { 268 | # for fixedlarge, we set the initial (log-)variance like so 269 | # to get a better decoder log likelihood. 270 | ModelVarType.FIXED_LARGE: ( 271 | np.append(self.posterior_variance[1], self.betas[1:]), 272 | np.log(np.append(self.posterior_variance[1], self.betas[1:])), 273 | ), 274 | ModelVarType.FIXED_SMALL: ( 275 | self.posterior_variance, 276 | self.posterior_log_variance_clipped, 277 | ), 278 | }[self.model_var_type] 279 | model_variance = _extract_into_tensor(model_variance, t, x.shape) 280 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) 281 | 282 | def process_xstart(x): 283 | if denoised_fn is not None: 284 | x = denoised_fn(x) 285 | if clip_denoised: 286 | return x.clamp(-1, 1) 287 | return x 288 | 289 | if self.model_mean_type == ModelMeanType.PREVIOUS_X: 290 | pred_xstart = process_xstart( 291 | self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) 292 | ) 293 | model_mean = model_output 294 | elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: 295 | if self.model_mean_type == ModelMeanType.START_X: 296 | pred_xstart = process_xstart(model_output) 297 | else: 298 | pred_xstart = process_xstart( 299 | self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) 300 | ) 301 | model_mean, _, _ = self.q_posterior_mean_variance( 302 | x_start=pred_xstart, x_t=x, t=t 303 | ) 304 | else: 305 | raise NotImplementedError(self.model_mean_type) 306 | 307 | assert ( 308 | model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape 309 | ) 310 | return { 311 | "mean": model_mean, 312 | "variance": model_variance, 313 | "log_variance": model_log_variance, 314 | "pred_xstart": pred_xstart, 315 | } 316 | 317 | def _predict_xstart_from_eps(self, x_t, t, eps): 318 | assert x_t.shape == eps.shape 319 | return ( 320 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 321 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps 322 | ) 323 | 324 | def _predict_xstart_from_xprev(self, x_t, t, xprev): 325 | assert x_t.shape == xprev.shape 326 | return ( # (xprev - coef2*x_t) / coef1 327 | _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev 328 | - _extract_into_tensor( 329 | self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape 330 | ) 331 | * x_t 332 | ) 333 | 334 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart): 335 | return ( 336 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 337 | - pred_xstart 338 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 339 | 340 | def _scale_timesteps(self, t): 341 | if self.rescale_timesteps: 342 | return t.float() * (1000.0 / self.num_timesteps) 343 | return t 344 | 345 | def p_sample( 346 | self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None 347 | ): 348 | """ 349 | Sample x_{t-1} from the model at the given timestep. 350 | :param model: the model to sample from. 351 | :param x: the current tensor at x_{t-1}. 352 | :param t: the value of t, starting at 0 for the first diffusion step. 353 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. 354 | :param denoised_fn: if not None, a function which applies to the 355 | x_start prediction before it is used to sample. 356 | :param model_kwargs: if not None, a dict of extra keyword arguments to 357 | pass to the model. This can be used for conditioning. 358 | :return: a dict containing the following keys: 359 | - 'sample': a random sample from the model. 360 | - 'pred_xstart': a prediction of x_0. 361 | """ 362 | out = self.p_mean_variance( 363 | model, 364 | x, 365 | t, 366 | clip_denoised=clip_denoised, 367 | denoised_fn=denoised_fn, 368 | model_kwargs=model_kwargs, 369 | ) 370 | noise = th.randn_like(x) 371 | nonzero_mask = ( 372 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 373 | ) # no noise when t == 0 374 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise 375 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 376 | 377 | def p_sample_loop( 378 | self, 379 | model, 380 | shape, 381 | noise=None, 382 | clip_denoised=True, 383 | denoised_fn=None, 384 | model_kwargs=None, 385 | device=None, 386 | progress=False, 387 | ): 388 | """ 389 | Generate samples from the model. 390 | :param model: the model module. 391 | :param shape: the shape of the samples, (N, C, H, W). 392 | :param noise: if specified, the noise from the encoder to sample. 393 | Should be of the same shape as `shape`. 394 | :param clip_denoised: if True, clip x_start predictions to [-1, 1]. 395 | :param denoised_fn: if not None, a function which applies to the 396 | x_start prediction before it is used to sample. 397 | :param model_kwargs: if not None, a dict of extra keyword arguments to 398 | pass to the model. This can be used for conditioning. 399 | :param device: if specified, the device to create the samples on. 400 | If not specified, use a model parameter's device. 401 | :param progress: if True, show a tqdm progress bar. 402 | :return: a non-differentiable batch of samples. 403 | """ 404 | final = None 405 | for sample in self.p_sample_loop_progressive( 406 | model, 407 | shape, 408 | noise=noise, 409 | clip_denoised=clip_denoised, 410 | denoised_fn=denoised_fn, 411 | model_kwargs=model_kwargs, 412 | device=device, 413 | progress=progress, 414 | ): 415 | final = sample 416 | return final["sample"] 417 | 418 | def p_sample_loop_progressive( 419 | self, 420 | model, 421 | shape, 422 | noise=None, 423 | clip_denoised=True, 424 | denoised_fn=None, 425 | model_kwargs=None, 426 | device=None, 427 | progress=False, 428 | ): 429 | """ 430 | Generate samples from the model and yield intermediate samples from 431 | each timestep of diffusion. 432 | Arguments are the same as p_sample_loop(). 433 | Returns a generator over dicts, where each dict is the return value of 434 | p_sample(). 435 | """ 436 | if device is None: 437 | device = next(model.parameters()).device 438 | assert isinstance(shape, (tuple, list)) 439 | if noise is not None: 440 | img = noise 441 | else: 442 | img = th.randn(*shape, device=device) 443 | indices = list(range(self.num_timesteps))[::-1] 444 | 445 | if progress: 446 | # Lazy import so that we don't depend on tqdm. 447 | from tqdm.auto import tqdm 448 | 449 | indices = tqdm(indices) 450 | 451 | for i in indices: 452 | t = th.tensor([i] * shape[0], device=device) 453 | with th.no_grad(): 454 | out = self.p_sample( 455 | model, 456 | img, 457 | t, 458 | clip_denoised=clip_denoised, 459 | denoised_fn=denoised_fn, 460 | model_kwargs=model_kwargs, 461 | ) 462 | yield out 463 | img = out["sample"] 464 | 465 | def ddim_sample( 466 | self, 467 | model, 468 | x, 469 | t, 470 | clip_denoised=True, 471 | denoised_fn=None, 472 | model_kwargs=None, 473 | eta=0.0, 474 | ): 475 | """ 476 | Sample x_{t-1} from the model using DDIM. 477 | Same usage as p_sample(). 478 | """ 479 | out = self.p_mean_variance( 480 | model, 481 | x, 482 | t, 483 | clip_denoised=clip_denoised, 484 | denoised_fn=denoised_fn, 485 | model_kwargs=model_kwargs, 486 | ) 487 | # Usually our model outputs epsilon, but we re-derive it 488 | # in case we used x_start or x_prev prediction. 489 | eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) 490 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 491 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) 492 | sigma = ( 493 | eta 494 | * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) 495 | * th.sqrt(1 - alpha_bar / alpha_bar_prev) 496 | ) 497 | # Equation 12. 498 | noise = th.randn_like(x) 499 | mean_pred = ( 500 | out["pred_xstart"] * th.sqrt(alpha_bar_prev) 501 | + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps 502 | ) 503 | nonzero_mask = ( 504 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 505 | ) # no noise when t == 0 506 | sample = mean_pred + nonzero_mask * sigma * noise 507 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 508 | 509 | def ddim_reverse_sample( 510 | self, 511 | model, 512 | x, 513 | t, 514 | clip_denoised=True, 515 | denoised_fn=None, 516 | model_kwargs=None, 517 | eta=0.0, 518 | ): 519 | """ 520 | Sample x_{t+1} from the model using DDIM reverse ODE. 521 | """ 522 | assert eta == 0.0, "Reverse ODE only for deterministic path" 523 | out = self.p_mean_variance( 524 | model, 525 | x, 526 | t, 527 | clip_denoised=clip_denoised, 528 | denoised_fn=denoised_fn, 529 | model_kwargs=model_kwargs, 530 | ) 531 | # Usually our model outputs epsilon, but we re-derive it 532 | # in case we used x_start or x_prev prediction. 533 | eps = ( 534 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x 535 | - out["pred_xstart"] 536 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) 537 | alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) 538 | 539 | # Equation 12. reversed 540 | mean_pred = ( 541 | out["pred_xstart"] * th.sqrt(alpha_bar_next) 542 | + th.sqrt(1 - alpha_bar_next) * eps 543 | ) 544 | 545 | return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} 546 | 547 | def ddim_sample_loop( 548 | self, 549 | model, 550 | shape, 551 | noise=None, 552 | clip_denoised=True, 553 | denoised_fn=None, 554 | model_kwargs=None, 555 | device=None, 556 | progress=False, 557 | eta=0.0, 558 | ): 559 | """ 560 | Generate samples from the model using DDIM. 561 | Same usage as p_sample_loop(). 562 | """ 563 | final = None 564 | for sample in self.ddim_sample_loop_progressive( 565 | model, 566 | shape, 567 | noise=noise, 568 | clip_denoised=clip_denoised, 569 | denoised_fn=denoised_fn, 570 | model_kwargs=model_kwargs, 571 | device=device, 572 | progress=progress, 573 | eta=eta, 574 | ): 575 | final = sample 576 | return final["sample"] 577 | 578 | def ddim_sample_loop_progressive( 579 | self, 580 | model, 581 | shape, 582 | noise=None, 583 | clip_denoised=True, 584 | denoised_fn=None, 585 | model_kwargs=None, 586 | device=None, 587 | progress=False, 588 | eta=0.0, 589 | ): 590 | """ 591 | Use DDIM to sample from the model and yield intermediate samples from 592 | each timestep of DDIM. 593 | Same usage as p_sample_loop_progressive(). 594 | """ 595 | if device is None: 596 | device = next(model.parameters()).device 597 | assert isinstance(shape, (tuple, list)) 598 | if noise is not None: 599 | img = noise 600 | else: 601 | img = th.randn(*shape, device=device) 602 | indices = list(range(self.num_timesteps))[::-1] 603 | 604 | if progress: 605 | # Lazy import so that we don't depend on tqdm. 606 | from tqdm.auto import tqdm 607 | 608 | indices = tqdm(indices) 609 | 610 | for i in indices: 611 | t = th.tensor([i] * shape[0], device=device) 612 | with th.no_grad(): 613 | out = self.ddim_sample( 614 | model, 615 | img, 616 | t, 617 | clip_denoised=clip_denoised, 618 | denoised_fn=denoised_fn, 619 | model_kwargs=model_kwargs, 620 | eta=eta, 621 | ) 622 | yield out 623 | img = out["sample"] 624 | 625 | def _vb_terms_bpd( 626 | self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None 627 | ): 628 | """ 629 | Get a term for the variational lower-bound. 630 | The resulting units are bits (rather than nats, as one might expect). 631 | This allows for comparison to other papers. 632 | :return: a dict with the following keys: 633 | - 'output': a shape [N] tensor of NLLs or KLs. 634 | - 'pred_xstart': the x_0 predictions. 635 | """ 636 | true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( 637 | x_start=x_start, x_t=x_t, t=t 638 | ) 639 | out = self.p_mean_variance( 640 | model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs 641 | ) 642 | kl = normal_kl( 643 | true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] 644 | ) 645 | kl = mean_flat(kl) / np.log(2.0) 646 | 647 | decoder_nll = -discretized_gaussian_log_likelihood( 648 | x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] 649 | ) 650 | assert decoder_nll.shape == x_start.shape 651 | decoder_nll = mean_flat(decoder_nll) / np.log(2.0) 652 | 653 | # At the first timestep return the decoder NLL, 654 | # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) 655 | output = th.where((t == 0), decoder_nll, kl) 656 | return {"output": output, "pred_xstart": out["pred_xstart"]} 657 | 658 | def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): 659 | """ 660 | Compute training losses for a single timestep. 661 | :param model: the model to evaluate loss on. 662 | :param x_start: the [N x C x ...] tensor of inputs. 663 | :param t: a batch of timestep indices. 664 | :param model_kwargs: if not None, a dict of extra keyword arguments to 665 | pass to the model. This can be used for conditioning. 666 | :param noise: if specified, the specific Gaussian noise to try to remove. 667 | :return: a dict with the key "loss" containing a tensor of shape [N]. 668 | Some mean or variance settings may also have other keys. 669 | """ 670 | if model_kwargs is None: 671 | model_kwargs = {} 672 | if noise is None: 673 | noise = th.randn_like(x_start) 674 | x_t = self.q_sample(x_start, t, noise=noise) 675 | 676 | terms = {} 677 | 678 | if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: 679 | terms["loss"] = self._vb_terms_bpd( 680 | model=model, 681 | x_start=x_start, 682 | x_t=x_t, 683 | t=t, 684 | clip_denoised=False, 685 | model_kwargs=model_kwargs, 686 | )["output"] 687 | if self.loss_type == LossType.RESCALED_KL: 688 | terms["loss"] *= self.num_timesteps 689 | elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: 690 | model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) 691 | 692 | if self.model_var_type in [ 693 | ModelVarType.LEARNED, 694 | ModelVarType.LEARNED_RANGE, 695 | ]: 696 | B, C = x_t.shape[:2] 697 | assert model_output.shape == (B, C * 2, *x_t.shape[2:]) 698 | model_output, model_var_values = th.split(model_output, C, dim=1) 699 | # Learn the variance using the variational bound, but don't let 700 | # it affect our mean prediction. 701 | frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) 702 | terms["vb"] = self._vb_terms_bpd( 703 | model=lambda *args, r=frozen_out: r, 704 | x_start=x_start, 705 | x_t=x_t, 706 | t=t, 707 | clip_denoised=False, 708 | )["output"] 709 | if self.loss_type == LossType.RESCALED_MSE: 710 | # Divide by 1000 for equivalence with initial implementation. 711 | # Without a factor of 1/1000, the VB term hurts the MSE term. 712 | terms["vb"] *= self.num_timesteps / 1000.0 713 | 714 | target = { 715 | ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( 716 | x_start=x_start, x_t=x_t, t=t 717 | )[0], 718 | ModelMeanType.START_X: x_start, 719 | ModelMeanType.EPSILON: noise, 720 | }[self.model_mean_type] 721 | assert model_output.shape == target.shape == x_start.shape 722 | terms["mse"] = mean_flat((target - model_output) ** 2) 723 | if "vb" in terms: 724 | terms["loss"] = terms["mse"] + terms["vb"] 725 | else: 726 | terms["loss"] = terms["mse"] 727 | else: 728 | raise NotImplementedError(self.loss_type) 729 | 730 | return terms 731 | 732 | def _prior_bpd(self, x_start): 733 | """ 734 | Get the prior KL term for the variational lower-bound, measured in 735 | bits-per-dim. 736 | This term can't be optimized, as it only depends on the encoder. 737 | :param x_start: the [N x C x ...] tensor of inputs. 738 | :return: a batch of [N] KL values (in bits), one per batch element. 739 | """ 740 | batch_size = x_start.shape[0] 741 | t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) 742 | qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) 743 | kl_prior = normal_kl( 744 | mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 745 | ) 746 | return mean_flat(kl_prior) / np.log(2.0) 747 | 748 | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): 749 | """ 750 | Compute the entire variational lower-bound, measured in bits-per-dim, 751 | as well as other related quantities. 752 | :param model: the model to evaluate loss on. 753 | :param x_start: the [N x C x ...] tensor of inputs. 754 | :param clip_denoised: if True, clip denoised samples. 755 | :param model_kwargs: if not None, a dict of extra keyword arguments to 756 | pass to the model. This can be used for conditioning. 757 | :return: a dict containing the following keys: 758 | - total_bpd: the total variational lower-bound, per batch element. 759 | - prior_bpd: the prior term in the lower-bound. 760 | - vb: an [N x T] tensor of terms in the lower-bound. 761 | - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. 762 | - mse: an [N x T] tensor of epsilon MSEs for each timestep. 763 | """ 764 | device = x_start.device 765 | batch_size = x_start.shape[0] 766 | 767 | vb = [] 768 | xstart_mse = [] 769 | mse = [] 770 | for t in list(range(self.num_timesteps))[::-1]: 771 | t_batch = th.tensor([t] * batch_size, device=device) 772 | noise = th.randn_like(x_start) 773 | x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) 774 | # Calculate VLB term at the current timestep 775 | with th.no_grad(): 776 | out = self._vb_terms_bpd( 777 | model, 778 | x_start=x_start, 779 | x_t=x_t, 780 | t=t_batch, 781 | clip_denoised=clip_denoised, 782 | model_kwargs=model_kwargs, 783 | ) 784 | vb.append(out["output"]) 785 | xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) 786 | eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) 787 | mse.append(mean_flat((eps - noise) ** 2)) 788 | 789 | vb = th.stack(vb, dim=1) 790 | xstart_mse = th.stack(xstart_mse, dim=1) 791 | mse = th.stack(mse, dim=1) 792 | 793 | prior_bpd = self._prior_bpd(x_start) 794 | total_bpd = vb.sum(dim=1) + prior_bpd 795 | return { 796 | "total_bpd": total_bpd, 797 | "prior_bpd": prior_bpd, 798 | "vb": vb, 799 | "xstart_mse": xstart_mse, 800 | "mse": mse, 801 | } 802 | 803 | 804 | def _extract_into_tensor(arr, timesteps, broadcast_shape): 805 | """ 806 | Extract values from a 1-D numpy array for a batch of indices. 807 | :param arr: the 1-D numpy array. 808 | :param timesteps: a tensor of indices into the array to extract. 809 | :param broadcast_shape: a larger shape of K dimensions with the batch 810 | dimension equal to the length of timesteps. 811 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 812 | """ 813 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() 814 | while len(res.shape) < len(broadcast_shape): 815 | res = res[..., None] 816 | return res.expand(broadcast_shape) 817 | -------------------------------------------------------------------------------- /dalle2_decoder/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import shutil 9 | import os.path as osp 10 | import json 11 | import time 12 | import datetime 13 | import tempfile 14 | import warnings 15 | from collections import defaultdict 16 | from contextlib import contextmanager 17 | 18 | DEBUG = 10 19 | INFO = 20 20 | WARN = 30 21 | ERROR = 40 22 | 23 | DISABLED = 50 24 | 25 | 26 | class KVWriter(object): 27 | def writekvs(self, kvs): 28 | raise NotImplementedError 29 | 30 | 31 | class SeqWriter(object): 32 | def writeseq(self, seq): 33 | raise NotImplementedError 34 | 35 | 36 | class HumanOutputFormat(KVWriter, SeqWriter): 37 | def __init__(self, filename_or_file): 38 | if isinstance(filename_or_file, str): 39 | self.file = open(filename_or_file, "wt") 40 | self.own_file = True 41 | else: 42 | assert hasattr(filename_or_file, "read"), ( 43 | "expected file or str, got %s" % filename_or_file 44 | ) 45 | self.file = filename_or_file 46 | self.own_file = False 47 | 48 | def writekvs(self, kvs): 49 | # Create strings for printing 50 | key2str = {} 51 | for (key, val) in sorted(kvs.items()): 52 | if hasattr(val, "__float__"): 53 | valstr = "%-8.3g" % val 54 | else: 55 | valstr = str(val) 56 | key2str[self._truncate(key)] = self._truncate(valstr) 57 | 58 | # Find max widths 59 | if len(key2str) == 0: 60 | print("WARNING: tried to write empty key-value dict") 61 | return 62 | else: 63 | keywidth = max(map(len, key2str.keys())) 64 | valwidth = max(map(len, key2str.values())) 65 | 66 | # Write out the data 67 | dashes = "-" * (keywidth + valwidth + 7) 68 | lines = [dashes] 69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 70 | lines.append( 71 | "| %s%s | %s%s |" 72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 73 | ) 74 | lines.append(dashes) 75 | self.file.write("\n".join(lines) + "\n") 76 | 77 | # Flush the output to the file 78 | self.file.flush() 79 | 80 | def _truncate(self, s): 81 | maxlen = 30 82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 83 | 84 | def writeseq(self, seq): 85 | seq = list(seq) 86 | for (i, elem) in enumerate(seq): 87 | self.file.write(elem) 88 | if i < len(seq) - 1: # add space unless this is the last one 89 | self.file.write(" ") 90 | self.file.write("\n") 91 | self.file.flush() 92 | 93 | def close(self): 94 | if self.own_file: 95 | self.file.close() 96 | 97 | 98 | class JSONOutputFormat(KVWriter): 99 | def __init__(self, filename): 100 | self.file = open(filename, "wt") 101 | 102 | def writekvs(self, kvs): 103 | for k, v in sorted(kvs.items()): 104 | if hasattr(v, "dtype"): 105 | kvs[k] = float(v) 106 | self.file.write(json.dumps(kvs) + "\n") 107 | self.file.flush() 108 | 109 | def close(self): 110 | self.file.close() 111 | 112 | 113 | class CSVOutputFormat(KVWriter): 114 | def __init__(self, filename): 115 | self.file = open(filename, "w+t") 116 | self.keys = [] 117 | self.sep = "," 118 | 119 | def writekvs(self, kvs): 120 | # Add our current row to the history 121 | extra_keys = list(kvs.keys() - self.keys) 122 | extra_keys.sort() 123 | if extra_keys: 124 | self.keys.extend(extra_keys) 125 | self.file.seek(0) 126 | lines = self.file.readlines() 127 | self.file.seek(0) 128 | for (i, k) in enumerate(self.keys): 129 | if i > 0: 130 | self.file.write(",") 131 | self.file.write(k) 132 | self.file.write("\n") 133 | for line in lines[1:]: 134 | self.file.write(line[:-1]) 135 | self.file.write(self.sep * len(extra_keys)) 136 | self.file.write("\n") 137 | for (i, k) in enumerate(self.keys): 138 | if i > 0: 139 | self.file.write(",") 140 | v = kvs.get(k) 141 | if v is not None: 142 | self.file.write(str(v)) 143 | self.file.write("\n") 144 | self.file.flush() 145 | 146 | def close(self): 147 | self.file.close() 148 | 149 | 150 | class TensorBoardOutputFormat(KVWriter): 151 | """ 152 | Dumps key/value pairs into TensorBoard's numeric format. 153 | """ 154 | 155 | def __init__(self, dir): 156 | os.makedirs(dir, exist_ok=True) 157 | self.dir = dir 158 | self.step = 1 159 | prefix = "events" 160 | path = osp.join(osp.abspath(dir), prefix) 161 | import tensorflow as tf 162 | from tensorflow.python import pywrap_tensorflow 163 | from tensorflow.core.util import event_pb2 164 | from tensorflow.python.util import compat 165 | 166 | self.tf = tf 167 | self.event_pb2 = event_pb2 168 | self.pywrap_tensorflow = pywrap_tensorflow 169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 170 | 171 | def writekvs(self, kvs): 172 | def summary_val(k, v): 173 | kwargs = {"tag": k, "simple_value": float(v)} 174 | return self.tf.Summary.Value(**kwargs) 175 | 176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 178 | event.step = ( 179 | self.step 180 | ) # is there any reason why you'd want to specify the step? 181 | self.writer.WriteEvent(event) 182 | self.writer.Flush() 183 | self.step += 1 184 | 185 | def close(self): 186 | if self.writer: 187 | self.writer.Close() 188 | self.writer = None 189 | 190 | 191 | def make_output_format(format, ev_dir, log_suffix=""): 192 | os.makedirs(ev_dir, exist_ok=True) 193 | if format == "stdout": 194 | return HumanOutputFormat(sys.stdout) 195 | elif format == "log": 196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 197 | elif format == "json": 198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 199 | elif format == "csv": 200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 201 | elif format == "tensorboard": 202 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 203 | else: 204 | raise ValueError("Unknown format specified: %s" % (format,)) 205 | 206 | 207 | # ================================================================ 208 | # API 209 | # ================================================================ 210 | 211 | 212 | def logkv(key, val): 213 | """ 214 | Log a value of some diagnostic 215 | Call this once for each diagnostic quantity, each iteration 216 | If called many times, last value will be used. 217 | """ 218 | get_current().logkv(key, val) 219 | 220 | 221 | def logkv_mean(key, val): 222 | """ 223 | The same as logkv(), but if called many times, values averaged. 224 | """ 225 | get_current().logkv_mean(key, val) 226 | 227 | 228 | def logkvs(d): 229 | """ 230 | Log a dictionary of key-value pairs 231 | """ 232 | for (k, v) in d.items(): 233 | logkv(k, v) 234 | 235 | 236 | def dumpkvs(): 237 | """ 238 | Write all of the diagnostics from the current iteration 239 | """ 240 | return get_current().dumpkvs() 241 | 242 | 243 | def getkvs(): 244 | return get_current().name2val 245 | 246 | 247 | def log(*args, level=INFO): 248 | """ 249 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 250 | """ 251 | get_current().log(*args, level=level) 252 | 253 | 254 | def debug(*args): 255 | log(*args, level=DEBUG) 256 | 257 | 258 | def info(*args): 259 | log(*args, level=INFO) 260 | 261 | 262 | def warn(*args): 263 | log(*args, level=WARN) 264 | 265 | 266 | def error(*args): 267 | log(*args, level=ERROR) 268 | 269 | 270 | def set_level(level): 271 | """ 272 | Set logging threshold on current logger. 273 | """ 274 | get_current().set_level(level) 275 | 276 | 277 | def set_comm(comm): 278 | get_current().set_comm(comm) 279 | 280 | 281 | def get_dir(): 282 | """ 283 | Get directory that log files are being written to. 284 | will be None if there is no output directory (i.e., if you didn't call start) 285 | """ 286 | return get_current().get_dir() 287 | 288 | 289 | record_tabular = logkv 290 | dump_tabular = dumpkvs 291 | 292 | 293 | @contextmanager 294 | def profile_kv(scopename): 295 | logkey = "wait_" + scopename 296 | tstart = time.time() 297 | try: 298 | yield 299 | finally: 300 | get_current().name2val[logkey] += time.time() - tstart 301 | 302 | 303 | def profile(n): 304 | """ 305 | Usage: 306 | @profile("my_func") 307 | def my_func(): code 308 | """ 309 | 310 | def decorator_with_name(func): 311 | def func_wrapper(*args, **kwargs): 312 | with profile_kv(n): 313 | return func(*args, **kwargs) 314 | 315 | return func_wrapper 316 | 317 | return decorator_with_name 318 | 319 | 320 | # ================================================================ 321 | # Backend 322 | # ================================================================ 323 | 324 | 325 | def get_current(): 326 | if Logger.CURRENT is None: 327 | _configure_default_logger() 328 | 329 | return Logger.CURRENT 330 | 331 | 332 | class Logger(object): 333 | DEFAULT = None # A logger with no output files. (See right below class definition) 334 | # So that you can still log to the terminal without setting up any output files 335 | CURRENT = None # Current logger being used by the free functions above 336 | 337 | def __init__(self, dir, output_formats, comm=None): 338 | self.name2val = defaultdict(float) # values this iteration 339 | self.name2cnt = defaultdict(int) 340 | self.level = INFO 341 | self.dir = dir 342 | self.output_formats = output_formats 343 | self.comm = comm 344 | 345 | # Logging API, forwarded 346 | # ---------------------------------------- 347 | def logkv(self, key, val): 348 | self.name2val[key] = val 349 | 350 | def logkv_mean(self, key, val): 351 | oldval, cnt = self.name2val[key], self.name2cnt[key] 352 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 353 | self.name2cnt[key] = cnt + 1 354 | 355 | def dumpkvs(self): 356 | if self.comm is None: 357 | d = self.name2val 358 | else: 359 | d = mpi_weighted_mean( 360 | self.comm, 361 | { 362 | name: (val, self.name2cnt.get(name, 1)) 363 | for (name, val) in self.name2val.items() 364 | }, 365 | ) 366 | if self.comm.rank != 0: 367 | d["dummy"] = 1 # so we don't get a warning about empty dict 368 | out = d.copy() # Return the dict for unit testing purposes 369 | for fmt in self.output_formats: 370 | if isinstance(fmt, KVWriter): 371 | fmt.writekvs(d) 372 | self.name2val.clear() 373 | self.name2cnt.clear() 374 | return out 375 | 376 | def log(self, *args, level=INFO): 377 | if self.level <= level: 378 | self._do_log(args) 379 | 380 | # Configuration 381 | # ---------------------------------------- 382 | def set_level(self, level): 383 | self.level = level 384 | 385 | def set_comm(self, comm): 386 | self.comm = comm 387 | 388 | def get_dir(self): 389 | return self.dir 390 | 391 | def close(self): 392 | for fmt in self.output_formats: 393 | fmt.close() 394 | 395 | # Misc 396 | # ---------------------------------------- 397 | def _do_log(self, args): 398 | for fmt in self.output_formats: 399 | if isinstance(fmt, SeqWriter): 400 | fmt.writeseq(map(str, args)) 401 | 402 | 403 | def get_rank_without_mpi_import(): 404 | # check environment variables here instead of importing mpi4py 405 | # to avoid calling MPI_Init() when this module is imported 406 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 407 | if varname in os.environ: 408 | return int(os.environ[varname]) 409 | return 0 410 | 411 | 412 | def mpi_weighted_mean(comm, local_name2valcount): 413 | """ 414 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 415 | Perform a weighted average over dicts that are each on a different node 416 | Input: local_name2valcount: dict mapping key -> (value, count) 417 | Returns: key -> mean 418 | """ 419 | all_name2valcount = comm.gather(local_name2valcount) 420 | if comm.rank == 0: 421 | name2sum = defaultdict(float) 422 | name2count = defaultdict(float) 423 | for n2vc in all_name2valcount: 424 | for (name, (val, count)) in n2vc.items(): 425 | try: 426 | val = float(val) 427 | except ValueError: 428 | if comm.rank == 0: 429 | warnings.warn( 430 | "WARNING: tried to compute mean on non-float {}={}".format( 431 | name, val 432 | ) 433 | ) 434 | else: 435 | name2sum[name] += val * count 436 | name2count[name] += count 437 | return {name: name2sum[name] / name2count[name] for name in name2sum} 438 | else: 439 | return {} 440 | 441 | 442 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 443 | """ 444 | If comm is provided, average all numerical stats across that comm 445 | """ 446 | if dir is None: 447 | dir = os.getenv("OPENAI_LOGDIR") 448 | if dir is None: 449 | dir = osp.join( 450 | tempfile.gettempdir(), 451 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), 452 | ) 453 | assert isinstance(dir, str) 454 | dir = os.path.expanduser(dir) 455 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 456 | 457 | rank = get_rank_without_mpi_import() 458 | if rank > 0: 459 | log_suffix = log_suffix + "-rank%03i" % rank 460 | 461 | if format_strs is None: 462 | if rank == 0: 463 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 464 | else: 465 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 466 | format_strs = filter(None, format_strs) 467 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 468 | 469 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 470 | if output_formats: 471 | log("Logging to %s" % dir) 472 | 473 | 474 | def _configure_default_logger(): 475 | configure() 476 | Logger.DEFAULT = Logger.CURRENT 477 | 478 | 479 | def reset(): 480 | if Logger.CURRENT is not Logger.DEFAULT: 481 | Logger.CURRENT.close() 482 | Logger.CURRENT = Logger.DEFAULT 483 | log("Reset logger") 484 | 485 | 486 | @contextmanager 487 | def scoped_configure(dir=None, format_strs=None, comm=None): 488 | prevlogger = Logger.CURRENT 489 | configure(dir=dir, format_strs=format_strs, comm=comm) 490 | try: 491 | yield 492 | finally: 493 | Logger.CURRENT.close() 494 | Logger.CURRENT = prevlogger -------------------------------------------------------------------------------- /dalle2_decoder/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | Shapes are automatically broadcasted, so batches can be compared to 16 | scalars, among other use cases. 17 | """ 18 | tensor = None 19 | for obj in (mean1, logvar1, mean2, logvar2): 20 | if isinstance(obj, th.Tensor): 21 | tensor = obj 22 | break 23 | assert tensor is not None, "at least one argument must be a Tensor" 24 | 25 | # Force variances to be Tensors. Broadcasting helps convert scalars to 26 | # Tensors, but it does not work for th.exp(). 27 | logvar1, logvar2 = [ 28 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 29 | for x in (logvar1, logvar2) 30 | ] 31 | 32 | return 0.5 * ( 33 | -1.0 34 | + logvar2 35 | - logvar1 36 | + th.exp(logvar1 - logvar2) 37 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 38 | ) 39 | 40 | 41 | def approx_standard_normal_cdf(x): 42 | """ 43 | A fast approximation of the cumulative distribution function of the 44 | standard normal. 45 | """ 46 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 47 | 48 | 49 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 50 | """ 51 | Compute the log-likelihood of a Gaussian distribution discretizing to a 52 | given image. 53 | :param x: the target images. It is assumed that this was uint8 values, 54 | rescaled to the range [-1, 1]. 55 | :param means: the Gaussian mean Tensor. 56 | :param log_scales: the Gaussian log stddev Tensor. 57 | :return: a tensor like x of log probabilities (in nats). 58 | """ 59 | assert x.shape == means.shape == log_scales.shape 60 | centered_x = x - means 61 | inv_stdv = th.exp(-log_scales) 62 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 63 | cdf_plus = approx_standard_normal_cdf(plus_in) 64 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 65 | cdf_min = approx_standard_normal_cdf(min_in) 66 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 67 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 68 | cdf_delta = cdf_plus - cdf_min 69 | log_probs = th.where( 70 | x < -0.999, 71 | log_cdf_plus, 72 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 73 | ) 74 | assert log_probs.shape == x.shape 75 | return log_probs -------------------------------------------------------------------------------- /dalle2_decoder/model_creation.py: -------------------------------------------------------------------------------- 1 | from .gaussian_diffusion import get_named_beta_schedule 2 | from . import gaussian_diffusion as gd 3 | from .respace import SpacedDiffusion, space_timesteps 4 | from dalle2_decoder.text2im_model import ( 5 | InpaintText2ImUNet, 6 | SuperResInpaintText2ImUnet, 7 | SuperResText2ImUNet, 8 | Text2ImUNet, 9 | ) 10 | from dalle2_decoder.tokenizer.bpe import get_encoder 11 | 12 | 13 | def model_and_diffusion_defaults(): 14 | return dict( 15 | image_size=64, 16 | num_channels=192, 17 | num_res_blocks=3, 18 | channel_mult="", 19 | num_heads=1, 20 | num_head_channels=64, 21 | num_heads_upsample=-1, 22 | attention_resolutions="32,16,8", 23 | dropout=0.1, 24 | text_ctx=128, 25 | xf_width=512, 26 | xf_layers=16, 27 | xf_heads=8, 28 | xf_final_ln=True, 29 | xf_padding=True, 30 | use_scale_shift_norm=True, 31 | resblock_updown=True, 32 | use_fp16=True, 33 | cache_text_emb=False, 34 | inpaint=False, 35 | super_res=False, 36 | 37 | learn_sigma=True, 38 | sigma_small=False, 39 | diffusion_steps=1000, 40 | noise_schedule="linear", 41 | timestep_respacing="", 42 | use_kl=False, 43 | predict_xstart=False, 44 | rescale_timesteps=True, 45 | rescale_learned_sigmas=True, 46 | 47 | ) 48 | 49 | def model_and_diffusion_defaults_upsampler(): 50 | result = model_and_diffusion_defaults() 51 | result.update( 52 | dict( 53 | image_size=256, 54 | num_res_blocks=2, 55 | noise_schedule="linear", 56 | super_res=True, 57 | ) 58 | ) 59 | return result 60 | 61 | 62 | def create_model_and_diffusion( 63 | image_size, 64 | num_channels, 65 | num_res_blocks, 66 | channel_mult, 67 | num_heads, 68 | num_head_channels, 69 | num_heads_upsample, 70 | attention_resolutions, 71 | dropout, 72 | text_ctx, 73 | xf_width, 74 | xf_layers, 75 | xf_heads, 76 | xf_final_ln, 77 | xf_padding, 78 | use_scale_shift_norm, 79 | resblock_updown, 80 | use_fp16, 81 | cache_text_emb, 82 | inpaint, 83 | super_res, 84 | 85 | learn_sigma, 86 | sigma_small, 87 | diffusion_steps, 88 | noise_schedule, 89 | timestep_respacing, 90 | use_kl, 91 | predict_xstart, 92 | rescale_timesteps, 93 | rescale_learned_sigmas, 94 | 95 | ): 96 | model = create_model( 97 | image_size, 98 | num_channels, 99 | num_res_blocks, 100 | channel_mult=channel_mult, 101 | attention_resolutions=attention_resolutions, 102 | num_heads=num_heads, 103 | num_head_channels=num_head_channels, 104 | num_heads_upsample=num_heads_upsample, 105 | use_scale_shift_norm=use_scale_shift_norm, 106 | dropout=dropout, 107 | text_ctx=text_ctx, 108 | xf_width=xf_width, 109 | xf_layers=xf_layers, 110 | xf_heads=xf_heads, 111 | xf_final_ln=xf_final_ln, 112 | xf_padding=xf_padding, 113 | resblock_updown=resblock_updown, 114 | use_fp16=use_fp16, 115 | cache_text_emb=cache_text_emb, 116 | inpaint=inpaint, 117 | super_res=super_res, 118 | ) 119 | diffusion = create_gaussian_diffusion( 120 | steps=diffusion_steps, 121 | learn_sigma=learn_sigma, 122 | sigma_small=sigma_small, 123 | noise_schedule=noise_schedule, 124 | use_kl=use_kl, 125 | predict_xstart=predict_xstart, 126 | rescale_timesteps=rescale_timesteps, 127 | rescale_learned_sigmas=rescale_learned_sigmas, 128 | timestep_respacing=timestep_respacing, 129 | ) 130 | return model, diffusion 131 | 132 | 133 | def create_model( 134 | image_size, 135 | num_channels, 136 | num_res_blocks, 137 | channel_mult, 138 | attention_resolutions, 139 | num_heads, 140 | num_head_channels, 141 | num_heads_upsample, 142 | use_scale_shift_norm, 143 | dropout, 144 | text_ctx, 145 | xf_width, 146 | xf_layers, 147 | xf_heads, 148 | xf_final_ln, 149 | xf_padding, 150 | resblock_updown, 151 | use_fp16, 152 | cache_text_emb, 153 | inpaint, 154 | super_res, 155 | ): 156 | if channel_mult == "": 157 | if image_size == 256: 158 | channel_mult = (1, 1, 2, 2, 4, 4) 159 | elif image_size == 128: 160 | channel_mult = (1, 1, 2, 3, 4) 161 | elif image_size == 64: 162 | channel_mult = (1, 2, 3, 4) 163 | else: 164 | raise ValueError(f"unsupported image size: {image_size}") 165 | else: 166 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) 167 | assert 2 ** (len(channel_mult) + 2) == image_size 168 | 169 | attention_ds = [] 170 | for res in attention_resolutions.split(","): 171 | attention_ds.append(image_size // int(res)) 172 | 173 | if inpaint and super_res: 174 | model_cls = SuperResInpaintText2ImUnet 175 | elif inpaint: 176 | model_cls = InpaintText2ImUNet 177 | elif super_res: 178 | model_cls = SuperResText2ImUNet 179 | else: 180 | model_cls = Text2ImUNet 181 | return model_cls( 182 | text_ctx=text_ctx, 183 | xf_width=xf_width, 184 | xf_layers=xf_layers, 185 | xf_heads=xf_heads, 186 | xf_final_ln=xf_final_ln, 187 | tokenizer=get_encoder(), 188 | xf_padding=xf_padding, 189 | in_channels=3, 190 | model_channels=num_channels, 191 | out_channels=6, 192 | num_res_blocks=num_res_blocks, 193 | attention_resolutions=tuple(attention_ds), 194 | dropout=dropout, 195 | channel_mult=channel_mult, 196 | use_fp16=use_fp16, 197 | num_heads=num_heads, 198 | num_head_channels=num_head_channels, 199 | num_heads_upsample=num_heads_upsample, 200 | use_scale_shift_norm=use_scale_shift_norm, 201 | resblock_updown=resblock_updown, 202 | cache_text_emb=cache_text_emb, 203 | ) 204 | 205 | 206 | def create_gaussian_diffusion( 207 | *, 208 | steps=1000, 209 | learn_sigma=False, 210 | sigma_small=False, 211 | noise_schedule="linear", 212 | use_kl=False, 213 | predict_xstart=False, 214 | rescale_timesteps=False, 215 | rescale_learned_sigmas=False, 216 | timestep_respacing="", 217 | ): 218 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 219 | if use_kl: 220 | loss_type = gd.LossType.RESCALED_KL 221 | elif rescale_learned_sigmas: 222 | loss_type = gd.LossType.RESCALED_MSE 223 | else: 224 | loss_type = gd.LossType.MSE 225 | if not timestep_respacing: 226 | timestep_respacing = [steps] 227 | return SpacedDiffusion( 228 | use_timesteps=space_timesteps(steps, timestep_respacing), 229 | betas=betas, 230 | model_mean_type=( 231 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 232 | ), 233 | model_var_type=( 234 | ( 235 | gd.ModelVarType.FIXED_LARGE 236 | if not sigma_small 237 | else gd.ModelVarType.FIXED_SMALL 238 | ) 239 | if not learn_sigma 240 | else gd.ModelVarType.LEARNED_RANGE 241 | ), 242 | loss_type=loss_type, 243 | rescale_timesteps=rescale_timesteps, 244 | ) 245 | -------------------------------------------------------------------------------- /dalle2_decoder/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | def update_ema(target_params, source_params, rate=0.99): 12 | """ 13 | Update target parameters to be closer to those of source parameters using 14 | an exponential moving average. 15 | :param target_params: the target parameter sequence. 16 | :param source_params: the source parameter sequence. 17 | :param rate: the EMA rate (closer to 1 means slower). 18 | """ 19 | for targ, src in zip(target_params, source_params): 20 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 21 | 22 | class GroupNorm32(nn.GroupNorm): 23 | def __init__(self, num_groups, num_channels, swish, eps=1e-5): 24 | super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps) 25 | self.swish = swish 26 | 27 | def forward(self, x): 28 | y = super().forward(x.float()).to(x.dtype) 29 | if self.swish == 1.0: 30 | y = F.silu(y) 31 | elif self.swish: 32 | y = y * F.sigmoid(y * float(self.swish)) 33 | return y 34 | 35 | 36 | def conv_nd(dims, *args, **kwargs): 37 | """ 38 | Create a 1D, 2D, or 3D convolution module. 39 | """ 40 | if dims == 1: 41 | return nn.Conv1d(*args, **kwargs) 42 | elif dims == 2: 43 | return nn.Conv2d(*args, **kwargs) 44 | elif dims == 3: 45 | return nn.Conv3d(*args, **kwargs) 46 | raise ValueError(f"unsupported dimensions: {dims}") 47 | 48 | 49 | def linear(*args, **kwargs): 50 | """ 51 | Create a linear module. 52 | """ 53 | return nn.Linear(*args, **kwargs) 54 | 55 | 56 | def avg_pool_nd(dims, *args, **kwargs): 57 | """ 58 | Create a 1D, 2D, or 3D average pooling module. 59 | """ 60 | if dims == 1: 61 | return nn.AvgPool1d(*args, **kwargs) 62 | elif dims == 2: 63 | return nn.AvgPool2d(*args, **kwargs) 64 | elif dims == 3: 65 | return nn.AvgPool3d(*args, **kwargs) 66 | raise ValueError(f"unsupported dimensions: {dims}") 67 | 68 | 69 | def zero_module(module): 70 | """ 71 | Zero out the parameters of a module and return it. 72 | """ 73 | for p in module.parameters(): 74 | p.detach().zero_() 75 | return module 76 | 77 | 78 | def scale_module(module, scale): 79 | """ 80 | Scale the parameters of a module and return it. 81 | """ 82 | for p in module.parameters(): 83 | p.detach().mul_(scale) 84 | return module 85 | 86 | 87 | def normalization(channels, swish=0.0): 88 | """ 89 | Make a standard normalization layer, with an optional swish activation. 90 | 91 | :param channels: number of input channels. 92 | :return: an nn.Module for normalization. 93 | """ 94 | return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) 95 | 96 | 97 | def timestep_embedding(timesteps, dim, max_period=10000): 98 | """ 99 | Create sinusoidal timestep embeddings. 100 | 101 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 102 | These may be fractional. 103 | :param dim: the dimension of the output. 104 | :param max_period: controls the minimum frequency of the embeddings. 105 | :return: an [N x dim] Tensor of positional embeddings. 106 | """ 107 | half = dim // 2 108 | freqs = th.exp( 109 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 110 | ).to(device=timesteps.device) 111 | args = timesteps[:, None].float() * freqs[None] 112 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 113 | if dim % 2: 114 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 115 | return embedding 116 | -------------------------------------------------------------------------------- /dalle2_decoder/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | :param name: the name of the sampler. 12 | :param diffusion: the diffusion object to sample for. 13 | """ 14 | if name == "uniform": 15 | return UniformSampler(diffusion) 16 | elif name == "loss-second-moment": 17 | return LossSecondMomentResampler(diffusion) 18 | else: 19 | raise NotImplementedError(f"unknown schedule sampler: {name}") 20 | 21 | 22 | class ScheduleSampler(ABC): 23 | """ 24 | A distribution over timesteps in the diffusion process, intended to reduce 25 | variance of the objective. 26 | By default, samplers perform unbiased importance sampling, in which the 27 | objective's mean is unchanged. 28 | However, subclasses may override sample() to change how the resampled 29 | terms are reweighted, allowing for actual changes in the objective. 30 | """ 31 | 32 | @abstractmethod 33 | def weights(self): 34 | """ 35 | Get a numpy array of weights, one per diffusion step. 36 | The weights needn't be normalized, but must be positive. 37 | """ 38 | 39 | def sample(self, batch_size, device): 40 | """ 41 | Importance-sample timesteps for a batch. 42 | :param batch_size: the number of timesteps. 43 | :param device: the torch device to save to. 44 | :return: a tuple (timesteps, weights): 45 | - timesteps: a tensor of timestep indices. 46 | - weights: a tensor of weights to scale the resulting losses. 47 | """ 48 | w = self.weights() 49 | p = w / np.sum(w) 50 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 51 | indices = th.from_numpy(indices_np).long().to(device) 52 | weights_np = 1 / (len(p) * p[indices_np]) 53 | weights = th.from_numpy(weights_np).float().to(device) 54 | return indices, weights 55 | 56 | 57 | class UniformSampler(ScheduleSampler): 58 | def __init__(self, diffusion): 59 | self.diffusion = diffusion 60 | self._weights = np.ones([diffusion.num_timesteps]) 61 | 62 | def weights(self): 63 | return self._weights 64 | 65 | 66 | class LossAwareSampler(ScheduleSampler): 67 | def update_with_local_losses(self, local_ts, local_losses): 68 | """ 69 | Update the reweighting using losses from a model. 70 | Call this method from each rank with a batch of timesteps and the 71 | corresponding losses for each of those timesteps. 72 | This method will perform synchronization to make sure all of the ranks 73 | maintain the exact same reweighting. 74 | :param local_ts: an integer Tensor of timesteps. 75 | :param local_losses: a 1D Tensor of losses. 76 | """ 77 | batch_sizes = [ 78 | th.tensor([0], dtype=th.int32, device=local_ts.device) 79 | for _ in range(dist.get_world_size()) 80 | ] 81 | dist.all_gather( 82 | batch_sizes, 83 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 84 | ) 85 | 86 | # Pad all_gather batches to be the maximum batch size. 87 | batch_sizes = [x.item() for x in batch_sizes] 88 | max_bs = max(batch_sizes) 89 | 90 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 91 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 92 | dist.all_gather(timestep_batches, local_ts) 93 | dist.all_gather(loss_batches, local_losses) 94 | timesteps = [ 95 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 96 | ] 97 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 98 | self.update_with_all_losses(timesteps, losses) 99 | 100 | @abstractmethod 101 | def update_with_all_losses(self, ts, losses): 102 | """ 103 | Update the reweighting using losses from a model. 104 | Sub-classes should override this method to update the reweighting 105 | using losses from the model. 106 | This method directly updates the reweighting without synchronizing 107 | between workers. It is called by update_with_local_losses from all 108 | ranks with identical arguments. Thus, it should have deterministic 109 | behavior to maintain state across workers. 110 | :param ts: a list of int timesteps. 111 | :param losses: a list of float losses, one per timestep. 112 | """ 113 | 114 | 115 | class LossSecondMomentResampler(LossAwareSampler): 116 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 117 | self.diffusion = diffusion 118 | self.history_per_term = history_per_term 119 | self.uniform_prob = uniform_prob 120 | self._loss_history = np.zeros( 121 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 122 | ) 123 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 124 | 125 | def weights(self): 126 | if not self._warmed_up(): 127 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 128 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 129 | weights /= np.sum(weights) 130 | weights *= 1 - self.uniform_prob 131 | weights += self.uniform_prob / len(weights) 132 | return weights 133 | 134 | def update_with_all_losses(self, ts, losses): 135 | for t, loss in zip(ts, losses): 136 | if self._loss_counts[t] == self.history_per_term: 137 | # Shift out the oldest loss term. 138 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 139 | self._loss_history[t, -1] = loss 140 | else: 141 | self._loss_history[t, self._loss_counts[t]] = loss 142 | self._loss_counts[t] += 1 143 | 144 | def _warmed_up(self): 145 | return (self._loss_counts == self.history_per_term).all() -------------------------------------------------------------------------------- /dalle2_decoder/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | For example, if there's 300 timesteps and the section counts are [10,15,20] 13 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 14 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 15 | If the stride is a string starting with "ddim", then the fixed striding 16 | from the DDIM paper is used, and only one section is allowed. 17 | :param num_timesteps: the number of diffusion steps in the original 18 | process to divide up. 19 | :param section_counts: either a list of numbers, or a string containing 20 | comma-separated numbers, indicating the step count 21 | per section. As a special case, use "ddimN" where N 22 | is a number of steps to use the striding from the 23 | DDIM paper. 24 | :return: a set of diffusion steps from the original process to use. 25 | """ 26 | if isinstance(section_counts, str): 27 | if section_counts.startswith("ddim"): 28 | desired_count = int(section_counts[len("ddim") :]) 29 | for i in range(1, num_timesteps): 30 | if len(range(0, num_timesteps, i)) == desired_count: 31 | return set(range(0, num_timesteps, i)) 32 | raise ValueError( 33 | f"cannot create exactly {num_timesteps} steps with an integer stride" 34 | ) 35 | section_counts = [int(x) for x in section_counts.split(",")] 36 | size_per = num_timesteps // len(section_counts) 37 | extra = num_timesteps % len(section_counts) 38 | start_idx = 0 39 | all_steps = [] 40 | for i, section_count in enumerate(section_counts): 41 | size = size_per + (1 if i < extra else 0) 42 | if size < section_count: 43 | raise ValueError( 44 | f"cannot divide section of {size} steps into {section_count}" 45 | ) 46 | if section_count <= 1: 47 | frac_stride = 1 48 | else: 49 | frac_stride = (size - 1) / (section_count - 1) 50 | cur_idx = 0.0 51 | taken_steps = [] 52 | for _ in range(section_count): 53 | taken_steps.append(start_idx + round(cur_idx)) 54 | cur_idx += frac_stride 55 | all_steps += taken_steps 56 | start_idx += size 57 | return set(all_steps) 58 | 59 | 60 | class SpacedDiffusion(GaussianDiffusion): 61 | """ 62 | A diffusion process which can skip steps in a base diffusion process. 63 | :param use_timesteps: a collection (sequence or set) of timesteps from the 64 | original diffusion process to retain. 65 | :param kwargs: the kwargs to create the base diffusion process. 66 | """ 67 | 68 | def __init__(self, use_timesteps, **kwargs): 69 | self.use_timesteps = set(use_timesteps) 70 | self.timestep_map = [] 71 | self.original_num_steps = len(kwargs["betas"]) 72 | 73 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 74 | last_alpha_cumprod = 1.0 75 | new_betas = [] 76 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 77 | if i in self.use_timesteps: 78 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 79 | last_alpha_cumprod = alpha_cumprod 80 | self.timestep_map.append(i) 81 | kwargs["betas"] = np.array(new_betas) 82 | super().__init__(**kwargs) 83 | 84 | def p_mean_variance( 85 | self, model, *args, **kwargs 86 | ): # pylint: disable=signature-differs 87 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 88 | 89 | def training_losses( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 93 | 94 | def _wrap_model(self, model): 95 | if isinstance(model, _WrappedModel): 96 | return model 97 | return _WrappedModel( 98 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 99 | ) 100 | 101 | def _scale_timesteps(self, t): 102 | # Scaling is done by the wrapped model. 103 | return t 104 | 105 | 106 | class _WrappedModel: 107 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 108 | self.model = model 109 | self.timestep_map = timestep_map 110 | self.rescale_timesteps = rescale_timesteps 111 | self.original_num_steps = original_num_steps 112 | 113 | def __call__(self, x, ts, **kwargs): 114 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 115 | new_ts = map_tensor[ts] 116 | if self.rescale_timesteps: 117 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 118 | return self.model(x, new_ts, **kwargs) -------------------------------------------------------------------------------- /dalle2_decoder/text2im_model.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .nn import timestep_embedding 6 | from .unet import UNetModel 7 | from .xf import LayerNorm, Transformer, convert_module_to_f16 8 | 9 | 10 | class Text2ImUNet(UNetModel): 11 | """ 12 | A UNetModel that conditions on text with an encoding transformer. 13 | 14 | Expects an extra kwarg `tokens` of text. 15 | 16 | :param text_ctx: number of text tokens to expect. 17 | :param xf_width: width of the transformer. 18 | :param xf_layers: depth of the transformer. 19 | :param xf_heads: heads in the transformer. 20 | :param xf_final_ln: use a LayerNorm after the output layer. 21 | :param tokenizer: the text tokenizer for sampling/vocab size. 22 | """ 23 | 24 | def __init__( 25 | self, 26 | text_ctx, 27 | xf_width, 28 | xf_layers, 29 | xf_heads, 30 | xf_final_ln, 31 | tokenizer, 32 | *args, 33 | cache_text_emb=False, 34 | xf_ar=0.0, 35 | xf_padding=False, 36 | share_unemb=False, 37 | **kwargs, 38 | ): 39 | self.text_ctx = text_ctx 40 | self.xf_width = xf_width 41 | self.xf_ar = xf_ar 42 | self.xf_padding = xf_padding 43 | self.tokenizer = tokenizer 44 | 45 | if not xf_width: 46 | super().__init__(*args, **kwargs, encoder_channels=None) 47 | else: 48 | super().__init__(*args, **kwargs, encoder_channels=xf_width) 49 | if self.xf_width: 50 | self.transformer = Transformer( 51 | text_ctx, 52 | xf_width, 53 | xf_layers, 54 | xf_heads, 55 | ) 56 | if xf_final_ln: 57 | self.final_ln = LayerNorm(xf_width) 58 | else: 59 | self.final_ln = None 60 | 61 | self.token_embedding = nn.Embedding(self.tokenizer.n_vocab, xf_width) 62 | self.positional_embedding = nn.Parameter(th.empty(text_ctx, xf_width, dtype=th.float32)) 63 | self.transformer_proj = nn.Linear(xf_width, self.model_channels * 4) 64 | 65 | if self.xf_padding: 66 | self.padding_embedding = nn.Parameter( 67 | th.empty(text_ctx, xf_width, dtype=th.float32) 68 | ) 69 | if self.xf_ar: 70 | self.unemb = nn.Linear(xf_width, self.tokenizer.n_vocab) 71 | if share_unemb: 72 | self.unemb.weight = self.token_embedding.weight 73 | 74 | self.cache_text_emb = cache_text_emb 75 | self.cache = None 76 | # 77 | self.time_to_half = nn.Linear(768, 768 // 2) 78 | self.clip_to_half = nn.Linear(768, 768 // 2) 79 | # 80 | def convert_to_fp16(self): 81 | super().convert_to_fp16() 82 | if self.xf_width: 83 | self.transformer.apply(convert_module_to_f16) 84 | self.transformer_proj.to(th.float16) 85 | self.token_embedding.to(th.float16) 86 | self.positional_embedding.to(th.float16) 87 | if self.xf_padding: 88 | self.padding_embedding.to(th.float16) 89 | if self.xf_ar: 90 | self.unemb.to(th.float16) 91 | 92 | def get_text_emb(self, tokens, mask): 93 | assert tokens is not None 94 | 95 | if self.cache_text_emb and self.cache is not None: 96 | assert ( 97 | tokens == self.cache["tokens"] 98 | ).all(), f"Tokens {tokens.cpu().numpy().tolist()} do not match cache {self.cache['tokens'].cpu().numpy().tolist()}" 99 | return self.cache 100 | 101 | xf_in = self.token_embedding(tokens.long()) 102 | xf_in = xf_in + self.positional_embedding[None] 103 | if self.xf_padding: 104 | assert mask is not None 105 | xf_in = th.where(mask[..., None], xf_in, self.padding_embedding[None]) 106 | xf_out = self.transformer(xf_in.to(self.dtype)) 107 | if self.final_ln is not None: 108 | xf_out = self.final_ln(xf_out) 109 | xf_proj = self.transformer_proj(xf_out[:, -1]) 110 | xf_out = xf_out.permute(0, 2, 1) # NLC -> NCL 111 | 112 | outputs = dict(xf_proj=xf_proj, xf_out=xf_out) 113 | 114 | if self.cache_text_emb: 115 | self.cache = dict( 116 | tokens=tokens, 117 | xf_proj=xf_proj.detach(), 118 | xf_out=xf_out.detach() if xf_out is not None else None, 119 | ) 120 | 121 | return outputs 122 | 123 | def del_cache(self): 124 | self.cache = None 125 | 126 | def forward(self, x, timesteps, clip_emb=None, tokens=None, mask=None): 127 | hs = [] 128 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 129 | if self.xf_width: 130 | text_outputs = self.get_text_emb(tokens, mask) 131 | xf_proj, xf_out = text_outputs["xf_proj"], text_outputs["xf_out"] 132 | clip_emb = clip_emb.to(emb) 133 | emb = th.cat([self.time_to_half(emb + xf_proj.to(emb) + clip_emb), self.clip_to_half(clip_emb)], dim=1).to(clip_emb) 134 | else: 135 | xf_out = None 136 | h = x.type(self.dtype) 137 | for module in self.input_blocks: 138 | h = module(h, emb, xf_out) 139 | hs.append(h) 140 | h = self.middle_block(h, emb, xf_out) 141 | for module in self.output_blocks: 142 | h = th.cat([h, hs.pop()], dim=1) 143 | h = module(h, emb, xf_out) 144 | h = h.type(x.dtype) 145 | h = self.out(h) 146 | return h 147 | 148 | 149 | class SuperResText2ImUNet(Text2ImUNet): 150 | """ 151 | A text2im model that performs super-resolution. 152 | Expects an extra kwarg `low_res` to condition on a low-resolution image. 153 | """ 154 | 155 | def __init__(self, *args, **kwargs): 156 | if "in_channels" in kwargs: 157 | kwargs = dict(kwargs) 158 | kwargs["in_channels"] = kwargs["in_channels"] * 2 159 | else: 160 | # Curse you, Python. Or really, just curse positional arguments :|. 161 | args = list(args) 162 | args[1] = args[1] * 2 163 | super().__init__(*args, **kwargs) 164 | 165 | def forward(self, x, timesteps, low_res=None, **kwargs): 166 | _, _, new_height, new_width = x.shape 167 | upsampled = F.interpolate( 168 | low_res, (new_height, new_width), mode="bilinear", align_corners=False 169 | ) 170 | x = th.cat([x, upsampled], dim=1) 171 | return super().forward(x, timesteps, **kwargs) 172 | 173 | 174 | class InpaintText2ImUNet(Text2ImUNet): 175 | """ 176 | A text2im model which can perform inpainting. 177 | """ 178 | 179 | def __init__(self, *args, **kwargs): 180 | super().__init__(*args, **kwargs) 181 | 182 | def forward(self, x, timesteps, inpaint_image=None, inpaint_mask=None, **kwargs): 183 | if inpaint_image is None: 184 | inpaint_image = th.zeros_like(x) 185 | if inpaint_mask is None: 186 | inpaint_mask = th.zeros_like(x[:, :1]) 187 | inverted_mask = (inpaint_mask == 0).int() 188 | inpaint_image = inpaint_image * inpaint_mask 189 | new_x = inpaint_image + inverted_mask * x 190 | return super().forward( 191 | new_x, 192 | timesteps, 193 | **kwargs, 194 | ) 195 | 196 | 197 | class SuperResInpaintText2ImUnet(Text2ImUNet): 198 | """ 199 | A text2im model which can perform both upsampling and inpainting. 200 | """ 201 | 202 | def __init__(self, *args, **kwargs): 203 | if "in_channels" in kwargs: 204 | kwargs = dict(kwargs) 205 | kwargs["in_channels"] = kwargs["in_channels"] * 3 + 1 206 | else: 207 | # Curse you, Python. Or really, just curse positional arguments :|. 208 | args = list(args) 209 | args[1] = args[1] * 3 + 1 210 | super().__init__(*args, **kwargs) 211 | 212 | def forward( 213 | self, 214 | x, 215 | timesteps, 216 | inpaint_image=None, 217 | inpaint_mask=None, 218 | low_res=None, 219 | **kwargs, 220 | ): 221 | if inpaint_image is None: 222 | inpaint_image = th.zeros_like(x) 223 | if inpaint_mask is None: 224 | inpaint_mask = th.zeros_like(x[:, :1]) 225 | _, _, new_height, new_width = x.shape 226 | upsampled = F.interpolate( 227 | low_res, (new_height, new_width), mode="bilinear", align_corners=False 228 | ) 229 | return super().forward( 230 | th.cat([x, inpaint_image * inpaint_mask, inpaint_mask, upsampled], dim=1), 231 | timesteps, 232 | **kwargs, 233 | ) 234 | -------------------------------------------------------------------------------- /dalle2_decoder/tokenizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralPushkin/Dalle2-Decoder/47af769d01d50b7a7f9c1c2b8cef6112b9870bdb/dalle2_decoder/tokenizer/__init__.py -------------------------------------------------------------------------------- /dalle2_decoder/tokenizer/bpe.py: -------------------------------------------------------------------------------- 1 | """ 2 | Byte pair encoding utilities adapted from: 3 | https://github.com/openai/gpt-2/blob/master/src/encoder.py 4 | """ 5 | 6 | import gzip 7 | import json 8 | import os 9 | from functools import lru_cache 10 | from typing import List, Tuple 11 | 12 | import regex as re 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = ( 27 | list(range(ord("!"), ord("~") + 1)) 28 | + list(range(ord("¡"), ord("¬") + 1)) 29 | + list(range(ord("®"), ord("ÿ") + 1)) 30 | ) 31 | cs = bs[:] 32 | n = 0 33 | for b in range(2 ** 8): 34 | if b not in bs: 35 | bs.append(b) 36 | cs.append(2 ** 8 + n) 37 | n += 1 38 | cs = [chr(n) for n in cs] 39 | return dict(zip(bs, cs)) 40 | 41 | 42 | def get_pairs(word): 43 | """Return set of symbol pairs in a word. 44 | Word is represented as tuple of symbols (symbols being variable-length strings). 45 | """ 46 | pairs = set() 47 | prev_char = word[0] 48 | for char in word[1:]: 49 | pairs.add((prev_char, char)) 50 | prev_char = char 51 | return pairs 52 | 53 | 54 | class Encoder: 55 | def __init__(self, encoder, bpe_merges, errors="replace"): 56 | self.encoder = encoder 57 | self.decoder = {v: k for k, v in self.encoder.items()} 58 | self.errors = errors # how to handle errors in decoding 59 | self.byte_encoder = bytes_to_unicode() 60 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 61 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 62 | self.cache = {} 63 | 64 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 65 | self.pat = re.compile( 66 | r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" 67 | ) 68 | 69 | @property 70 | def n_vocab(self) -> int: 71 | return len(self.encoder) 72 | 73 | @property 74 | def end_token(self) -> int: 75 | return self.n_vocab - 1 76 | 77 | def padded_tokens_and_mask( 78 | self, tokens: List[int], text_ctx: int 79 | ) -> Tuple[List[int], List[bool]]: 80 | tokens = tokens[:text_ctx] 81 | padding = text_ctx - len(tokens) 82 | padded_tokens = tokens + [self.end_token] * padding 83 | mask = [True] * len(tokens) + [False] * padding 84 | return padded_tokens, mask 85 | 86 | def bpe(self, token): 87 | if token in self.cache: 88 | return self.cache[token] 89 | word = tuple(token) 90 | pairs = get_pairs(word) 91 | 92 | if not pairs: 93 | return token 94 | 95 | while True: 96 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) 97 | if bigram not in self.bpe_ranks: 98 | break 99 | first, second = bigram 100 | new_word = [] 101 | i = 0 102 | while i < len(word): 103 | try: 104 | j = word.index(first, i) 105 | new_word.extend(word[i:j]) 106 | i = j 107 | except: # pylint: disable=bare-except 108 | new_word.extend(word[i:]) 109 | break 110 | 111 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 112 | new_word.append(first + second) 113 | i += 2 114 | else: 115 | new_word.append(word[i]) 116 | i += 1 117 | new_word = tuple(new_word) 118 | word = new_word 119 | if len(word) == 1: 120 | break 121 | else: 122 | pairs = get_pairs(word) 123 | word = " ".join(word) 124 | self.cache[token] = word 125 | return word 126 | 127 | def encode(self, text): 128 | text = text.lower() 129 | bpe_tokens = [] 130 | for token in re.findall(self.pat, text): 131 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) 132 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")) 133 | return bpe_tokens 134 | 135 | def decode(self, tokens): 136 | text = "".join([self.decoder[token] for token in tokens]) 137 | text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) 138 | return text 139 | 140 | 141 | def get_encoder(): 142 | root_dir = os.path.dirname(os.path.abspath(__file__)) 143 | with gzip.open(os.path.join(root_dir, "encoder.json.gz"), "r") as f: 144 | encoder = json.load(f) 145 | with gzip.open(os.path.join(root_dir, "vocab.bpe.gz"), "r") as f: 146 | bpe_data = str(f.read(), "utf-8") 147 | bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]] 148 | return Encoder( 149 | encoder=encoder, 150 | bpe_merges=bpe_merges, 151 | ) 152 | -------------------------------------------------------------------------------- /dalle2_decoder/tokenizer/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralPushkin/Dalle2-Decoder/47af769d01d50b7a7f9c1c2b8cef6112b9870bdb/dalle2_decoder/tokenizer/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /dalle2_decoder/tokenizer/encoder.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralPushkin/Dalle2-Decoder/47af769d01d50b7a7f9c1c2b8cef6112b9870bdb/dalle2_decoder/tokenizer/encoder.json.gz -------------------------------------------------------------------------------- /dalle2_decoder/tokenizer/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copied from: https://github.com/openai/CLIP/blob/573315e83f07b53a61ff5098757e8fc885f1703e/clip/simple_tokenizer.py 3 | """ 4 | 5 | import gzip 6 | import html 7 | import os 8 | from functools import lru_cache 9 | from typing import List, Tuple 10 | 11 | import ftfy 12 | import regex as re 13 | 14 | 15 | @lru_cache() 16 | def default_bpe(): 17 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 18 | 19 | 20 | @lru_cache() 21 | def bytes_to_unicode(): 22 | """ 23 | Returns list of utf-8 byte and a corresponding list of unicode strings. 24 | The reversible bpe codes work on unicode strings. 25 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 26 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 27 | This is a signficant percentage of your normal, say, 32K bpe vocab. 28 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 29 | And avoids mapping to whitespace/control characters the bpe code barfs on. 30 | """ 31 | bs = ( 32 | list(range(ord("!"), ord("~") + 1)) 33 | + list(range(ord("¡"), ord("¬") + 1)) 34 | + list(range(ord("®"), ord("ÿ") + 1)) 35 | ) 36 | cs = bs[:] 37 | n = 0 38 | for b in range(2 ** 8): 39 | if b not in bs: 40 | bs.append(b) 41 | cs.append(2 ** 8 + n) 42 | n += 1 43 | cs = [chr(n) for n in cs] 44 | return dict(zip(bs, cs)) 45 | 46 | 47 | def get_pairs(word): 48 | """Return set of symbol pairs in a word. 49 | Word is represented as tuple of symbols (symbols being variable-length strings). 50 | """ 51 | pairs = set() 52 | prev_char = word[0] 53 | for char in word[1:]: 54 | pairs.add((prev_char, char)) 55 | prev_char = char 56 | return pairs 57 | 58 | 59 | def basic_clean(text): 60 | text = ftfy.fix_text(text) 61 | text = html.unescape(html.unescape(text)) 62 | return text.strip() 63 | 64 | 65 | def whitespace_clean(text): 66 | text = re.sub(r"\s+", " ", text) 67 | text = text.strip() 68 | return text 69 | 70 | 71 | class SimpleTokenizer(object): 72 | def __init__(self, bpe_path: str = default_bpe()): 73 | self.byte_encoder = bytes_to_unicode() 74 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 75 | merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") 76 | merges = merges[1 : 49152 - 256 - 2 + 1] 77 | merges = [tuple(merge.split()) for merge in merges] 78 | vocab = list(bytes_to_unicode().values()) 79 | vocab = vocab + [v + "" for v in vocab] 80 | for merge in merges: 81 | vocab.append("".join(merge)) 82 | vocab.extend(["<|startoftext|>", "<|endoftext|>"]) 83 | self.encoder = dict(zip(vocab, range(len(vocab)))) 84 | self.decoder = {v: k for k, v in self.encoder.items()} 85 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 86 | self.cache = {"<|startoftext|>": "<|startoftext|>", "<|endoftext|>": "<|endoftext|>"} 87 | self.pat = re.compile( 88 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", 89 | re.IGNORECASE, 90 | ) 91 | 92 | @property 93 | def start_token(self): 94 | return self.encoder["<|startoftext|>"] 95 | 96 | @property 97 | def end_token(self): 98 | return self.encoder["<|endoftext|>"] 99 | 100 | def padded_tokens_and_len(self, tokens: List[int], text_ctx: int) -> Tuple[List[int], int]: 101 | tokens = [self.start_token] + tokens[: text_ctx - 2] + [self.end_token] 102 | text_len = len(tokens) 103 | padding = text_ctx - len(tokens) 104 | padded_tokens = tokens + [0] * padding 105 | return padded_tokens, text_len 106 | 107 | def bpe(self, token): 108 | if token in self.cache: 109 | return self.cache[token] 110 | word = tuple(token[:-1]) + (token[-1] + "",) 111 | pairs = get_pairs(word) 112 | 113 | if not pairs: 114 | return token + "" 115 | 116 | while True: 117 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) 118 | if bigram not in self.bpe_ranks: 119 | break 120 | first, second = bigram 121 | new_word = [] 122 | i = 0 123 | while i < len(word): 124 | try: 125 | j = word.index(first, i) 126 | new_word.extend(word[i:j]) 127 | i = j 128 | except: # pylint: disable=bare-except 129 | new_word.extend(word[i:]) 130 | break 131 | 132 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 133 | new_word.append(first + second) 134 | i += 2 135 | else: 136 | new_word.append(word[i]) 137 | i += 1 138 | new_word = tuple(new_word) 139 | word = new_word 140 | if len(word) == 1: 141 | break 142 | else: 143 | pairs = get_pairs(word) 144 | word = " ".join(word) 145 | self.cache[token] = word 146 | return word 147 | 148 | def encode(self, text): 149 | bpe_tokens = [] 150 | text = whitespace_clean(basic_clean(text)).lower() 151 | for token in re.findall(self.pat, text): 152 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) 153 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")) 154 | return bpe_tokens 155 | 156 | def decode(self, tokens): 157 | text = "".join([self.decoder[token] for token in tokens]) 158 | text = ( 159 | bytearray([self.byte_decoder[c] for c in text]) 160 | .decode("utf-8", errors="replace") 161 | .replace("", " ") 162 | ) 163 | return text 164 | -------------------------------------------------------------------------------- /dalle2_decoder/tokenizer/vocab.bpe.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralPushkin/Dalle2-Decoder/47af769d01d50b7a7f9c1c2b8cef6112b9870bdb/dalle2_decoder/tokenizer/vocab.bpe.gz -------------------------------------------------------------------------------- /dalle2_decoder/train_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import os 4 | 5 | import blobfile as bf 6 | import numpy as np 7 | import torch as th 8 | import torch.distributed as dist 9 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 10 | from torch.optim import AdamW 11 | 12 | from . import dist_util, logger 13 | from .fp16_util import ( 14 | make_master_params, 15 | master_params_to_model_params, 16 | model_grads_to_master_grads, 17 | unflatten_master_params, 18 | zero_grad, 19 | ) 20 | from .nn import update_ema 21 | from .resample import LossAwareSampler, UniformSampler 22 | 23 | # For ImageNet experiments, this was a good default value. 24 | # We found that the lg_loss_scale quickly climbed to 25 | # 20-21 within the first ~1K steps of training. 26 | INITIAL_LOG_LOSS_SCALE = 20.0 27 | 28 | 29 | class TrainLoop: 30 | def __init__( 31 | self, 32 | *, 33 | model, 34 | diffusion, 35 | data, 36 | batch_size, 37 | microbatch, 38 | lr, 39 | ema_rate, 40 | log_interval, 41 | save_interval, 42 | resume_checkpoint, 43 | save_dir, 44 | use_fp16=False, 45 | fp16_scale_growth=1e-3, 46 | schedule_sampler=None, 47 | weight_decay=0.0, 48 | lr_anneal_steps=0, 49 | ): 50 | self.save_dir = save_dir 51 | self.model = model 52 | self.diffusion = diffusion 53 | self.data = data 54 | self.batch_size = batch_size 55 | self.microbatch = microbatch if microbatch > 0 else batch_size 56 | self.lr = lr 57 | self.ema_rate = ( 58 | [ema_rate] 59 | if isinstance(ema_rate, float) 60 | else [float(x) for x in ema_rate.split(",")] 61 | ) 62 | self.log_interval = log_interval 63 | self.save_interval = save_interval 64 | self.resume_checkpoint = resume_checkpoint 65 | self.use_fp16 = use_fp16 66 | self.fp16_scale_growth = fp16_scale_growth 67 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) 68 | self.weight_decay = weight_decay 69 | self.lr_anneal_steps = lr_anneal_steps 70 | 71 | self.step = 0 72 | self.resume_step = 0 73 | self.global_batch = self.batch_size * dist.get_world_size() 74 | 75 | self.model_params = list(self.model.parameters()) 76 | self.master_params = self.model_params 77 | self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE 78 | self.sync_cuda = th.cuda.is_available() 79 | 80 | self._load_and_sync_parameters() 81 | if self.use_fp16: 82 | self._setup_fp16() 83 | 84 | self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay) 85 | if self.resume_step: 86 | self._load_optimizer_state() 87 | # Model was resumed, either due to a restart or a checkpoint 88 | # being specified at the command line. 89 | self.ema_params = [ 90 | self._load_ema_parameters(rate) for rate in self.ema_rate 91 | ] 92 | else: 93 | self.ema_params = [ 94 | copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate)) 95 | ] 96 | 97 | if th.cuda.is_available(): 98 | self.use_ddp = True 99 | self.ddp_model = DDP( 100 | self.model, 101 | device_ids=[dist_util.dev()], 102 | output_device=dist_util.dev(), 103 | broadcast_buffers=False, 104 | bucket_cap_mb=128, 105 | find_unused_parameters=False, 106 | ) 107 | else: 108 | if dist.get_world_size() > 1: 109 | logger.warn( 110 | "Distributed training requires CUDA. " 111 | "Gradients will not be synchronized properly!" 112 | ) 113 | self.use_ddp = False 114 | self.ddp_model = self.model 115 | 116 | def _load_and_sync_parameters(self): 117 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 118 | 119 | if resume_checkpoint: 120 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint) 121 | if dist.get_rank() == 0: 122 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...") 123 | self.model.load_state_dict( 124 | dist_util.load_state_dict( 125 | resume_checkpoint, map_location=dist_util.dev() 126 | ) 127 | ) 128 | 129 | dist_util.sync_params(self.model.parameters()) 130 | 131 | def _load_ema_parameters(self, rate): 132 | ema_params = copy.deepcopy(self.master_params) 133 | 134 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 135 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) 136 | if ema_checkpoint: 137 | if dist.get_rank() == 0: 138 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") 139 | state_dict = dist_util.load_state_dict( 140 | ema_checkpoint, map_location=dist_util.dev() 141 | ) 142 | ema_params = self._state_dict_to_master_params(state_dict) 143 | 144 | dist_util.sync_params(ema_params) 145 | return ema_params 146 | 147 | def _load_optimizer_state(self): 148 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 149 | opt_checkpoint = bf.join( 150 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" 151 | ) 152 | if bf.exists(opt_checkpoint): 153 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") 154 | state_dict = dist_util.load_state_dict( 155 | opt_checkpoint, map_location=dist_util.dev() 156 | ) 157 | self.opt.load_state_dict(state_dict) 158 | 159 | def _setup_fp16(self): 160 | self.master_params = make_master_params(self.model_params) 161 | self.model.convert_to_fp16() 162 | 163 | def run_loop(self): 164 | while ( 165 | not self.lr_anneal_steps 166 | or self.step + self.resume_step < self.lr_anneal_steps 167 | ): 168 | batch, cond = next(self.data) 169 | self.run_step(batch, cond) 170 | if self.step % self.log_interval == 0: 171 | logger.dumpkvs() 172 | if self.step % self.save_interval == 0: 173 | self.save() 174 | # Run for a finite amount of time in integration tests. 175 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: 176 | return 177 | self.step += 1 178 | # Save the last checkpoint if it wasn't already saved. 179 | if (self.step - 1) % self.save_interval != 0: 180 | self.save() 181 | 182 | def run_step(self, batch, cond): 183 | self.forward_backward(batch, cond) 184 | if self.use_fp16: 185 | self.optimize_fp16() 186 | else: 187 | self.optimize_normal() 188 | self.log_step() 189 | 190 | def forward_backward(self, batch, cond): 191 | zero_grad(self.model_params) 192 | for i in range(0, batch.shape[0], self.microbatch): 193 | micro = batch[i : i + self.microbatch].to(dist_util.dev()) 194 | micro_cond = { 195 | k: v[i : i + self.microbatch].to(dist_util.dev()) 196 | for k, v in cond.items() 197 | } 198 | last_batch = (i + self.microbatch) >= batch.shape[0] 199 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) 200 | 201 | compute_losses = functools.partial( 202 | self.diffusion.training_losses, 203 | self.ddp_model, 204 | micro, 205 | t, 206 | model_kwargs=micro_cond, 207 | ) 208 | 209 | if last_batch or not self.use_ddp: 210 | losses = compute_losses() 211 | else: 212 | with self.ddp_model.no_sync(): 213 | losses = compute_losses() 214 | 215 | if isinstance(self.schedule_sampler, LossAwareSampler): 216 | self.schedule_sampler.update_with_local_losses( 217 | t, losses["loss"].detach() 218 | ) 219 | 220 | loss = (losses["loss"] * weights).mean() 221 | log_loss_dict( 222 | self.diffusion, t, {k: v * weights for k, v in losses.items()} 223 | ) 224 | if self.use_fp16: 225 | loss_scale = 2 ** self.lg_loss_scale 226 | (loss * loss_scale).backward() 227 | else: 228 | loss.backward() 229 | 230 | def optimize_fp16(self): 231 | if any(not th.isfinite(p.grad).all() for p in self.model_params): 232 | self.lg_loss_scale -= 1 233 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 234 | return 235 | 236 | model_grads_to_master_grads(self.model_params, self.master_params) 237 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 238 | self._log_grad_norm() 239 | self._anneal_lr() 240 | self.opt.step() 241 | for rate, params in zip(self.ema_rate, self.ema_params): 242 | update_ema(params, self.master_params, rate=rate) 243 | master_params_to_model_params(self.model_params, self.master_params) 244 | self.lg_loss_scale += self.fp16_scale_growth 245 | 246 | def optimize_normal(self): 247 | self._log_grad_norm() 248 | self._anneal_lr() 249 | self.opt.step() 250 | for rate, params in zip(self.ema_rate, self.ema_params): 251 | update_ema(params, self.master_params, rate=rate) 252 | 253 | def _log_grad_norm(self): 254 | sqsum = 0.0 255 | for p in self.master_params: 256 | sqsum += (p.grad ** 2).sum().item() 257 | logger.logkv_mean("grad_norm", np.sqrt(sqsum)) 258 | 259 | def _anneal_lr(self): 260 | if not self.lr_anneal_steps: 261 | return 262 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps 263 | lr = self.lr * (1 - frac_done) 264 | for param_group in self.opt.param_groups: 265 | param_group["lr"] = lr 266 | 267 | def log_step(self): 268 | logger.logkv("step", self.step + self.resume_step) 269 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) 270 | if self.use_fp16: 271 | logger.logkv("lg_loss_scale", self.lg_loss_scale) 272 | 273 | def save(self): 274 | filename = os.path.join(self.save_dir, f'model_{self.step}.pt') 275 | state_dict = self.model.state_dict() 276 | th.save(state_dict, filename) 277 | 278 | def _master_params_to_state_dict(self, master_params): 279 | if self.use_fp16: 280 | master_params = unflatten_master_params( 281 | self.model.parameters(), master_params 282 | ) 283 | state_dict = self.model.state_dict() 284 | for i, (name, _value) in enumerate(self.model.named_parameters()): 285 | assert name in state_dict 286 | state_dict[name] = master_params[i] 287 | return state_dict 288 | 289 | def _state_dict_to_master_params(self, state_dict): 290 | params = [state_dict[name] for name, _ in self.model.named_parameters()] 291 | if self.use_fp16: 292 | return make_master_params(params) 293 | else: 294 | return params 295 | 296 | 297 | def parse_resume_step_from_filename(filename): 298 | """ 299 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the 300 | checkpoint's number of steps. 301 | """ 302 | split = filename.split("model") 303 | if len(split) < 2: 304 | return 0 305 | split1 = split[-1].split(".")[0] 306 | try: 307 | return int(split1) 308 | except ValueError: 309 | return 0 310 | 311 | 312 | def get_blob_logdir(): 313 | return os.environ.get("DIFFUSION_BLOB_LOGDIR", logger.get_dir()) 314 | 315 | 316 | def find_resume_checkpoint(): 317 | # On your infrastructure, you may want to override this to automatically 318 | # discover the latest checkpoint on your blob storage, etc. 319 | return None 320 | 321 | 322 | def find_ema_checkpoint(main_checkpoint, step, rate): 323 | if main_checkpoint is None: 324 | return None 325 | filename = f"ema_{rate}_{(step):06d}.pt" 326 | path = bf.join(bf.dirname(main_checkpoint), filename) 327 | if bf.exists(path): 328 | return path 329 | return None 330 | 331 | 332 | def log_loss_dict(diffusion, ts, losses): 333 | for key, values in losses.items(): 334 | logger.logkv_mean(key, values.mean().item()) 335 | # Log the quantiles (four quartiles, in particular). 336 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): 337 | quartile = int(4 * sub_t / diffusion.num_timesteps) 338 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss) 339 | -------------------------------------------------------------------------------- /dalle2_decoder/unet.py: -------------------------------------------------------------------------------- 1 | import math 2 | from abc import abstractmethod 3 | 4 | import torch as th 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from .fp16_util import convert_module_to_f16, convert_module_to_f32 9 | from .nn import avg_pool_nd, conv_nd, linear, normalization, timestep_embedding, zero_module 10 | 11 | 12 | class TimestepBlock(nn.Module): 13 | """ 14 | Any module where forward() takes timestep embeddings as a second argument. 15 | """ 16 | 17 | @abstractmethod 18 | def forward(self, x, emb): 19 | """ 20 | Apply the module to `x` given `emb` timestep embeddings. 21 | """ 22 | 23 | 24 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 25 | """ 26 | A sequential module that passes timestep embeddings to the children that 27 | support it as an extra input. 28 | """ 29 | 30 | def forward(self, x, emb, encoder_out=None): 31 | for layer in self: 32 | if isinstance(layer, TimestepBlock): 33 | x = layer(x, emb) 34 | elif isinstance(layer, AttentionBlock): 35 | x = layer(x, encoder_out) 36 | else: 37 | x = layer(x) 38 | return x 39 | 40 | 41 | class Upsample(nn.Module): 42 | """ 43 | An upsampling layer with an optional convolution. 44 | 45 | :param channels: channels in the inputs and outputs. 46 | :param use_conv: a bool determining if a convolution is applied. 47 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 48 | upsampling occurs in the inner-two dimensions. 49 | """ 50 | 51 | def __init__(self, channels, use_conv, dims=2, out_channels=None): 52 | super().__init__() 53 | self.channels = channels 54 | self.out_channels = out_channels or channels 55 | self.use_conv = use_conv 56 | self.dims = dims 57 | if use_conv: 58 | self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) 59 | 60 | def forward(self, x): 61 | assert x.shape[1] == self.channels 62 | if self.dims == 3: 63 | x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") 64 | else: 65 | x = F.interpolate(x, scale_factor=2, mode="nearest") 66 | if self.use_conv: 67 | x = self.conv(x) 68 | return x 69 | 70 | 71 | class Downsample(nn.Module): 72 | """ 73 | A downsampling layer with an optional convolution. 74 | 75 | :param channels: channels in the inputs and outputs. 76 | :param use_conv: a bool determining if a convolution is applied. 77 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 78 | downsampling occurs in the inner-two dimensions. 79 | """ 80 | 81 | def __init__(self, channels, use_conv, dims=2, out_channels=None): 82 | super().__init__() 83 | self.channels = channels 84 | self.out_channels = out_channels or channels 85 | self.use_conv = use_conv 86 | self.dims = dims 87 | stride = 2 if dims != 3 else (1, 2, 2) 88 | if use_conv: 89 | self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=1) 90 | else: 91 | assert self.channels == self.out_channels 92 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 93 | 94 | def forward(self, x): 95 | assert x.shape[1] == self.channels 96 | return self.op(x) 97 | 98 | 99 | class ResBlock(TimestepBlock): 100 | """ 101 | A residual block that can optionally change the number of channels. 102 | 103 | :param channels: the number of input channels. 104 | :param emb_channels: the number of timestep embedding channels. 105 | :param dropout: the rate of dropout. 106 | :param out_channels: if specified, the number of out channels. 107 | :param use_conv: if True and out_channels is specified, use a spatial 108 | convolution instead of a smaller 1x1 convolution to change the 109 | channels in the skip connection. 110 | :param dims: determines if the signal is 1D, 2D, or 3D. 111 | :param use_checkpoint: if True, use gradient checkpointing on this module. 112 | :param up: if True, use this block for upsampling. 113 | :param down: if True, use this block for downsampling. 114 | """ 115 | 116 | def __init__( 117 | self, 118 | channels, 119 | emb_channels, 120 | dropout, 121 | out_channels=None, 122 | use_conv=False, 123 | use_scale_shift_norm=False, 124 | dims=2, 125 | use_checkpoint=False, 126 | up=False, 127 | down=False, 128 | ): 129 | super().__init__() 130 | self.channels = channels 131 | self.emb_channels = emb_channels 132 | self.dropout = dropout 133 | self.out_channels = out_channels or channels 134 | self.use_conv = use_conv 135 | self.use_checkpoint = use_checkpoint 136 | self.use_scale_shift_norm = use_scale_shift_norm 137 | 138 | self.in_layers = nn.Sequential( 139 | normalization(channels, swish=1.0), 140 | nn.Identity(), 141 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 142 | ) 143 | 144 | self.updown = up or down 145 | 146 | if up: 147 | self.h_upd = Upsample(channels, False, dims) 148 | self.x_upd = Upsample(channels, False, dims) 149 | elif down: 150 | self.h_upd = Downsample(channels, False, dims) 151 | self.x_upd = Downsample(channels, False, dims) 152 | else: 153 | self.h_upd = self.x_upd = nn.Identity() 154 | 155 | self.emb_layers = nn.Sequential( 156 | nn.SiLU(), 157 | linear( 158 | emb_channels, 159 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, 160 | ), 161 | ) 162 | self.out_layers = nn.Sequential( 163 | normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0), 164 | nn.SiLU() if use_scale_shift_norm else nn.Identity(), 165 | nn.Dropout(p=dropout), 166 | zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), 167 | ) 168 | 169 | if self.out_channels == channels: 170 | self.skip_connection = nn.Identity() 171 | elif use_conv: 172 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) 173 | else: 174 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 175 | 176 | def forward(self, x, emb): 177 | """ 178 | Apply the block to a Tensor, conditioned on a timestep embedding. 179 | 180 | :param x: an [N x C x ...] Tensor of features. 181 | :param emb: an [N x emb_channels] Tensor of timestep embeddings. 182 | :return: an [N x C x ...] Tensor of outputs. 183 | """ 184 | if self.updown: 185 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 186 | h = in_rest(x) 187 | h = self.h_upd(h) 188 | x = self.x_upd(x) 189 | h = in_conv(h) 190 | else: 191 | h = self.in_layers(x) 192 | emb_out = self.emb_layers(emb).type(h.dtype) 193 | while len(emb_out.shape) < len(h.shape): 194 | emb_out = emb_out[..., None] 195 | if self.use_scale_shift_norm: 196 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 197 | scale, shift = th.chunk(emb_out, 2, dim=1) 198 | h = out_norm(h) * (1 + scale) + shift 199 | h = out_rest(h) 200 | else: 201 | h = h + emb_out 202 | h = self.out_layers(h) 203 | return self.skip_connection(x) + h 204 | 205 | 206 | class AttentionBlock(nn.Module): 207 | """ 208 | An attention block that allows spatial positions to attend to each other. 209 | 210 | Originally ported from here, but adapted to the N-d case. 211 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 212 | """ 213 | 214 | def __init__( 215 | self, 216 | channels, 217 | num_heads=1, 218 | num_head_channels=-1, 219 | use_checkpoint=False, 220 | encoder_channels=None, 221 | ): 222 | super().__init__() 223 | self.channels = channels 224 | if num_head_channels == -1: 225 | self.num_heads = num_heads 226 | else: 227 | assert ( 228 | channels % num_head_channels == 0 229 | ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" 230 | self.num_heads = channels // num_head_channels 231 | self.use_checkpoint = use_checkpoint 232 | self.norm = normalization(channels, swish=0.0) 233 | self.qkv = conv_nd(1, channels, channels * 3, 1) 234 | self.attention = QKVAttention(self.num_heads) 235 | 236 | if encoder_channels is not None: 237 | self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1) 238 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) 239 | 240 | def forward(self, x, encoder_out=None): 241 | b, c, *spatial = x.shape 242 | qkv = self.qkv(self.norm(x).view(b, c, -1)) 243 | if encoder_out is not None: 244 | encoder_out = self.encoder_kv(encoder_out) 245 | h = self.attention(qkv, encoder_out) 246 | else: 247 | h = self.attention(qkv) 248 | h = self.proj_out(h) 249 | return x + h.reshape(b, c, *spatial) 250 | 251 | 252 | class QKVAttention(nn.Module): 253 | """ 254 | A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping 255 | """ 256 | 257 | def __init__(self, n_heads): 258 | super().__init__() 259 | self.n_heads = n_heads 260 | 261 | def forward(self, qkv, encoder_kv=None): 262 | """ 263 | Apply QKV attention. 264 | 265 | :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. 266 | :return: an [N x (H * C) x T] tensor after attention. 267 | """ 268 | bs, width, length = qkv.shape 269 | assert width % (3 * self.n_heads) == 0 270 | ch = width // (3 * self.n_heads) 271 | q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) 272 | if encoder_kv is not None: 273 | assert encoder_kv.shape[1] == self.n_heads * ch * 2 274 | ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1) 275 | k = th.cat([ek, k], dim=-1) 276 | v = th.cat([ev, v], dim=-1) 277 | scale = 1 / math.sqrt(math.sqrt(ch)) 278 | weight = th.einsum( 279 | "bct,bcs->bts", q * scale, k * scale 280 | ) # More stable with f16 than dividing afterwards 281 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 282 | a = th.einsum("bts,bcs->bct", weight, v) 283 | return a.reshape(bs, -1, length) 284 | 285 | 286 | class UNetModel(nn.Module): 287 | """ 288 | The full UNet model with attention and timestep embedding. 289 | 290 | :param in_channels: channels in the input Tensor. 291 | :param model_channels: base channel count for the model. 292 | :param out_channels: channels in the output Tensor. 293 | :param num_res_blocks: number of residual blocks per downsample. 294 | :param attention_resolutions: a collection of downsample rates at which 295 | attention will take place. May be a set, list, or tuple. 296 | For example, if this contains 4, then at 4x downsampling, attention 297 | will be used. 298 | :param dropout: the dropout probability. 299 | :param channel_mult: channel multiplier for each level of the UNet. 300 | :param conv_resample: if True, use learned convolutions for upsampling and 301 | downsampling. 302 | :param dims: determines if the signal is 1D, 2D, or 3D. 303 | :param num_classes: if specified (as an int), then this model will be 304 | class-conditional with `num_classes` classes. 305 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 306 | :param num_heads: the number of attention heads in each attention layer. 307 | :param num_heads_channels: if specified, ignore num_heads and instead use 308 | a fixed channel width per attention head. 309 | :param num_heads_upsample: works with num_heads to set a different number 310 | of heads for upsampling. Deprecated. 311 | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. 312 | :param resblock_updown: use residual blocks for up/downsampling. 313 | """ 314 | 315 | def __init__( 316 | self, 317 | in_channels, 318 | model_channels, 319 | out_channels, 320 | num_res_blocks, 321 | attention_resolutions, 322 | dropout=0, 323 | channel_mult=(1, 2, 4, 8), 324 | conv_resample=True, 325 | dims=2, 326 | num_classes=None, 327 | use_checkpoint=False, 328 | use_fp16=False, 329 | num_heads=1, 330 | num_head_channels=-1, 331 | num_heads_upsample=-1, 332 | use_scale_shift_norm=False, 333 | resblock_updown=False, 334 | encoder_channels=None, 335 | ): 336 | super().__init__() 337 | 338 | if num_heads_upsample == -1: 339 | num_heads_upsample = num_heads 340 | 341 | self.in_channels = in_channels 342 | self.model_channels = model_channels 343 | self.out_channels = out_channels 344 | self.num_res_blocks = num_res_blocks 345 | self.attention_resolutions = attention_resolutions 346 | self.dropout = dropout 347 | self.channel_mult = channel_mult 348 | self.conv_resample = conv_resample 349 | self.num_classes = num_classes 350 | self.use_checkpoint = use_checkpoint 351 | self.dtype = th.float16 if use_fp16 else th.float32 352 | self.num_heads = num_heads 353 | self.num_head_channels = num_head_channels 354 | self.num_heads_upsample = num_heads_upsample 355 | 356 | time_embed_dim = model_channels * 4 357 | self.time_embed = nn.Sequential( 358 | linear(model_channels, time_embed_dim), 359 | nn.SiLU(), 360 | linear(time_embed_dim, time_embed_dim), 361 | ) 362 | 363 | if self.num_classes is not None: 364 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) 365 | 366 | ch = input_ch = int(channel_mult[0] * model_channels) 367 | self.input_blocks = nn.ModuleList( 368 | [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))] 369 | ) 370 | self._feature_size = ch 371 | input_block_chans = [ch] 372 | ds = 1 373 | for level, mult in enumerate(channel_mult): 374 | for _ in range(num_res_blocks): 375 | layers = [ 376 | ResBlock( 377 | ch, 378 | time_embed_dim, 379 | dropout, 380 | out_channels=int(mult * model_channels), 381 | dims=dims, 382 | use_checkpoint=use_checkpoint, 383 | use_scale_shift_norm=use_scale_shift_norm, 384 | ) 385 | ] 386 | ch = int(mult * model_channels) 387 | if ds in attention_resolutions: 388 | layers.append( 389 | AttentionBlock( 390 | ch, 391 | use_checkpoint=use_checkpoint, 392 | num_heads=num_heads, 393 | num_head_channels=num_head_channels, 394 | encoder_channels=encoder_channels, 395 | ) 396 | ) 397 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 398 | self._feature_size += ch 399 | input_block_chans.append(ch) 400 | if level != len(channel_mult) - 1: 401 | out_ch = ch 402 | self.input_blocks.append( 403 | TimestepEmbedSequential( 404 | ResBlock( 405 | ch, 406 | time_embed_dim, 407 | dropout, 408 | out_channels=out_ch, 409 | dims=dims, 410 | use_checkpoint=use_checkpoint, 411 | use_scale_shift_norm=use_scale_shift_norm, 412 | down=True, 413 | ) 414 | if resblock_updown 415 | else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) 416 | ) 417 | ) 418 | ch = out_ch 419 | input_block_chans.append(ch) 420 | ds *= 2 421 | self._feature_size += ch 422 | 423 | self.middle_block = TimestepEmbedSequential( 424 | ResBlock( 425 | ch, 426 | time_embed_dim, 427 | dropout, 428 | dims=dims, 429 | use_checkpoint=use_checkpoint, 430 | use_scale_shift_norm=use_scale_shift_norm, 431 | ), 432 | AttentionBlock( 433 | ch, 434 | use_checkpoint=use_checkpoint, 435 | num_heads=num_heads, 436 | num_head_channels=num_head_channels, 437 | encoder_channels=encoder_channels, 438 | ), 439 | ResBlock( 440 | ch, 441 | time_embed_dim, 442 | dropout, 443 | dims=dims, 444 | use_checkpoint=use_checkpoint, 445 | use_scale_shift_norm=use_scale_shift_norm, 446 | ), 447 | ) 448 | self._feature_size += ch 449 | 450 | self.output_blocks = nn.ModuleList([]) 451 | for level, mult in list(enumerate(channel_mult))[::-1]: 452 | for i in range(num_res_blocks + 1): 453 | ich = input_block_chans.pop() 454 | layers = [ 455 | ResBlock( 456 | ch + ich, 457 | time_embed_dim, 458 | dropout, 459 | out_channels=int(model_channels * mult), 460 | dims=dims, 461 | use_checkpoint=use_checkpoint, 462 | use_scale_shift_norm=use_scale_shift_norm, 463 | ) 464 | ] 465 | ch = int(model_channels * mult) 466 | if ds in attention_resolutions: 467 | layers.append( 468 | AttentionBlock( 469 | ch, 470 | use_checkpoint=use_checkpoint, 471 | num_heads=num_heads_upsample, 472 | num_head_channels=num_head_channels, 473 | encoder_channels=encoder_channels, 474 | ) 475 | ) 476 | if level and i == num_res_blocks: 477 | out_ch = ch 478 | layers.append( 479 | ResBlock( 480 | ch, 481 | time_embed_dim, 482 | dropout, 483 | out_channels=out_ch, 484 | dims=dims, 485 | use_checkpoint=use_checkpoint, 486 | use_scale_shift_norm=use_scale_shift_norm, 487 | up=True, 488 | ) 489 | if resblock_updown 490 | else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) 491 | ) 492 | ds //= 2 493 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 494 | self._feature_size += ch 495 | 496 | self.out = nn.Sequential( 497 | normalization(ch, swish=1.0), 498 | nn.Identity(), 499 | zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)), 500 | ) 501 | self.use_fp16 = use_fp16 502 | 503 | def convert_to_fp16(self): 504 | """ 505 | Convert the torso of the model to float16. 506 | """ 507 | self.input_blocks.apply(convert_module_to_f16) 508 | self.middle_block.apply(convert_module_to_f16) 509 | self.output_blocks.apply(convert_module_to_f16) 510 | 511 | def convert_to_fp32(self): 512 | """ 513 | Convert the torso of the model to float32. 514 | """ 515 | self.input_blocks.apply(convert_module_to_f32) 516 | self.middle_block.apply(convert_module_to_f32) 517 | self.output_blocks.apply(convert_module_to_f32) 518 | 519 | def forward(self, x, timesteps, y=None): 520 | """ 521 | Apply the model to an input batch. 522 | 523 | :param x: an [N x C x ...] Tensor of inputs. 524 | :param timesteps: a 1-D batch of timesteps. 525 | :param y: an [N] Tensor of labels, if class-conditional. 526 | :return: an [N x C x ...] Tensor of outputs. 527 | """ 528 | assert (y is not None) == ( 529 | self.num_classes is not None 530 | ), "must specify y if and only if the model is class-conditional" 531 | 532 | hs = [] 533 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 534 | 535 | if self.num_classes is not None: 536 | assert y.shape == (x.shape[0],) 537 | emb = emb + self.label_emb(y) 538 | 539 | h = x.type(self.dtype) 540 | for module in self.input_blocks: 541 | h = module(h, emb) 542 | hs.append(h) 543 | h = self.middle_block(h, emb) 544 | for module in self.output_blocks: 545 | h = th.cat([h, hs.pop()], dim=1) 546 | h = module(h, emb) 547 | h = h.type(x.dtype) 548 | return self.out(h) 549 | 550 | class SuperResUNetModel(UNetModel): 551 | """ 552 | A UNetModel that performs super-resolution. 553 | 554 | Expects an extra kwarg `low_res` to condition on a low-resolution image. 555 | """ 556 | 557 | def __init__(self, *args, **kwargs): 558 | if "in_channels" in kwargs: 559 | kwargs = dict(kwargs) 560 | kwargs["in_channels"] = kwargs["in_channels"] * 2 561 | else: 562 | # Curse you, Python. Or really, just curse positional arguments :|. 563 | args = list(args) 564 | args[1] = args[1] * 2 565 | super().__init__(*args, **kwargs) 566 | 567 | def forward(self, x, timesteps, low_res=None, **kwargs): 568 | _, _, new_height, new_width = x.shape 569 | upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") 570 | x = th.cat([x, upsampled], dim=1) 571 | return super().forward(x, timesteps, **kwargs) 572 | 573 | 574 | class InpaintUNetModel(UNetModel): 575 | """ 576 | A UNetModel which can perform inpainting. 577 | """ 578 | 579 | def __init__(self, *args, **kwargs): 580 | if "in_channels" in kwargs: 581 | kwargs = dict(kwargs) 582 | kwargs["in_channels"] = kwargs["in_channels"] * 2 + 1 583 | else: 584 | # Curse you, Python. Or really, just curse positional arguments :|. 585 | args = list(args) 586 | args[1] = args[1] * 2 + 1 587 | super().__init__(*args, **kwargs) 588 | 589 | def forward(self, x, timesteps, inpaint_image=None, inpaint_mask=None, **kwargs): 590 | if inpaint_image is None: 591 | inpaint_image = th.zeros_like(x) 592 | if inpaint_mask is None: 593 | inpaint_mask = th.zeros_like(x[:, :1]) 594 | return super().forward( 595 | th.cat([x, inpaint_image * inpaint_mask, inpaint_mask], dim=1), 596 | timesteps, 597 | **kwargs, 598 | ) 599 | 600 | 601 | class SuperResInpaintUNetModel(UNetModel): 602 | """ 603 | A UNetModel which can perform both upsampling and inpainting. 604 | """ 605 | 606 | def __init__(self, *args, **kwargs): 607 | if "in_channels" in kwargs: 608 | kwargs = dict(kwargs) 609 | kwargs["in_channels"] = kwargs["in_channels"] * 3 + 1 610 | else: 611 | # Curse you, Python. Or really, just curse positional arguments :|. 612 | args = list(args) 613 | args[1] = args[1] * 3 + 1 614 | super().__init__(*args, **kwargs) 615 | 616 | def forward( 617 | self, 618 | x, 619 | timesteps, 620 | inpaint_image=None, 621 | inpaint_mask=None, 622 | low_res=None, 623 | **kwargs, 624 | ): 625 | if inpaint_image is None: 626 | inpaint_image = th.zeros_like(x) 627 | if inpaint_mask is None: 628 | inpaint_mask = th.zeros_like(x[:, :1]) 629 | _, _, new_height, new_width = x.shape 630 | upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") 631 | return super().forward( 632 | th.cat([x, inpaint_image * inpaint_mask, inpaint_mask, upsampled], dim=1), 633 | timesteps, 634 | **kwargs, 635 | ) 636 | -------------------------------------------------------------------------------- /dalle2_decoder/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch as th 4 | import torch.nn as nn 5 | 6 | def mean_flat(tensor): 7 | """ 8 | Take the mean over all non-batch dimensions. 9 | """ 10 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) -------------------------------------------------------------------------------- /dalle2_decoder/xf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transformer implementation adapted from CLIP ViT: 3 | https://github.com/openai/CLIP/blob/4c0275784d6d9da97ca1f47eaaee31de1867da91/clip/model.py 4 | """ 5 | 6 | import math 7 | 8 | import torch as th 9 | import torch.nn as nn 10 | 11 | 12 | def convert_module_to_f16(l): 13 | """ 14 | Convert primitive modules to float16. 15 | """ 16 | if isinstance(l, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): 17 | l.weight.data = l.weight.data.half() 18 | if l.bias is not None: 19 | l.bias.data = l.bias.data.half() 20 | 21 | 22 | class LayerNorm(nn.LayerNorm): 23 | """ 24 | Implementation that supports fp16 inputs but fp32 gains/biases. 25 | """ 26 | 27 | def forward(self, x: th.Tensor): 28 | return super().forward(x.float()).to(x.dtype) 29 | 30 | 31 | class MultiheadAttention(nn.Module): 32 | def __init__(self, n_ctx, width, heads): 33 | super().__init__() 34 | self.n_ctx = n_ctx 35 | self.width = width 36 | self.heads = heads 37 | self.c_qkv = nn.Linear(width, width * 3) 38 | self.c_proj = nn.Linear(width, width) 39 | self.attention = QKVMultiheadAttention(heads, n_ctx) 40 | 41 | def forward(self, x): 42 | x = self.c_qkv(x) 43 | x = self.attention(x) 44 | x = self.c_proj(x) 45 | return x 46 | 47 | 48 | class MLP(nn.Module): 49 | def __init__(self, width): 50 | super().__init__() 51 | self.width = width 52 | self.c_fc = nn.Linear(width, width * 4) 53 | self.c_proj = nn.Linear(width * 4, width) 54 | self.gelu = nn.GELU() 55 | 56 | def forward(self, x): 57 | return self.c_proj(self.gelu(self.c_fc(x))) 58 | 59 | 60 | class QKVMultiheadAttention(nn.Module): 61 | def __init__(self, n_heads: int, n_ctx: int): 62 | super().__init__() 63 | self.n_heads = n_heads 64 | self.n_ctx = n_ctx 65 | 66 | def forward(self, qkv): 67 | bs, n_ctx, width = qkv.shape 68 | attn_ch = width // self.n_heads // 3 69 | scale = 1 / math.sqrt(math.sqrt(attn_ch)) 70 | qkv = qkv.view(bs, n_ctx, self.n_heads, -1) 71 | q, k, v = th.split(qkv, attn_ch, dim=-1) 72 | weight = th.einsum( 73 | "bthc,bshc->bhts", q * scale, k * scale 74 | ) # More stable with f16 than dividing afterwards 75 | wdtype = weight.dtype 76 | weight = th.softmax(weight.float(), dim=-1).type(wdtype) 77 | return th.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 78 | 79 | 80 | class ResidualAttentionBlock(nn.Module): 81 | def __init__( 82 | self, 83 | n_ctx: int, 84 | width: int, 85 | heads: int, 86 | ): 87 | super().__init__() 88 | 89 | self.attn = MultiheadAttention( 90 | n_ctx, 91 | width, 92 | heads, 93 | ) 94 | self.ln_1 = LayerNorm(width) 95 | self.mlp = MLP(width) 96 | self.ln_2 = LayerNorm(width) 97 | 98 | def forward(self, x: th.Tensor): 99 | x = x + self.attn(self.ln_1(x)) 100 | x = x + self.mlp(self.ln_2(x)) 101 | return x 102 | 103 | 104 | class Transformer(nn.Module): 105 | def __init__( 106 | self, 107 | n_ctx: int, 108 | width: int, 109 | layers: int, 110 | heads: int, 111 | ): 112 | super().__init__() 113 | self.n_ctx = n_ctx 114 | self.width = width 115 | self.layers = layers 116 | self.resblocks = nn.ModuleList( 117 | [ 118 | ResidualAttentionBlock( 119 | n_ctx, 120 | width, 121 | heads, 122 | ) 123 | for _ in range(layers) 124 | ] 125 | ) 126 | 127 | def forward(self, x: th.Tensor): 128 | for block in self.resblocks: 129 | x = block(x) 130 | return x 131 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="dalle2_decoder", 5 | packages=[ 6 | "dalle2_decoder", 7 | "dalle2_decoder.clip", 8 | "dalle2_decoder.tokenizer", 9 | ], 10 | package_data={ 11 | "dalle2_decoder.tokenizer": [ 12 | "bpe_simple_vocab_16e6.txt.gz", 13 | "encoder.json.gz", 14 | "vocab.bpe.gz", 15 | ], 16 | "dalle2_decoder.clip": ["config.yaml"], 17 | }, 18 | install_requires=[ 19 | "Pillow", 20 | "attrs", 21 | "torch", 22 | "filelock", 23 | "requests", 24 | "tqdm", 25 | "ftfy", 26 | "regex", 27 | "numpy", 28 | "blobfile", 29 | "mpi4py", 30 | ], 31 | author="NeuralPushkin", 32 | ) 33 | --------------------------------------------------------------------------------