├── README.md ├── imagen_pytorch ├── __init__.py ├── clip │ ├── __init__.py │ ├── attention.py │ ├── config.yaml │ ├── encoders.py │ ├── model_creation.py │ └── utils.py ├── dataset.py ├── download.py ├── fp16_util.py ├── gaussian_diffusion.py ├── get_webdataset_loader.py ├── logger.py ├── losses.py ├── model_creation.py ├── nn.py ├── resample.py ├── respace.py ├── text2im_model.py ├── tokenizer │ ├── __init__.py │ ├── bpe.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── encoder.json.gz │ ├── simple_tokenizer.py │ └── vocab.bpe.gz ├── unet.py ├── utils.py └── xf.py ├── images ├── 1.jpg ├── 2.jpg ├── 3.jpg ├── 4.jpg ├── 5.jpg ├── 6.jpg ├── 7.jpg └── 8.jpg ├── notebooks └── Imagen_pytorch_inference_new.ipynb └── setup.py /README.md: -------------------------------------------------------------------------------- 1 | # Imagen-pytorch 2 | Implementation of [Imagen](https://gweb-research-imagen.appspot.com/), Google's Text-to-Image Neural Network. It is the new SOTA for text-to-image synthesis. 3 | 4 | 5 | For detailed usage examples, see the [notebooks](notebooks) directory. 6 | 7 | * The [inference](notebooks/Imagen_pytorch_inference_new.ipynb) [![][colab]][colab-inference] notebook shows how to use Imagen. 8 | 9 | [colab]: 10 | [colab-inference]: 11 | 12 | # A red cube on top of blue cube 13 | 14 | ![image1](https://github.com/cene555/Imagen-pytorch/blob/main/images/1.jpg) 15 | 16 | # Teddy bears swimming at the Olympics 400m Butterfly event. 17 | 18 | ![image2](https://github.com/cene555/Imagen-pytorch/blob/main/images/2.jpg) 19 | 20 | # A teddy bear in times square 21 | 22 | ![image3](https://github.com/cene555/Imagen-pytorch/blob/main/images/3.jpg) 23 | 24 | ![image4](https://github.com/cene555/Imagen-pytorch/blob/main/images/6.jpg) 25 | 26 | # A face 27 | 28 | ![image5](https://github.com/cene555/Imagen-pytorch/blob/main/images/4.jpg) 29 | 30 | ![image6](https://github.com/cene555/Imagen-pytorch/blob/main/images/5.jpg) 31 | 32 | # A photo of teddy bear 33 | ![image7](https://github.com/cene555/Imagen-pytorch/blob/main/images/8.jpg) 34 | 35 | ![image8](https://github.com/cene555/Imagen-pytorch/blob/main/images/7.jpg) 36 | -------------------------------------------------------------------------------- /imagen_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | A codebase for performing model inference with a text-conditional diffusion model. 3 | """ 4 | -------------------------------------------------------------------------------- /imagen_pytorch/clip/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cene555/Imagen-pytorch/20318089b42cf25fea27618319aa4f7a105be2ec/imagen_pytorch/clip/__init__.py -------------------------------------------------------------------------------- /imagen_pytorch/clip/attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | from abc import ABC, abstractmethod 3 | from itertools import product 4 | from typing import Any, Optional 5 | 6 | import attr 7 | import numpy as np 8 | import torch 9 | 10 | 11 | @attr.s 12 | class AttentionMask(ABC): 13 | query_context_size: int = attr.ib(validator=lambda i, a, x: x >= 1) # type: ignore 14 | key_context_size: int = attr.ib(validator=lambda i, a, x: x >= 1) # type: ignore 15 | block_size: int = attr.ib(validator=lambda i, a, x: x >= 1) # type: ignore 16 | n_head: int = attr.ib(validator=lambda i, a, x: x >= 1) # type: ignore 17 | is_head_specific: bool = attr.ib(default=False) 18 | n_query_pad: int = attr.ib(default=0) 19 | n_key_pad: int = attr.ib(default=0) 20 | 21 | def __attrs_post_init__(self) -> None: 22 | if self.query_context_size % self.block_size != 0: 23 | raise ValueError() 24 | if self.key_context_size % self.block_size != 0: 25 | raise ValueError() 26 | if self.n_query_pad >= self.query_context_size: 27 | raise ValueError() 28 | if self.n_key_pad >= self.key_context_size: 29 | raise ValueError() 30 | 31 | self.n_query_block = self.query_context_size // self.block_size 32 | self.n_key_block = self.key_context_size // self.block_size 33 | self.first_pad_query_block_idx = self.n_query_block - int( 34 | math.ceil(self.n_query_pad / self.block_size) 35 | ) 36 | self.first_pad_key_block_idx = self.n_key_block - int( 37 | math.ceil(self.n_key_pad / self.block_size) 38 | ) 39 | 40 | def _make_global_layout(self) -> None: 41 | if not self.is_head_specific: 42 | m = np.ones([self.n_query_block, self.n_key_block], dtype=np.bool) 43 | r = product(*[range(n) for n in m.shape]) 44 | 45 | for qb, kb in r: 46 | m[qb, kb] = np.any(self.block_layout(None, 0, qb, kb, 0)) 47 | else: 48 | m = np.ones([self.n_head, self.n_query_block, self.n_key_block], dtype=np.bool) 49 | r = product(*[range(n) for n in m.shape]) 50 | 51 | for h, qb, kb in r: 52 | m[h, qb, kb] = np.any(self.block_layout(None, h, qb, kb, 0)) 53 | 54 | self.global_layout = m 55 | 56 | @abstractmethod 57 | def _block_layout( 58 | self, blk_shape: Any, head_idx: int, query_idx: int, key_idx: int, blk_idx: int 59 | ) -> np.ndarray: 60 | raise NotImplementedError() 61 | 62 | def block_layout( 63 | self, blk_shape: Any, head_idx: int, query_idx: int, key_idx: int, blk_idx: int 64 | ) -> np.ndarray: 65 | """ 66 | `query_idx`, `key_idx` are block-level, zero-based indices. 67 | """ 68 | 69 | m = np.ones([self.block_size, self.block_size], dtype=np.bool) 70 | 71 | if query_idx >= self.first_pad_query_block_idx: 72 | n_pad = min( 73 | self.block_size, 74 | (query_idx + 1) * self.block_size - (self.query_context_size - self.n_query_pad), 75 | ) 76 | assert n_pad > 0 77 | m[self.block_size - n_pad :] = False 78 | if key_idx >= self.first_pad_key_block_idx: 79 | n_pad = min( 80 | self.block_size, 81 | (key_idx + 1) * self.block_size - (self.key_context_size - self.n_key_pad), 82 | ) 83 | assert n_pad > 0 84 | m[:, self.block_size - n_pad :] = False 85 | 86 | return m & self._block_layout(blk_shape, head_idx, query_idx, key_idx, blk_idx) 87 | 88 | 89 | @attr.s 90 | class DenseAttentionMask(AttentionMask): 91 | def __attrs_post_init__(self) -> None: 92 | super().__attrs_post_init__() 93 | 94 | self.global_layout = np.ones([self.n_query_block, self.n_key_block], dtype=np.bool) 95 | n_zero_query_blocks = self.n_query_pad // self.block_size 96 | n_zero_key_blocks = self.n_key_pad // self.block_size 97 | self.global_layout[self.n_query_block - n_zero_query_blocks :] = False 98 | self.global_layout[:, self.n_key_block - n_zero_key_blocks :] = False 99 | 100 | def _block_layout( 101 | self, blk_shape: Any, head_idx: int, query_idx: int, key_idx: int, blk_idx: int 102 | ) -> np.ndarray: 103 | return np.ones([self.block_size, self.block_size], dtype=np.bool) 104 | 105 | 106 | @attr.s 107 | class DenseCausalAttentionMask(AttentionMask): 108 | def __attrs_post_init__(self) -> None: 109 | super().__attrs_post_init__() 110 | 111 | self.global_layout = np.tril(np.ones([self.n_query_block, self.n_key_block], dtype=np.bool)) 112 | n_zero_query_blocks = self.n_query_pad // self.block_size 113 | n_zero_key_blocks = self.n_key_pad // self.block_size 114 | self.global_layout[self.n_query_block - n_zero_query_blocks :] = False 115 | self.global_layout[:, self.n_key_block - n_zero_key_blocks :] = False 116 | 117 | def _block_layout( 118 | self, blk_shape: Any, head_idx: int, query_idx: int, key_idx: int, blk_idx: int 119 | ) -> np.ndarray: 120 | if query_idx > key_idx: 121 | return np.ones(2 * [self.block_size], dtype=np.bool) 122 | elif query_idx < key_idx: 123 | return np.zeros(2 * [self.block_size], dtype=np.bool) 124 | else: 125 | return np.tril(np.ones(2 * [self.block_size], dtype=np.bool)) 126 | 127 | 128 | @attr.s(eq=False, repr=False) 129 | class AttentionInfo: 130 | n_heads: int = attr.ib() 131 | ctx_blks_q: int = attr.ib() 132 | ctx_blks_k: int = attr.ib() 133 | block_size: int = attr.ib() 134 | pytorch_attn_bias: Optional[torch.Tensor] = attr.ib() 135 | 136 | 137 | def to_attention_info(d: AttentionMask) -> AttentionInfo: 138 | return AttentionInfo( 139 | n_heads=d.n_head, 140 | ctx_blks_q=d.n_query_block, 141 | ctx_blks_k=d.n_key_block, 142 | block_size=d.block_size, 143 | pytorch_attn_bias=None, 144 | ) 145 | 146 | 147 | def make_full_layout(d: AttentionMask) -> np.ndarray: 148 | """ 149 | Returns the `context_size x context_size` layout matrix described by `d`. If the layout is dependent on the index of 150 | the attention head, a `attention_head x context_size x context_size` layout matrix is returned instead. 151 | """ 152 | 153 | if not d.is_head_specific: 154 | u = np.reshape(d.global_layout, [d.n_query_block, d.n_key_block, 1, 1]) 155 | r = product(range(d.n_query_block), range(d.n_key_block)) 156 | v = np.array([d.block_layout(None, 0, i, j, 0) for i, j in r]) 157 | v = np.reshape(v, [d.n_query_block, d.n_key_block, d.block_size, d.block_size]) 158 | 159 | w = u * v 160 | w = np.transpose(w, [0, 2, 1, 3]) 161 | w = np.reshape(w, [d.query_context_size, d.key_context_size]) 162 | return w 163 | else: 164 | if len(d.global_layout.shape) == 2: 165 | u = np.reshape(d.global_layout, [1, d.n_query_block, d.n_key_block, 1, 1]) 166 | u = np.tile(u, [d.n_head, 1, 1, 1, 1]) 167 | elif len(d.global_layout.shape) == 3: 168 | u = np.reshape(d.global_layout, [d.n_head, d.n_query_block, d.n_key_block, 1, 1]) 169 | else: 170 | raise RuntimeError() 171 | 172 | s = product(range(d.n_head), range(d.n_query_block), range(d.n_key_block)) 173 | v = np.array([d.block_layout(None, i, j, k, 0) for i, j, k in s]) 174 | v = np.reshape(v, [d.n_head, d.n_query_block, d.n_key_block, d.block_size, d.block_size]) 175 | 176 | w = u * v 177 | w = np.transpose(w, [0, 1, 3, 2, 4]) 178 | w = np.reshape(w, [d.n_head, d.query_context_size, d.key_context_size]) 179 | return w 180 | -------------------------------------------------------------------------------- /imagen_pytorch/clip/config.yaml: -------------------------------------------------------------------------------- 1 | logit_scale: 100.0 2 | 3 | # Diffusion settings 4 | beta_schedule: "squaredcos_cap_v2" 5 | n_timesteps: 1000 6 | 7 | # Architecture settings 8 | image_size: 64 9 | patch_size: 4 10 | n_vocab: 65536 11 | max_text_len: 77 12 | n_embd: 512 13 | n_head_state_text: 64 14 | n_head_text: 8 15 | n_xf_blocks_text: 12 16 | n_head_state_image: 64 17 | n_head_image: 12 18 | n_xf_blocks_image: 12 19 | -------------------------------------------------------------------------------- /imagen_pytorch/clip/encoders.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import OrderedDict 3 | from typing import List, Optional, Tuple, cast 4 | 5 | import attr 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from .attention import ( 12 | AttentionInfo, 13 | DenseAttentionMask, 14 | DenseCausalAttentionMask, 15 | make_full_layout, 16 | to_attention_info, 17 | ) 18 | from .utils import Affine, LayerNorm, zero_key_bias_grad 19 | 20 | # Constants used in the original CLIP implementation. 21 | image_channel_means = [122.77093945, 116.74601272, 104.09373519] 22 | image_channel_stds = [68.50053285, 66.63215831, 70.32316309] 23 | 24 | 25 | @attr.s(eq=False, repr=False) 26 | class TextEmbedding(nn.Module): 27 | n_vocab: int = attr.ib() 28 | n_context: int = attr.ib() 29 | n_state: int = attr.ib() 30 | device: torch.device = attr.ib(default=torch.device("cuda")) 31 | 32 | def __attrs_post_init__(self) -> None: 33 | super().__init__() 34 | 35 | w_voc = torch.empty((self.n_vocab, self.n_state), dtype=torch.float32, device=self.device) 36 | w_pos = torch.empty((self.n_context, self.n_state), dtype=torch.float32, device=self.device) 37 | 38 | with torch.no_grad(): 39 | w_voc.normal_(std=0.02) 40 | w_pos.normal_(std=0.01) 41 | 42 | self.w_voc = nn.Parameter(w_voc) 43 | self.w_pos = nn.Parameter(w_pos) 44 | 45 | def forward(self, x: torch.Tensor) -> torch.Tensor: 46 | if len(x.shape) != 2: 47 | raise ValueError() 48 | 49 | return F.embedding(x, self.w_voc) + self.w_pos[None, :, :] 50 | 51 | 52 | @attr.s(eq=False, repr=False) 53 | class ImageEmbedding(nn.Module): 54 | image_size: int = attr.ib() 55 | patch_size: int = attr.ib() 56 | n_state: int = attr.ib() 57 | n_timestep: int = attr.ib(default=0) 58 | device: torch.device = attr.ib(default=torch.device("cuda")) 59 | 60 | def __attrs_post_init__(self) -> None: 61 | super().__init__() 62 | 63 | if self.image_size % self.patch_size != 0: 64 | raise ValueError() 65 | 66 | n_patch = self.image_size // self.patch_size 67 | patch_proj = torch.empty( 68 | (self.n_state, 3) + 2 * (self.patch_size,), dtype=torch.float32, device=self.device 69 | ) 70 | w_pos = torch.empty( 71 | (1 + n_patch ** 2, self.n_state), dtype=torch.float32, device=self.device 72 | ) 73 | 74 | with torch.no_grad(): 75 | if self.n_timestep == 0: 76 | pred_state = torch.empty((self.n_state,), dtype=torch.float32, device=self.device) 77 | pred_state.normal_(std=1 / np.sqrt(self.n_state)) 78 | self.pred_state = nn.Parameter(pred_state) 79 | else: 80 | w_t = torch.empty( 81 | (self.n_timestep, self.n_state), dtype=torch.float32, device=self.device 82 | ) 83 | w_t.normal_(std=1 / np.sqrt(self.n_state)) 84 | self.w_t = nn.Parameter(w_t) 85 | 86 | patch_proj.normal_(std=np.sqrt(2 / (self.n_state * self.patch_size ** 2))) 87 | w_pos.normal_(std=1 / np.sqrt(self.n_state)) 88 | 89 | self.patch_proj = nn.Parameter(patch_proj) 90 | self.w_pos = nn.Parameter(w_pos) 91 | 92 | self.channel_means = torch.tensor( 93 | image_channel_means, dtype=torch.float32, device=self.device 94 | )[None, :, None, None] 95 | self.channel_stds = torch.tensor( 96 | image_channel_stds, dtype=torch.float32, device=self.device 97 | )[None, :, None, None] 98 | self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device) 99 | 100 | def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None) -> torch.Tensor: 101 | if len(x.shape) != 4: 102 | raise ValueError("input should be 4d") 103 | if x.shape[1] != 3: 104 | raise ValueError("input should have 3 channels") 105 | if not (x.shape[2] == self.image_size and x.shape[3] == self.image_size): 106 | raise ValueError(f"input is not {self.image_size} x {self.image_size}") 107 | 108 | if (self.n_timestep == 0 and t is not None) or (self.n_timestep != 0 and t is None): 109 | raise ValueError() 110 | if self.n_timestep != 0: 111 | assert t is not None 112 | if len(t.shape) != 1: 113 | raise ValueError() 114 | if t.shape[0] != x.shape[0]: 115 | raise ValueError() 116 | 117 | x = (x - self.channel_means) / self.channel_stds 118 | x = F.conv2d(x, self.patch_proj, stride=self.patch_size) 119 | x = x.reshape(x.shape[0], self.n_state, (self.image_size // self.patch_size) ** 2).permute( 120 | 0, 2, 1 121 | ) 122 | 123 | sot = ( 124 | self.pred_state[None, None].expand(x.shape[0], -1, -1) 125 | if self.n_timestep == 0 126 | else F.embedding(cast(torch.Tensor, t), self.w_t)[:, None] 127 | ) 128 | x = torch.cat((sot, x), dim=1) + self.w_pos[None] 129 | return self.ln(x) 130 | 131 | 132 | @attr.s(eq=False, repr=False) 133 | class AttentionResblock(nn.Module): 134 | n_state: int = attr.ib() 135 | n_resblocks: int = attr.ib() 136 | attn_fn: AttentionInfo = attr.ib() 137 | device: torch.device = attr.ib(default=torch.device("cuda")) 138 | 139 | def __attrs_post_init__(self) -> None: 140 | super().__init__() 141 | 142 | self.n_head_state = self.n_state // self.attn_fn.n_heads 143 | self.qk_scale = 1 / np.sqrt(self.n_head_state) 144 | 145 | self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device) 146 | self.f_q = Affine( 147 | self.n_state, 148 | self.n_state, 149 | std=1 / math.sqrt(self.n_state), 150 | use_bias=True, 151 | bias_filter_fn=zero_key_bias_grad, 152 | device=self.device, 153 | ) 154 | self.f_k = Affine( 155 | self.n_state, 156 | self.n_state, 157 | std=1 / math.sqrt(self.n_state), 158 | use_bias=False, 159 | bias_filter_fn=zero_key_bias_grad, 160 | device=self.device, 161 | ) 162 | self.f_v = Affine( 163 | self.n_state, 164 | self.n_state, 165 | std=1 / math.sqrt(self.n_state), 166 | use_bias=True, 167 | bias_filter_fn=zero_key_bias_grad, 168 | device=self.device, 169 | ) 170 | self.f_c = Affine( 171 | self.n_state, 172 | self.n_state, 173 | use_bias=True, 174 | std=1 / np.sqrt(self.n_state * self.n_resblocks ** 2), 175 | device=self.device, 176 | ) # XXX 177 | 178 | def forward(self, m: torch.Tensor) -> torch.Tensor: 179 | n_context = m.shape[1] 180 | n_query_pad = self.attn_fn.ctx_blks_q * self.attn_fn.block_size - n_context 181 | n_key_pad = self.attn_fn.ctx_blks_k * self.attn_fn.block_size - n_context 182 | assert n_query_pad >= 0 183 | assert n_key_pad >= 0 184 | 185 | r = m 186 | r = self.ln(r) 187 | q, k, v = self.f_q(r), self.f_k(r), self.f_v(r) 188 | 189 | if n_query_pad != 0: 190 | q = F.pad(q, (0, 0, 0, n_query_pad)) 191 | 192 | if n_key_pad != 0: 193 | k = F.pad(k, (0, 0, 0, n_key_pad)) 194 | v = F.pad(v, (0, 0, 0, n_key_pad)) 195 | 196 | q = q.view([q.shape[0], -1, self.attn_fn.n_heads, self.n_head_state]).permute((0, 2, 1, 3)) 197 | k = k.view([k.shape[0], -1, self.attn_fn.n_heads, self.n_head_state]).permute((0, 2, 1, 3)) 198 | v = v.view([v.shape[0], -1, self.attn_fn.n_heads, self.n_head_state]).permute((0, 2, 1, 3)) 199 | w = torch.einsum( 200 | "bhcd,bhkd->bhck", q * math.sqrt(self.qk_scale), k * math.sqrt(self.qk_scale) 201 | ) 202 | 203 | if hasattr(self.attn_fn, "pytorch_attn_bias"): 204 | bias = self.attn_fn.pytorch_attn_bias 205 | assert len(bias.shape) in {2, 3} 206 | 207 | if len(bias.shape) == 2: 208 | w = torch.softmax(w + self.attn_fn.pytorch_attn_bias[None, None], dim=-1) 209 | elif len(bias.shape) == 3: 210 | w = torch.softmax(w + self.attn_fn.pytorch_attn_bias[None], dim=-1) 211 | else: 212 | w = torch.softmax(w, dim=-1) 213 | 214 | r = torch.einsum("bhck,bhkd->bhcd", w, v) 215 | r = r.permute((0, 2, 1, 3)).reshape((r.shape[0], -1, self.n_state)) 216 | 217 | if n_query_pad != 0: 218 | r = r[:, :-n_query_pad] 219 | 220 | assert r.shape[1] == n_context 221 | 222 | r = self.f_c(r) 223 | return m + r 224 | 225 | 226 | @attr.s(eq=False, repr=False) 227 | class FullyConnectedResblock(nn.Module): 228 | """ 229 | Not imported from other files because we retain Alec's original inits. 230 | """ 231 | 232 | n_state: int = attr.ib() 233 | n_resblocks: int = attr.ib() 234 | device: torch.device = attr.ib(default=torch.device("cuda")) 235 | 236 | def __attrs_post_init__(self) -> None: 237 | super().__init__() 238 | 239 | self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device) 240 | self.f_1 = Affine( 241 | self.n_state, 242 | 4 * self.n_state, 243 | use_bias=True, 244 | std=np.sqrt(2 / (4 * self.n_state)), 245 | device=self.device, 246 | ) 247 | self.f_2 = Affine( 248 | 4 * self.n_state, 249 | self.n_state, 250 | use_bias=True, 251 | std=1 / np.sqrt(self.n_state * self.n_resblocks ** 2), 252 | device=self.device, 253 | ) # XXX 254 | 255 | def forward(self, m: torch.Tensor) -> torch.Tensor: 256 | r = m 257 | r = self.ln(r) 258 | 259 | r = self.f_2(F.gelu(self.f_1(r))) 260 | return m + r 261 | 262 | 263 | @attr.s(eq=False, repr=False) 264 | class TransformerBlock(nn.Module): 265 | n_state: int = attr.ib() 266 | n_resblocks: int = attr.ib() 267 | attn_fn: AttentionInfo = attr.ib() 268 | device: torch.device = attr.ib(default=torch.device("cuda")) 269 | 270 | def __attrs_post_init__(self) -> None: 271 | super().__init__() 272 | 273 | self.f_attn = AttentionResblock( 274 | self.n_state, 275 | self.n_resblocks, 276 | self.attn_fn, 277 | self.device, 278 | ) 279 | self.f_mlp = FullyConnectedResblock(self.n_state, self.n_resblocks, self.device) 280 | 281 | def forward(self, x: torch.Tensor) -> torch.Tensor: 282 | return self.f_mlp(self.f_attn(x)) 283 | 284 | 285 | @attr.s(eq=False, repr=False) 286 | class TextFeatureExtractor(nn.Module): 287 | n_state: int = attr.ib() 288 | n_embd: int = attr.ib() 289 | device: torch.device = attr.ib(default=torch.device("cuda")) 290 | 291 | def __attrs_post_init__(self) -> None: 292 | super().__init__() 293 | 294 | self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device) 295 | self.f = Affine(self.n_state, self.n_embd, use_bias=False, device=self.device) 296 | 297 | def forward( 298 | self, text: torch.Tensor, text_len: torch.Tensor, return_probe_features: bool = False 299 | ) -> torch.Tensor: 300 | if len(text.shape) != 3: 301 | raise ValueError("expected text to be 3d") 302 | if len(text_len.shape) != 1: 303 | raise ValueError("expected text length to be 1d") 304 | if text.shape[0] != text_len.shape[0]: 305 | raise ValueError("text and text_len have inconsistent batch dimensions") 306 | 307 | index = (text_len - 1)[:, None, None].expand(-1, 1, text.shape[2]) 308 | x = torch.gather(text, dim=1, index=index) 309 | assert list(x.shape) == [text.shape[0], 1, text.shape[2]] 310 | 311 | if return_probe_features: 312 | return x[:, 0] 313 | 314 | x = self.ln(x) 315 | return self.f(x[:, 0]) 316 | 317 | 318 | @attr.s(eq=False, repr=False) 319 | class ImageFeatureExtractor(nn.Module): 320 | n_state: int = attr.ib() 321 | n_embd: int = attr.ib() 322 | device: torch.device = attr.ib(default=torch.device("cuda")) 323 | 324 | def __attrs_post_init__(self) -> None: 325 | super().__init__() 326 | 327 | self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device) 328 | self.f = Affine(self.n_state, self.n_embd, use_bias=False, device=self.device) 329 | 330 | def forward(self, x: torch.Tensor, return_probe_features: bool = False) -> torch.Tensor: 331 | if return_probe_features: 332 | return x[:, 0] 333 | 334 | x = self.ln(x[:, :1]) 335 | return self.f(x[:, 0]) 336 | 337 | 338 | @attr.s(eq=False, repr=False) 339 | class TextEncoder(nn.Module): 340 | n_bpe_vocab: int = attr.ib() 341 | max_text_len: int = attr.ib() 342 | n_embd: int = attr.ib() 343 | n_head: int = attr.ib() 344 | n_xf_blocks: int = attr.ib() 345 | n_head_state: int = attr.ib(default=64) 346 | device: torch.device = attr.ib(default=torch.device("cuda")) 347 | block_size: int = attr.ib(init=False, default=32) 348 | 349 | def __attrs_post_init__(self) -> None: 350 | super().__init__() 351 | 352 | self.n_state = self.n_head * self.n_head_state 353 | n_rounded_context = self.block_size * int(math.ceil(self.max_text_len / self.block_size)) 354 | n_pad = n_rounded_context - self.max_text_len 355 | 356 | args = ( 357 | n_rounded_context, 358 | n_rounded_context, 359 | self.block_size, 360 | self.n_head, 361 | False, 362 | n_pad, 363 | n_pad, 364 | ) 365 | mask = DenseCausalAttentionMask(*args) 366 | attn_fn = to_attention_info(mask) 367 | 368 | m = 1 - make_full_layout(mask).astype(np.float32) 369 | m[m == 1] = -1e10 370 | attn_fn.pytorch_attn_bias = torch.from_numpy(m).to(self.device) 371 | 372 | blocks: List[Tuple[str, nn.Module]] = [ 373 | ( 374 | "input", 375 | TextEmbedding( 376 | self.n_bpe_vocab, self.max_text_len, self.n_state, device=self.device 377 | ), 378 | ) 379 | ] 380 | 381 | for i in range(self.n_xf_blocks): 382 | blocks.append( 383 | ( 384 | f"block_{i}", 385 | TransformerBlock(self.n_state, 2 * self.n_xf_blocks, attn_fn, self.device), 386 | ) 387 | ) 388 | 389 | blocks.append( 390 | ("output", TextFeatureExtractor(self.n_state, self.n_embd, device=self.device)) 391 | ) 392 | 393 | self.blocks = nn.ModuleDict(OrderedDict(blocks)) 394 | 395 | def forward( 396 | self, 397 | text: torch.Tensor, 398 | text_len: torch.Tensor, 399 | return_probe_features: bool = False, 400 | ) -> torch.Tensor: 401 | 402 | n_batch = text.shape[0] 403 | h = self.blocks["input"](text) 404 | 405 | for i in range(self.n_xf_blocks): 406 | h = self.blocks[f"block_{i}"](h) 407 | 408 | h = self.blocks["output"](h, text_len, return_probe_features=return_probe_features) 409 | 410 | assert list(h.shape) == [ 411 | n_batch, 412 | self.n_embd if not return_probe_features else self.n_state, 413 | ] 414 | return h 415 | 416 | 417 | @attr.s(eq=False, repr=False) 418 | class ImageEncoder(nn.Module): 419 | image_size: int = attr.ib() 420 | patch_size: int = attr.ib() 421 | n_embd: int = attr.ib() 422 | n_head: int = attr.ib() 423 | n_xf_blocks: int = attr.ib() 424 | n_head_state: int = attr.ib(default=64) 425 | n_timestep: int = attr.ib(default=0) 426 | device: torch.device = attr.ib(default=torch.device("cuda")) 427 | block_size: int = attr.ib(init=False, default=32) 428 | 429 | def __attrs_post_init__(self) -> None: 430 | super().__init__() 431 | 432 | self.n_state = self.n_head * self.n_head_state 433 | self.n_context = 1 + (self.image_size // self.patch_size) ** 2 434 | n_rounded_context = self.block_size * int(math.ceil(self.n_context / self.block_size)) 435 | n_pad = n_rounded_context - self.n_context 436 | 437 | args = ( 438 | n_rounded_context, 439 | n_rounded_context, 440 | self.block_size, 441 | self.n_head, 442 | False, 443 | n_pad, 444 | n_pad, 445 | ) 446 | mask = DenseAttentionMask(*args) 447 | attn_fn = to_attention_info(mask) 448 | 449 | m = 1 - make_full_layout(mask).astype(np.float32) 450 | m[m == 1] = -1e10 451 | attn_fn.pytorch_attn_bias = torch.from_numpy(m).to(self.device) 452 | 453 | blocks: List[Tuple[str, nn.Module]] = [ 454 | ( 455 | "input", 456 | ImageEmbedding( 457 | self.image_size, 458 | self.patch_size, 459 | self.n_state, 460 | n_timestep=self.n_timestep, 461 | device=self.device, 462 | ), 463 | ) 464 | ] 465 | 466 | for i in range(self.n_xf_blocks): 467 | blocks.append( 468 | ( 469 | f"block_{i}", 470 | TransformerBlock(self.n_state, 2 * self.n_xf_blocks, attn_fn, self.device), 471 | ) 472 | ) 473 | 474 | blocks.append(("output", ImageFeatureExtractor(self.n_state, self.n_embd, self.device))) 475 | 476 | self.blocks = nn.ModuleDict(OrderedDict(blocks)) 477 | 478 | def forward( 479 | self, 480 | image: torch.Tensor, 481 | timesteps: Optional[torch.Tensor] = None, 482 | return_probe_features: bool = False, 483 | ) -> torch.Tensor: 484 | n_batch = image.shape[0] 485 | h = self.blocks["input"](image, t=timesteps) 486 | 487 | for i in range(self.n_xf_blocks): 488 | h = self.blocks[f"block_{i}"](h) 489 | 490 | h = self.blocks["output"](h, return_probe_features=return_probe_features) 491 | 492 | assert list(h.shape) == [ 493 | n_batch, 494 | self.n_embd if not return_probe_features else self.n_state, 495 | ] 496 | 497 | return h 498 | -------------------------------------------------------------------------------- /imagen_pytorch/clip/model_creation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import lru_cache 3 | from typing import Any, Callable, Dict, List, Optional, Tuple 4 | 5 | import attr 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import yaml 10 | from glide_text2im.tokenizer.simple_tokenizer import SimpleTokenizer 11 | 12 | from .encoders import ImageEncoder, TextEncoder 13 | 14 | 15 | @lru_cache() 16 | def default_config_path() -> str: 17 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "config.yaml") 18 | 19 | 20 | @attr.s 21 | class CLIPModel: 22 | config: Dict[str, Any] = attr.ib() 23 | text_encoder: nn.Module = attr.ib() 24 | image_encoder: nn.Module = attr.ib() 25 | logit_scale: torch.Tensor = attr.ib() 26 | device: torch.device = attr.ib() 27 | tokenizer: SimpleTokenizer = attr.ib() 28 | 29 | def encode_prompts(self, prompts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: 30 | tokens = [] 31 | lens = [] 32 | for prompt in prompts: 33 | sub_tokens, sub_len = self.tokenizer.padded_tokens_and_len( 34 | self.tokenizer.encode(prompt), self.text_encoder.max_text_len 35 | ) 36 | tokens.append(sub_tokens) 37 | lens.append(sub_len) 38 | return ( 39 | torch.tensor(tokens).to(dtype=torch.long, device=self.device), 40 | torch.tensor(lens).to(dtype=torch.long, device=self.device), 41 | ) 42 | 43 | def text_embeddings(self, prompts: List[str]) -> torch.Tensor: 44 | tokens, lens = self.encode_prompts(prompts) 45 | z_t = self.text_encoder(tokens, lens) 46 | return z_t / (torch.linalg.norm(z_t, dim=-1, keepdim=True) + 1e-12) 47 | 48 | def image_embeddings(self, images: torch.Tensor, t: torch.Tensor) -> torch.Tensor: 49 | z_i = self.image_encoder((images + 1) * 127.5, t) 50 | return z_i / (torch.linalg.norm(z_i, dim=-1, keepdim=True) + 1e-12) 51 | 52 | def cond_fn(self, prompts: List[str], grad_scale: float) -> Callable[..., torch.Tensor]: 53 | with torch.no_grad(): 54 | z_t = self.text_embeddings(prompts) 55 | 56 | def cond_fn(x, t, grad_scale=grad_scale, **kwargs): 57 | with torch.enable_grad(): 58 | x_var = x.detach().requires_grad_(True) 59 | z_i = self.image_embeddings(x_var, t) 60 | loss = torch.exp(self.logit_scale) * (z_t * z_i).sum() 61 | grad = torch.autograd.grad(loss, x_var)[0].detach() 62 | return grad * grad_scale 63 | 64 | return cond_fn 65 | 66 | 67 | def create_clip_model( 68 | config_path: Optional[str] = None, 69 | device: Optional[torch.device] = None, 70 | tokenizer: Optional[SimpleTokenizer] = None, 71 | ) -> CLIPModel: 72 | if config_path is None: 73 | config_path = default_config_path() 74 | if device is None: 75 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 76 | if tokenizer is None: 77 | tokenizer = SimpleTokenizer() 78 | 79 | with open(config_path, "r") as f: 80 | config = yaml.load(f, Loader=yaml.SafeLoader) 81 | 82 | text_encoder = TextEncoder( 83 | n_bpe_vocab=config["n_vocab"], 84 | max_text_len=config["max_text_len"], 85 | n_embd=config["n_embd"], 86 | n_head=config["n_head_text"], 87 | n_xf_blocks=config["n_xf_blocks_text"], 88 | n_head_state=config["n_head_state_text"], 89 | device=device, 90 | ) 91 | 92 | image_encoder = ImageEncoder( 93 | image_size=config["image_size"], 94 | patch_size=config["patch_size"], 95 | n_embd=config["n_embd"], 96 | n_head=config["n_head_image"], 97 | n_xf_blocks=config["n_xf_blocks_image"], 98 | n_head_state=config["n_head_state_image"], 99 | n_timestep=config["n_timesteps"], 100 | device=device, 101 | ) 102 | 103 | logit_scale = torch.tensor( 104 | np.log(config["logit_scale"]), 105 | dtype=torch.float32, 106 | device=device, 107 | requires_grad=False, 108 | ) 109 | 110 | return CLIPModel( 111 | config=config, 112 | text_encoder=text_encoder, 113 | image_encoder=image_encoder, 114 | logit_scale=logit_scale, 115 | device=device, 116 | tokenizer=tokenizer, 117 | ) 118 | -------------------------------------------------------------------------------- /imagen_pytorch/clip/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Callable, Optional 3 | 4 | import attr 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | FilterFn = Callable[[torch.Tensor], torch.Tensor] 10 | 11 | 12 | class ZeroKeyBiasGrad(torch.autograd.Function): 13 | @staticmethod 14 | def forward(ctx, x): 15 | return x 16 | 17 | @staticmethod 18 | def backward(ctx, output_grad): 19 | output_grad = output_grad.clone() 20 | output_grad.chunk(3)[1].zero_() 21 | return output_grad 22 | 23 | 24 | def zero_key_bias_grad(x: torch.Tensor) -> torch.Tensor: 25 | return ZeroKeyBiasGrad.apply(x) 26 | 27 | 28 | @attr.s(eq=False, repr=False) 29 | class LayerNorm(nn.Module): 30 | n_state: int = attr.ib() 31 | eps: float = attr.ib(default=1e-6) 32 | device: torch.device = attr.ib(default=torch.device("cuda")) 33 | 34 | def __attrs_post_init__(self) -> None: 35 | super().__init__() 36 | self.g = nn.Parameter(torch.ones((self.n_state,), dtype=torch.float32, device=self.device)) 37 | self.b = nn.Parameter(torch.zeros((self.n_state,), dtype=torch.float32, device=self.device)) 38 | self.g.weight_decay_level = "disable" # type: ignore 39 | self.b.weight_decay_level = "disable" # type: ignore 40 | 41 | def forward(self, x: torch.Tensor) -> torch.Tensor: 42 | return F.layer_norm( 43 | x.type(torch.float32), torch.Size((self.n_state,)), self.g, self.b, self.eps 44 | ) 45 | 46 | 47 | @attr.s(eq=False, repr=False) 48 | class Affine(nn.Module): 49 | n_in: int = attr.ib() 50 | n_out: int = attr.ib() 51 | use_bias: bool = attr.ib(default=True) 52 | use_admnet_init: bool = attr.ib(default=False) 53 | std: Optional[float] = attr.ib(default=None) 54 | extra_init_scale: Optional[float] = attr.ib(default=None) 55 | bias_filter_fn: FilterFn = attr.ib(default=lambda x: x) 56 | device: torch.device = attr.ib(default=torch.device("cuda")) 57 | 58 | def __attrs_post_init__(self) -> None: 59 | super().__init__() 60 | 61 | if not self.use_admnet_init: 62 | self.std = self.std if self.std is not None else math.sqrt(2 / (self.n_in + self.n_out)) 63 | self.std = ( 64 | self.std if self.extra_init_scale is None else self.std * self.extra_init_scale 65 | ) 66 | 67 | w = torch.empty((self.n_out, self.n_in), dtype=torch.float32, device=self.device) 68 | self.w = nn.Parameter(w) 69 | 70 | if self.use_bias: 71 | self.b = nn.Parameter( 72 | torch.zeros((self.n_out,), dtype=torch.float32, device=self.device) 73 | ) 74 | self.b.weight_decay_level = "disable" # type: ignore 75 | else: 76 | if self.extra_init_scale is not None: 77 | raise ValueError("extra_init_scale incompatible with admnet init") 78 | 79 | w = torch.empty((self.n_out, self.n_in), dtype=torch.float32, device=self.device) 80 | 81 | if self.use_bias: 82 | b = torch.empty((self.n_out,), dtype=torch.float32, device=self.device) 83 | 84 | self.w = nn.Parameter(w) 85 | 86 | if self.use_bias: 87 | self.b = nn.Parameter(b) 88 | self.b.weight_decay_level = "disable" # type: ignore 89 | 90 | def forward(self, x: torch.Tensor) -> torch.Tensor: 91 | w = self.w if self.w.dtype == x.dtype else self.w.to(x.dtype) 92 | b = ( 93 | self.bias_filter_fn(self.b if self.b.dtype == x.dtype else self.b.to(x.dtype)) 94 | if self.use_bias 95 | else None 96 | ) 97 | return F.linear(x, w, b) 98 | -------------------------------------------------------------------------------- /imagen_pytorch/dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import blobfile as bf 3 | import numpy as np 4 | from torch.utils.data import DataLoader, Dataset 5 | import torch 6 | import json 7 | import os 8 | from transformers import AutoTokenizer 9 | 10 | 11 | 12 | def get_loader(batch_size, resolution, image_dir, df, zero_text_prob=0.1, tokenizer_name='t5-large', max_len=128, shuffle=True,): 13 | dataset = ImageDataset(resolution, image_dir, df, tokenizer_name, max_len, zero_text_prob) 14 | loader = DataLoader( 15 | dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1, drop_last=True 16 | ) 17 | return loader 18 | 19 | class ImageDataset(Dataset): 20 | def __init__(self, resolution, image_dir, df, tokenizer_name, max_len, zero_text_prob): 21 | super().__init__() 22 | self.resolution = resolution 23 | self.image_dir = image_dir 24 | self.df = df 25 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 26 | self.max_len = max_len 27 | self.zero_text_prob = zero_text_prob 28 | 29 | def __len__(self): 30 | return len(self.df) 31 | 32 | def __getitem__(self, idx): 33 | out_dict = {} 34 | path, text = self.df.iloc[idx]['path'], self.df.iloc[idx]['text'] 35 | 36 | with bf.BlobFile(os.path.join(self.image_dir, path), "rb") as f: 37 | pil_image = Image.open(f) 38 | pil_image.load() 39 | 40 | 41 | while min(*pil_image.size) >= 2 * self.resolution: 42 | pil_image = pil_image.resize( 43 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 44 | ) 45 | 46 | scale = self.resolution / min(*pil_image.size) 47 | pil_image = pil_image.resize( 48 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 49 | ) 50 | 51 | arr = np.array(pil_image.convert("RGB")) 52 | crop_y = (arr.shape[0] - self.resolution) // 2 53 | crop_x = (arr.shape[1] - self.resolution) // 2 54 | arr = arr[crop_y: crop_y + self.resolution, crop_x: crop_x + self.resolution] 55 | arr = arr.astype(np.float32) / 127.5 - 1 56 | if np.random.binomial(1, self.zero_text_prob): 57 | text = '' 58 | text_encoding = self.tokenizer( 59 | text, 60 | max_length=self.max_len, 61 | padding="max_length", 62 | truncation=True, 63 | return_attention_mask=True, 64 | add_special_tokens=True, 65 | return_tensors="pt") 66 | 67 | out_dict["tokens"] = text_encoding['input_ids'][0] 68 | out_dict["mask"] = text_encoding['attention_mask'][0] 69 | return np.transpose(arr, [2, 0, 1]), out_dict -------------------------------------------------------------------------------- /imagen_pytorch/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import lru_cache 3 | from typing import Dict, Optional 4 | 5 | import requests 6 | import torch as th 7 | from filelock import FileLock 8 | from tqdm.auto import tqdm 9 | 10 | MODEL_PATHS = { 11 | "base": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt", 12 | "upsample": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt", 13 | "base-inpaint": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base_inpaint.pt", 14 | "upsample-inpaint": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample_inpaint.pt", 15 | "clip/image-enc": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/clip_image_enc.pt", 16 | "clip/text-enc": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/clip_text_enc.pt", 17 | } 18 | 19 | 20 | @lru_cache() 21 | def default_cache_dir() -> str: 22 | return os.path.join(os.path.abspath(os.getcwd()), "glide_model_cache") 23 | 24 | 25 | def fetch_file_cached( 26 | url: str, progress: bool = True, cache_dir: Optional[str] = None, chunk_size: int = 4096 27 | ) -> str: 28 | """ 29 | Download the file at the given URL into a local file and return the path. 30 | 31 | If cache_dir is specified, it will be used to download the files. 32 | Otherwise, default_cache_dir() is used. 33 | """ 34 | if cache_dir is None: 35 | cache_dir = default_cache_dir() 36 | os.makedirs(cache_dir, exist_ok=True) 37 | local_path = os.path.join(cache_dir, url.split("/")[-1]) 38 | if os.path.exists(local_path): 39 | return local_path 40 | response = requests.get(url, stream=True) 41 | size = int(response.headers.get("content-length", "0")) 42 | with FileLock(local_path + ".lock"): 43 | if progress: 44 | pbar = tqdm(total=size, unit="iB", unit_scale=True) 45 | tmp_path = local_path + ".tmp" 46 | with open(tmp_path, "wb") as f: 47 | for chunk in response.iter_content(chunk_size): 48 | if progress: 49 | pbar.update(len(chunk)) 50 | f.write(chunk) 51 | os.rename(tmp_path, local_path) 52 | if progress: 53 | pbar.close() 54 | return local_path 55 | 56 | 57 | def load_checkpoint( 58 | checkpoint_name: str, 59 | device: th.device, 60 | progress: bool = True, 61 | cache_dir: Optional[str] = None, 62 | chunk_size: int = 4096, 63 | ) -> Dict[str, th.Tensor]: 64 | if checkpoint_name not in MODEL_PATHS: 65 | raise ValueError( 66 | f"Unknown checkpoint name {checkpoint_name}. Known names are: {MODEL_PATHS.keys()}." 67 | ) 68 | path = fetch_file_cached( 69 | MODEL_PATHS[checkpoint_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size 70 | ) 71 | return th.load(path, map_location=device) 72 | -------------------------------------------------------------------------------- /imagen_pytorch/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to inference with 16-bit precision. 3 | """ 4 | 5 | import torch.nn as nn 6 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 7 | 8 | def convert_module_to_f16(l): 9 | """ 10 | Convert primitive modules to float16. 11 | """ 12 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 13 | l.weight.data = l.weight.data.half() 14 | if l.bias is not None: 15 | l.bias.data = l.bias.data.half() 16 | 17 | 18 | def convert_module_to_f32(l): 19 | """ 20 | Convert primitive modules to float32, undoing convert_module_to_f16(). 21 | """ 22 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 23 | l.weight.data = l.weight.data.float() 24 | if l.bias is not None: 25 | l.bias.data = l.bias.data.float() 26 | def make_master_params(model_params): 27 | """ 28 | Copy model parameters into a (differently-shaped) list of full-precision 29 | parameters. 30 | """ 31 | master_params = _flatten_dense_tensors( 32 | [param.detach().float() for param in model_params] 33 | ) 34 | master_params = nn.Parameter(master_params) 35 | master_params.requires_grad = True 36 | return [master_params] 37 | 38 | 39 | def model_grads_to_master_grads(model_params, master_params): 40 | """ 41 | Copy the gradients from the model parameters into the master parameters 42 | from make_master_params(). 43 | """ 44 | master_params[0].grad = _flatten_dense_tensors( 45 | [param.grad.data.detach().float() for param in model_params] 46 | ) 47 | 48 | 49 | def master_params_to_model_params(model_params, master_params): 50 | """ 51 | Copy the master parameter data back into the model parameters. 52 | """ 53 | # Without copying to a list, if a generator is passed, this will 54 | # silently not copy any parameters. 55 | model_params = list(model_params) 56 | 57 | for param, master_param in zip( 58 | model_params, unflatten_master_params(model_params, master_params) 59 | ): 60 | param.detach().copy_(master_param) 61 | 62 | 63 | def unflatten_master_params(model_params, master_params): 64 | """ 65 | Unflatten the master parameters to look like model_params. 66 | """ 67 | return _unflatten_dense_tensors(master_params[0].detach(), model_params) 68 | 69 | 70 | def zero_grad(model_params): 71 | for param in model_params: 72 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 73 | if param.grad is not None: 74 | param.grad.detach_() 75 | param.grad.zero_() -------------------------------------------------------------------------------- /imagen_pytorch/gaussian_diffusion.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | This code started out as a PyTorch port of Ho et al's diffusion models: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py 5 | Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules. 6 | """ 7 | 8 | import enum 9 | import math 10 | from copy import deepcopy 11 | import numpy as np 12 | import torch as th 13 | 14 | from .utils import mean_flat 15 | from .losses import normal_kl, discretized_gaussian_log_likelihood 16 | 17 | 18 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): 19 | """ 20 | Get a pre-defined beta schedule for the given name. 21 | The beta schedule library consists of beta schedules which remain similar 22 | in the limit of num_diffusion_timesteps. 23 | Beta schedules may be added, but should not be removed or changed once 24 | they are committed to maintain backwards compatibility. 25 | """ 26 | if schedule_name == "linear": 27 | # Linear schedule from Ho et al, extended to work for any number of 28 | # diffusion steps. 29 | scale = 1000 / num_diffusion_timesteps 30 | beta_start = scale * 0.0001 31 | beta_end = scale * 0.02 32 | return np.linspace( 33 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 34 | ) 35 | elif schedule_name == "cosine": 36 | return betas_for_alpha_bar( 37 | num_diffusion_timesteps, 38 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, 39 | ) 40 | else: 41 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}") 42 | 43 | 44 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 45 | """ 46 | Create a beta schedule that discretizes the given alpha_t_bar function, 47 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 48 | :param num_diffusion_timesteps: the number of betas to produce. 49 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 50 | produces the cumulative product of (1-beta) up to that 51 | part of the diffusion process. 52 | :param max_beta: the maximum beta to use; use values lower than 1 to 53 | prevent singularities. 54 | """ 55 | betas = [] 56 | for i in range(num_diffusion_timesteps): 57 | t1 = i / num_diffusion_timesteps 58 | t2 = (i + 1) / num_diffusion_timesteps 59 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 60 | return np.array(betas) 61 | 62 | 63 | class ModelMeanType(enum.Enum): 64 | """ 65 | Which type of output the model predicts. 66 | """ 67 | 68 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1} 69 | START_X = enum.auto() # the model predicts x_0 70 | EPSILON = enum.auto() # the model predicts epsilon 71 | 72 | 73 | class ModelVarType(enum.Enum): 74 | """ 75 | What is used as the model's output variance. 76 | The LEARNED_RANGE option has been added to allow the model to predict 77 | values between FIXED_SMALL and FIXED_LARGE, making its job easier. 78 | """ 79 | 80 | LEARNED = enum.auto() 81 | FIXED_SMALL = enum.auto() 82 | FIXED_LARGE = enum.auto() 83 | LEARNED_RANGE = enum.auto() 84 | 85 | 86 | class LossType(enum.Enum): 87 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances) 88 | RESCALED_MSE = ( 89 | enum.auto() 90 | ) # use raw MSE loss (with RESCALED_KL when learning variances) 91 | KL = enum.auto() # use the variational lower-bound 92 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB 93 | 94 | def is_vb(self): 95 | return self == LossType.KL or self == LossType.RESCALED_KL 96 | 97 | 98 | class GaussianDiffusion: 99 | """ 100 | Utilities for training and sampling diffusion models. 101 | Ported directly from here, and then adapted over time to further experimentation. 102 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 103 | :param betas: a 1-D numpy array of betas for each diffusion timestep, 104 | starting at T and going to 1. 105 | :param model_mean_type: a ModelMeanType determining what the model outputs. 106 | :param model_var_type: a ModelVarType determining how variance is output. 107 | :param loss_type: a LossType determining the loss function to use. 108 | :param rescale_timesteps: if True, pass floating point timesteps into the 109 | model so that they are always scaled like in the 110 | original paper (0 to 1000). 111 | """ 112 | 113 | def __init__( 114 | self, 115 | *, 116 | betas, 117 | model_mean_type, 118 | model_var_type, 119 | loss_type, 120 | rescale_timesteps=False, 121 | ): 122 | self.model_mean_type = model_mean_type 123 | self.model_var_type = model_var_type 124 | self.loss_type = loss_type 125 | self.rescale_timesteps = rescale_timesteps 126 | 127 | # Use float64 for accuracy. 128 | betas = np.array(betas, dtype=np.float64) 129 | self.betas = betas 130 | assert len(betas.shape) == 1, "betas must be 1-D" 131 | assert (betas > 0).all() and (betas <= 1).all() 132 | 133 | self.num_timesteps = int(betas.shape[0]) 134 | 135 | alphas = 1.0 - betas 136 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 137 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) 138 | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) 139 | assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) 140 | 141 | # calculations for diffusion q(x_t | x_{t-1}) and others 142 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) 143 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) 144 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) 145 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) 146 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) 147 | 148 | # calculations for posterior q(x_{t-1} | x_t, x_0) 149 | self.posterior_variance = ( 150 | betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 151 | ) 152 | # log calculation clipped because the posterior variance is 0 at the 153 | # beginning of the diffusion chain. 154 | self.posterior_log_variance_clipped = np.log( 155 | np.append(self.posterior_variance[1], self.posterior_variance[1:]) 156 | ) 157 | self.posterior_mean_coef1 = ( 158 | betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 159 | ) 160 | self.posterior_mean_coef2 = ( 161 | (1.0 - self.alphas_cumprod_prev) 162 | * np.sqrt(alphas) 163 | / (1.0 - self.alphas_cumprod) 164 | ) 165 | 166 | def q_mean_variance(self, x_start, t): 167 | """ 168 | Get the distribution q(x_t | x_0). 169 | :param x_start: the [N x C x ...] tensor of noiseless inputs. 170 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 171 | :return: A tuple (mean, variance, log_variance), all of x_start's shape. 172 | """ 173 | mean = ( 174 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 175 | ) 176 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) 177 | log_variance = _extract_into_tensor( 178 | self.log_one_minus_alphas_cumprod, t, x_start.shape 179 | ) 180 | return mean, variance, log_variance 181 | 182 | def q_sample(self, x_start, t, noise=None): 183 | """ 184 | Diffuse the data for a given number of diffusion steps. 185 | In other words, sample from q(x_t | x_0). 186 | :param x_start: the initial data batch. 187 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 188 | :param noise: if specified, the split-out normal noise. 189 | :return: A noisy version of x_start. 190 | """ 191 | if noise is None: 192 | noise = th.randn_like(x_start) 193 | assert noise.shape == x_start.shape 194 | return ( 195 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 196 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) 197 | * noise 198 | ) 199 | 200 | def q_posterior_mean_variance(self, x_start, x_t, t): 201 | """ 202 | Compute the mean and variance of the diffusion posterior: 203 | q(x_{t-1} | x_t, x_0) 204 | """ 205 | assert x_start.shape == x_t.shape 206 | posterior_mean = ( 207 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start 208 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t 209 | ) 210 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) 211 | posterior_log_variance_clipped = _extract_into_tensor( 212 | self.posterior_log_variance_clipped, t, x_t.shape 213 | ) 214 | assert ( 215 | posterior_mean.shape[0] 216 | == posterior_variance.shape[0] 217 | == posterior_log_variance_clipped.shape[0] 218 | == x_start.shape[0] 219 | ) 220 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 221 | 222 | def p_mean_variance( 223 | self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None 224 | ): 225 | """ 226 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of 227 | the initial x, x_0. 228 | :param model: the model, which takes a signal and a batch of timesteps 229 | as input. 230 | :param x: the [N x C x ...] tensor at time t. 231 | :param t: a 1-D Tensor of timesteps. 232 | :param clip_denoised: if True, clip the denoised signal into [-1, 1]. 233 | :param denoised_fn: if not None, a function which applies to the 234 | x_start prediction before it is used to sample. Applies before 235 | clip_denoised. 236 | :param model_kwargs: if not None, a dict of extra keyword arguments to 237 | pass to the model. This can be used for conditioning. 238 | :return: a dict with the following keys: 239 | - 'mean': the model mean output. 240 | - 'variance': the model variance output. 241 | - 'log_variance': the log of 'variance'. 242 | - 'pred_xstart': the prediction for x_0. 243 | """ 244 | if model_kwargs is None: 245 | model_kwargs = {} 246 | 247 | B, C = x.shape[:2] 248 | assert t.shape == (B,) 249 | model_output = model(x, self._scale_timesteps(t), **model_kwargs) 250 | 251 | 252 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: 253 | assert model_output.shape == (B, C * 2, *x.shape[2:]) 254 | model_output, model_var_values = th.split(model_output, C, dim=1) 255 | if self.model_var_type == ModelVarType.LEARNED: 256 | model_log_variance = model_var_values 257 | model_variance = th.exp(model_log_variance) 258 | else: 259 | min_log = _extract_into_tensor( 260 | self.posterior_log_variance_clipped, t, x.shape 261 | ) 262 | max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) 263 | # The model_var_values is [-1, 1] for [min_var, max_var]. 264 | frac = (model_var_values + 1) / 2 265 | model_log_variance = frac * max_log + (1 - frac) * min_log 266 | model_variance = th.exp(model_log_variance) 267 | else: 268 | model_variance, model_log_variance = { 269 | # for fixedlarge, we set the initial (log-)variance like so 270 | # to get a better decoder log likelihood. 271 | ModelVarType.FIXED_LARGE: ( 272 | np.append(self.posterior_variance[1], self.betas[1:]), 273 | np.log(np.append(self.posterior_variance[1], self.betas[1:])), 274 | ), 275 | ModelVarType.FIXED_SMALL: ( 276 | self.posterior_variance, 277 | self.posterior_log_variance_clipped, 278 | ), 279 | }[self.model_var_type] 280 | model_variance = _extract_into_tensor(model_variance, t, x.shape) 281 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) 282 | 283 | def process_xstart(x): 284 | if denoised_fn is not None: 285 | x = denoised_fn(x) 286 | if clip_denoised: 287 | x2 = th.clone(x).cpu().detach().numpy() 288 | p = 80 289 | s = np.percentile( 290 | np.abs(x2), p, 291 | axis=tuple(range(1, x2.ndim)))[0] 292 | s = max(s, 1.0) 293 | x = th.clip(x, -s, s) / s 294 | return x#x.clamp(-1, 1) 295 | return x 296 | 297 | if self.model_mean_type == ModelMeanType.PREVIOUS_X: 298 | pred_xstart = process_xstart( 299 | self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) 300 | ) 301 | model_mean = model_output 302 | elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: 303 | if self.model_mean_type == ModelMeanType.START_X: 304 | pred_xstart = process_xstart(model_output) 305 | else: 306 | pred_xstart = process_xstart( 307 | self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) 308 | ) 309 | model_mean, _, _ = self.q_posterior_mean_variance( 310 | x_start=pred_xstart, x_t=x, t=t 311 | ) 312 | else: 313 | raise NotImplementedError(self.model_mean_type) 314 | 315 | assert ( 316 | model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape 317 | ) 318 | return { 319 | "mean": model_mean, 320 | "variance": model_variance, 321 | "log_variance": model_log_variance, 322 | "pred_xstart": pred_xstart, 323 | } 324 | 325 | def _predict_xstart_from_eps(self, x_t, t, eps): 326 | assert x_t.shape == eps.shape 327 | return ( 328 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 329 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps 330 | ) 331 | 332 | def _predict_xstart_from_xprev(self, x_t, t, xprev): 333 | assert x_t.shape == xprev.shape 334 | return ( # (xprev - coef2*x_t) / coef1 335 | _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev 336 | - _extract_into_tensor( 337 | self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape 338 | ) 339 | * x_t 340 | ) 341 | 342 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart): 343 | return ( 344 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 345 | - pred_xstart 346 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 347 | 348 | def _scale_timesteps(self, t): 349 | if self.rescale_timesteps: 350 | return t.float() * (1000.0 / self.num_timesteps) 351 | return t 352 | 353 | def p_sample( 354 | self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None 355 | ): 356 | """ 357 | Sample x_{t-1} from the model at the given timestep. 358 | :param model: the model to sample from. 359 | :param x: the current tensor at x_{t-1}. 360 | :param t: the value of t, starting at 0 for the first diffusion step. 361 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. 362 | :param denoised_fn: if not None, a function which applies to the 363 | x_start prediction before it is used to sample. 364 | :param model_kwargs: if not None, a dict of extra keyword arguments to 365 | pass to the model. This can be used for conditioning. 366 | :return: a dict containing the following keys: 367 | - 'sample': a random sample from the model. 368 | - 'pred_xstart': a prediction of x_0. 369 | """ 370 | out = self.p_mean_variance( 371 | model, 372 | x, 373 | t, 374 | clip_denoised=clip_denoised, 375 | denoised_fn=denoised_fn, 376 | model_kwargs=model_kwargs, 377 | ) 378 | noise = th.randn_like(x) 379 | nonzero_mask = ( 380 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 381 | ) # no noise when t == 0 382 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise 383 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 384 | 385 | def p_sample_loop( 386 | self, 387 | model, 388 | shape, 389 | noise=None, 390 | clip_denoised=True, 391 | denoised_fn=None, 392 | model_kwargs=None, 393 | device=None, 394 | progress=False, 395 | ): 396 | """ 397 | Generate samples from the model. 398 | :param model: the model module. 399 | :param shape: the shape of the samples, (N, C, H, W). 400 | :param noise: if specified, the noise from the encoder to sample. 401 | Should be of the same shape as `shape`. 402 | :param clip_denoised: if True, clip x_start predictions to [-1, 1]. 403 | :param denoised_fn: if not None, a function which applies to the 404 | x_start prediction before it is used to sample. 405 | :param model_kwargs: if not None, a dict of extra keyword arguments to 406 | pass to the model. This can be used for conditioning. 407 | :param device: if specified, the device to create the samples on. 408 | If not specified, use a model parameter's device. 409 | :param progress: if True, show a tqdm progress bar. 410 | :return: a non-differentiable batch of samples. 411 | """ 412 | final = None 413 | for sample in self.p_sample_loop_progressive( 414 | model, 415 | shape, 416 | noise=noise, 417 | clip_denoised=clip_denoised, 418 | denoised_fn=denoised_fn, 419 | model_kwargs=model_kwargs, 420 | device=device, 421 | progress=progress, 422 | ): 423 | final = sample 424 | return final["sample"] 425 | 426 | def p_sample_loop_progressive( 427 | self, 428 | model, 429 | shape, 430 | noise=None, 431 | clip_denoised=True, 432 | denoised_fn=None, 433 | model_kwargs=None, 434 | device=None, 435 | progress=False, 436 | ): 437 | """ 438 | Generate samples from the model and yield intermediate samples from 439 | each timestep of diffusion. 440 | Arguments are the same as p_sample_loop(). 441 | Returns a generator over dicts, where each dict is the return value of 442 | p_sample(). 443 | """ 444 | if device is None: 445 | device = next(model.parameters()).device 446 | assert isinstance(shape, (tuple, list)) 447 | if noise is not None: 448 | img = noise 449 | else: 450 | img = th.randn(*shape, device=device) 451 | indices = list(range(self.num_timesteps))[::-1] 452 | 453 | if progress: 454 | # Lazy import so that we don't depend on tqdm. 455 | from tqdm.auto import tqdm 456 | 457 | indices = tqdm(indices) 458 | 459 | for i in indices: 460 | t = th.tensor([i] * shape[0], device=device) 461 | with th.no_grad(): 462 | out = self.p_sample( 463 | model, 464 | img, 465 | t, 466 | clip_denoised=clip_denoised, 467 | denoised_fn=denoised_fn, 468 | model_kwargs=model_kwargs, 469 | ) 470 | yield out 471 | img = out["sample"] 472 | 473 | def ddim_sample( 474 | self, 475 | model, 476 | x, 477 | t, 478 | clip_denoised=True, 479 | denoised_fn=None, 480 | model_kwargs=None, 481 | eta=0.0, 482 | ): 483 | """ 484 | Sample x_{t-1} from the model using DDIM. 485 | Same usage as p_sample(). 486 | """ 487 | out = self.p_mean_variance( 488 | model, 489 | x, 490 | t, 491 | clip_denoised=clip_denoised, 492 | denoised_fn=denoised_fn, 493 | model_kwargs=model_kwargs, 494 | ) 495 | # Usually our model outputs epsilon, but we re-derive it 496 | # in case we used x_start or x_prev prediction. 497 | eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) 498 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 499 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) 500 | sigma = ( 501 | eta 502 | * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) 503 | * th.sqrt(1 - alpha_bar / alpha_bar_prev) 504 | ) 505 | # Equation 12. 506 | noise = th.randn_like(x) 507 | mean_pred = ( 508 | out["pred_xstart"] * th.sqrt(alpha_bar_prev) 509 | + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps 510 | ) 511 | nonzero_mask = ( 512 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 513 | ) # no noise when t == 0 514 | sample = mean_pred + nonzero_mask * sigma * noise 515 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 516 | 517 | def ddim_reverse_sample( 518 | self, 519 | model, 520 | x, 521 | t, 522 | clip_denoised=True, 523 | denoised_fn=None, 524 | model_kwargs=None, 525 | eta=0.0, 526 | ): 527 | """ 528 | Sample x_{t+1} from the model using DDIM reverse ODE. 529 | """ 530 | assert eta == 0.0, "Reverse ODE only for deterministic path" 531 | out = self.p_mean_variance( 532 | model, 533 | x, 534 | t, 535 | clip_denoised=clip_denoised, 536 | denoised_fn=denoised_fn, 537 | model_kwargs=model_kwargs, 538 | ) 539 | # Usually our model outputs epsilon, but we re-derive it 540 | # in case we used x_start or x_prev prediction. 541 | eps = ( 542 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x 543 | - out["pred_xstart"] 544 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) 545 | alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) 546 | 547 | # Equation 12. reversed 548 | mean_pred = ( 549 | out["pred_xstart"] * th.sqrt(alpha_bar_next) 550 | + th.sqrt(1 - alpha_bar_next) * eps 551 | ) 552 | 553 | return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} 554 | 555 | def ddim_sample_loop( 556 | self, 557 | model, 558 | shape, 559 | noise=None, 560 | clip_denoised=True, 561 | denoised_fn=None, 562 | model_kwargs=None, 563 | device=None, 564 | progress=False, 565 | eta=0.0, 566 | ): 567 | """ 568 | Generate samples from the model using DDIM. 569 | Same usage as p_sample_loop(). 570 | """ 571 | final = None 572 | for sample in self.ddim_sample_loop_progressive( 573 | model, 574 | shape, 575 | noise=noise, 576 | clip_denoised=clip_denoised, 577 | denoised_fn=denoised_fn, 578 | model_kwargs=model_kwargs, 579 | device=device, 580 | progress=progress, 581 | eta=eta, 582 | ): 583 | final = sample 584 | return final["sample"] 585 | 586 | def ddim_sample_loop_progressive( 587 | self, 588 | model, 589 | shape, 590 | noise=None, 591 | clip_denoised=True, 592 | denoised_fn=None, 593 | model_kwargs=None, 594 | device=None, 595 | progress=False, 596 | eta=0.0, 597 | ): 598 | """ 599 | Use DDIM to sample from the model and yield intermediate samples from 600 | each timestep of DDIM. 601 | Same usage as p_sample_loop_progressive(). 602 | """ 603 | if device is None: 604 | device = next(model.parameters()).device 605 | assert isinstance(shape, (tuple, list)) 606 | if noise is not None: 607 | img = noise 608 | else: 609 | img = th.randn(*shape, device=device) 610 | indices = list(range(self.num_timesteps))[::-1] 611 | 612 | if progress: 613 | # Lazy import so that we don't depend on tqdm. 614 | from tqdm.auto import tqdm 615 | 616 | indices = tqdm(indices) 617 | 618 | for i in indices: 619 | t = th.tensor([i] * shape[0], device=device) 620 | with th.no_grad(): 621 | out = self.ddim_sample( 622 | model, 623 | img, 624 | t, 625 | clip_denoised=clip_denoised, 626 | denoised_fn=denoised_fn, 627 | model_kwargs=model_kwargs, 628 | eta=eta, 629 | ) 630 | yield out 631 | img = out["sample"] 632 | 633 | def _vb_terms_bpd( 634 | self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None 635 | ): 636 | """ 637 | Get a term for the variational lower-bound. 638 | The resulting units are bits (rather than nats, as one might expect). 639 | This allows for comparison to other papers. 640 | :return: a dict with the following keys: 641 | - 'output': a shape [N] tensor of NLLs or KLs. 642 | - 'pred_xstart': the x_0 predictions. 643 | """ 644 | true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( 645 | x_start=x_start, x_t=x_t, t=t 646 | ) 647 | out = self.p_mean_variance( 648 | model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs 649 | ) 650 | kl = normal_kl( 651 | true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] 652 | ) 653 | kl = mean_flat(kl) / np.log(2.0) 654 | 655 | decoder_nll = -discretized_gaussian_log_likelihood( 656 | x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] 657 | ) 658 | assert decoder_nll.shape == x_start.shape 659 | decoder_nll = mean_flat(decoder_nll) / np.log(2.0) 660 | 661 | # At the first timestep return the decoder NLL, 662 | # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) 663 | output = th.where((t == 0), decoder_nll, kl) 664 | return {"output": output, "pred_xstart": out["pred_xstart"]} 665 | 666 | def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): 667 | """ 668 | Compute training losses for a single timestep. 669 | :param model: the model to evaluate loss on. 670 | :param x_start: the [N x C x ...] tensor of inputs. 671 | :param t: a batch of timestep indices. 672 | :param model_kwargs: if not None, a dict of extra keyword arguments to 673 | pass to the model. This can be used for conditioning. 674 | :param noise: if specified, the specific Gaussian noise to try to remove. 675 | :return: a dict with the key "loss" containing a tensor of shape [N]. 676 | Some mean or variance settings may also have other keys. 677 | """ 678 | if model_kwargs is None: 679 | model_kwargs = {} 680 | if noise is None: 681 | noise = th.randn_like(x_start) 682 | x_t = self.q_sample(x_start, t, noise=noise) 683 | 684 | terms = {} 685 | 686 | if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: 687 | terms["loss"] = self._vb_terms_bpd( 688 | model=model, 689 | x_start=x_start, 690 | x_t=x_t, 691 | t=t, 692 | clip_denoised=False, 693 | model_kwargs=model_kwargs, 694 | )["output"] 695 | if self.loss_type == LossType.RESCALED_KL: 696 | terms["loss"] *= self.num_timesteps 697 | elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: 698 | model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) 699 | 700 | if self.model_var_type in [ 701 | ModelVarType.LEARNED, 702 | ModelVarType.LEARNED_RANGE, 703 | ]: 704 | B, C = x_t.shape[:2] 705 | assert model_output.shape == (B, C * 2, *x_t.shape[2:]) 706 | model_output, model_var_values = th.split(model_output, C, dim=1) 707 | # Learn the variance using the variational bound, but don't let 708 | # it affect our mean prediction. 709 | frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) 710 | terms["vb"] = self._vb_terms_bpd( 711 | model=lambda *args, r=frozen_out: r, 712 | x_start=x_start, 713 | x_t=x_t, 714 | t=t, 715 | clip_denoised=False, 716 | )["output"] 717 | if self.loss_type == LossType.RESCALED_MSE: 718 | # Divide by 1000 for equivalence with initial implementation. 719 | # Without a factor of 1/1000, the VB term hurts the MSE term. 720 | terms["vb"] *= self.num_timesteps / 1000.0 721 | 722 | target = { 723 | ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( 724 | x_start=x_start, x_t=x_t, t=t 725 | )[0], 726 | ModelMeanType.START_X: x_start, 727 | ModelMeanType.EPSILON: noise, 728 | }[self.model_mean_type] 729 | assert model_output.shape == target.shape == x_start.shape 730 | terms["mse"] = mean_flat((target - model_output) ** 2) 731 | if "vb" in terms: 732 | terms["loss"] = terms["mse"] + terms["vb"] 733 | else: 734 | terms["loss"] = terms["mse"] 735 | else: 736 | raise NotImplementedError(self.loss_type) 737 | 738 | return terms 739 | 740 | def _prior_bpd(self, x_start): 741 | """ 742 | Get the prior KL term for the variational lower-bound, measured in 743 | bits-per-dim. 744 | This term can't be optimized, as it only depends on the encoder. 745 | :param x_start: the [N x C x ...] tensor of inputs. 746 | :return: a batch of [N] KL values (in bits), one per batch element. 747 | """ 748 | batch_size = x_start.shape[0] 749 | t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) 750 | qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) 751 | kl_prior = normal_kl( 752 | mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 753 | ) 754 | return mean_flat(kl_prior) / np.log(2.0) 755 | 756 | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): 757 | """ 758 | Compute the entire variational lower-bound, measured in bits-per-dim, 759 | as well as other related quantities. 760 | :param model: the model to evaluate loss on. 761 | :param x_start: the [N x C x ...] tensor of inputs. 762 | :param clip_denoised: if True, clip denoised samples. 763 | :param model_kwargs: if not None, a dict of extra keyword arguments to 764 | pass to the model. This can be used for conditioning. 765 | :return: a dict containing the following keys: 766 | - total_bpd: the total variational lower-bound, per batch element. 767 | - prior_bpd: the prior term in the lower-bound. 768 | - vb: an [N x T] tensor of terms in the lower-bound. 769 | - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. 770 | - mse: an [N x T] tensor of epsilon MSEs for each timestep. 771 | """ 772 | device = x_start.device 773 | batch_size = x_start.shape[0] 774 | 775 | vb = [] 776 | xstart_mse = [] 777 | mse = [] 778 | for t in list(range(self.num_timesteps))[::-1]: 779 | t_batch = th.tensor([t] * batch_size, device=device) 780 | noise = th.randn_like(x_start) 781 | x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) 782 | # Calculate VLB term at the current timestep 783 | with th.no_grad(): 784 | out = self._vb_terms_bpd( 785 | model, 786 | x_start=x_start, 787 | x_t=x_t, 788 | t=t_batch, 789 | clip_denoised=clip_denoised, 790 | model_kwargs=model_kwargs, 791 | ) 792 | vb.append(out["output"]) 793 | xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) 794 | eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) 795 | mse.append(mean_flat((eps - noise) ** 2)) 796 | 797 | vb = th.stack(vb, dim=1) 798 | xstart_mse = th.stack(xstart_mse, dim=1) 799 | mse = th.stack(mse, dim=1) 800 | 801 | prior_bpd = self._prior_bpd(x_start) 802 | total_bpd = vb.sum(dim=1) + prior_bpd 803 | return { 804 | "total_bpd": total_bpd, 805 | "prior_bpd": prior_bpd, 806 | "vb": vb, 807 | "xstart_mse": xstart_mse, 808 | "mse": mse, 809 | } 810 | 811 | 812 | def _extract_into_tensor(arr, timesteps, broadcast_shape): 813 | """ 814 | Extract values from a 1-D numpy array for a batch of indices. 815 | :param arr: the 1-D numpy array. 816 | :param timesteps: a tensor of indices into the array to extract. 817 | :param broadcast_shape: a larger shape of K dimensions with the batch 818 | dimension equal to the length of timesteps. 819 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 820 | """ 821 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() 822 | while len(res.shape) < len(broadcast_shape): 823 | res = res[..., None] 824 | return res.expand(broadcast_shape) 825 | -------------------------------------------------------------------------------- /imagen_pytorch/get_webdataset_loader.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 2 | from torch.utils.data import DataLoader 3 | import torch 4 | 5 | import os 6 | import argparse 7 | import io 8 | import numpy as np 9 | from PIL import Image 10 | from transformers import AutoTokenizer 11 | 12 | try: 13 | from torchvision.transforms import InterpolationMode 14 | BICUBIC = InterpolationMode.BICUBIC 15 | except ImportError: 16 | BICUBIC = Image.BICUBIC 17 | 18 | 19 | def _convert_image_to_rgb(image): 20 | return image.convert("RGB") 21 | 22 | def _transform(n_px): 23 | return Compose([ 24 | Resize(n_px, interpolation=BICUBIC), 25 | CenterCrop(n_px), 26 | _convert_image_to_rgb, 27 | ToTensor(), 28 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 29 | ]) 30 | 31 | def create_webdataset( 32 | urls, 33 | enable_text=True, 34 | enable_image=True, 35 | image_key="jpg", 36 | caption_key="txt", 37 | enable_metadata=False, 38 | cache_path=None, 39 | t5_name='t5-11b' 40 | 41 | ): 42 | """Create a WebDataset reader, it can read a webdataset of image, text and json""" 43 | import webdataset as wds # pylint: disable=import-outside-toplevel 44 | 45 | 46 | dataset = wds.WebDataset(wds.ResampledShards(urls)) 47 | print('dataset_created') 48 | tokenizer_t = AutoTokenizer.from_pretrained(t5_name) 49 | def tokenizer(text): 50 | out_dict = {} 51 | if np.random.binomial(1, 0.08): 52 | text = '' 53 | text_encoding = tokenizer_t( 54 | text, 55 | max_length=128, 56 | padding="max_length", 57 | truncation=True, 58 | return_attention_mask=True, 59 | add_special_tokens=True, 60 | return_tensors="pt") 61 | 62 | out_dict["tokens"] = text_encoding['input_ids'][0] 63 | out_dict["mask"] = text_encoding['attention_mask'][0] 64 | return out_dict 65 | def filter_dataset(item): 66 | if enable_text and caption_key not in item: 67 | return False 68 | if enable_image and image_key not in item: 69 | return False 70 | if enable_metadata and "json" not in item: 71 | return False 72 | return True 73 | 74 | filtered_dataset = dataset.select(filter_dataset) 75 | print('dataset filtered') 76 | resolution = 64 77 | print('resolution is', resolution) 78 | def preprocess_dataset(item): 79 | if enable_image: 80 | image_data = item[image_key] 81 | 82 | pil_image = Image.open(io.BytesIO(image_data)) 83 | pil_image.load() 84 | while min(*pil_image.size) >= 2 * resolution: 85 | pil_image = pil_image.resize( 86 | tuple(x // 2 for x in pil_image.size), resample=Image.Resampling.BOX 87 | ) 88 | 89 | scale = resolution / min(*pil_image.size) 90 | pil_image = pil_image.resize( 91 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 92 | ) 93 | 94 | arr = np.array(pil_image.convert("RGB")) 95 | crop_y = (arr.shape[0] - resolution) // 2 96 | crop_x = (arr.shape[1] - resolution) // 2 97 | 98 | arr = arr[crop_y: crop_y + resolution, crop_x: crop_x + resolution] 99 | arr = arr.astype(np.float32) / 127.5 - 1 100 | 101 | if enable_text: 102 | text = item[caption_key] 103 | caption = text.decode("utf-8") 104 | tokenized_text = tokenizer(caption) 105 | return np.transpose(arr, [2, 0, 1]), tokenized_text 106 | 107 | transformed_dataset = filtered_dataset.map(preprocess_dataset, handler=wds.handlers.warn_and_continue) 108 | print('dataset transformed') 109 | return transformed_dataset 110 | 111 | 112 | def dataset_to_dataloader(dataset, batch_size, num_prepro_workers, input_format): 113 | """Create a pytorch dataloader from a dataset""" 114 | 115 | def collate_fn(batch): 116 | batch = list(filter(lambda x: x is not None, batch)) 117 | return default_collate(batch) 118 | 119 | data = DataLoader( 120 | dataset, 121 | batch_size=batch_size, 122 | shuffle=False, 123 | num_workers=num_prepro_workers, 124 | pin_memory=True, 125 | ) 126 | return data 127 | 128 | 129 | class WebdatasetReader: 130 | """WebdatasetReader is a reader that reads samples from a webdataset""" 131 | 132 | def __init__( 133 | self, 134 | input_dataset, 135 | batch_size, 136 | num_prepro_workers, 137 | enable_text=True, 138 | enable_image=True, 139 | enable_metadata=False, 140 | wds_image_key="jpg", 141 | wds_caption_key="txt", 142 | cache_path=None, 143 | t5_name='t5-11b', 144 | 145 | ): 146 | self.batch_size = batch_size 147 | dataset = create_webdataset( 148 | input_dataset, 149 | enable_text=enable_text, 150 | enable_image=enable_image, 151 | image_key=wds_image_key, 152 | caption_key=wds_caption_key, 153 | enable_metadata=enable_metadata, 154 | cache_path=cache_path, 155 | t5_name=t5_name 156 | ) 157 | self.dataloader = dataset_to_dataloader(dataset, batch_size, num_prepro_workers, "webdataset") 158 | def get_loader(self): 159 | return self.dataloader 160 | def __iter__(self): 161 | for batch in self.dataloader: 162 | yield batch 163 | -------------------------------------------------------------------------------- /imagen_pytorch/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import shutil 9 | import os.path as osp 10 | import json 11 | import time 12 | import datetime 13 | import tempfile 14 | import warnings 15 | from collections import defaultdict 16 | from contextlib import contextmanager 17 | 18 | DEBUG = 10 19 | INFO = 20 20 | WARN = 30 21 | ERROR = 40 22 | 23 | DISABLED = 50 24 | 25 | 26 | class KVWriter(object): 27 | def writekvs(self, kvs): 28 | raise NotImplementedError 29 | 30 | 31 | class SeqWriter(object): 32 | def writeseq(self, seq): 33 | raise NotImplementedError 34 | 35 | 36 | class HumanOutputFormat(KVWriter, SeqWriter): 37 | def __init__(self, filename_or_file): 38 | if isinstance(filename_or_file, str): 39 | self.file = open(filename_or_file, "wt") 40 | self.own_file = True 41 | else: 42 | assert hasattr(filename_or_file, "read"), ( 43 | "expected file or str, got %s" % filename_or_file 44 | ) 45 | self.file = filename_or_file 46 | self.own_file = False 47 | 48 | def writekvs(self, kvs): 49 | # Create strings for printing 50 | key2str = {} 51 | for (key, val) in sorted(kvs.items()): 52 | if hasattr(val, "__float__"): 53 | valstr = "%-8.3g" % val 54 | else: 55 | valstr = str(val) 56 | key2str[self._truncate(key)] = self._truncate(valstr) 57 | 58 | # Find max widths 59 | if len(key2str) == 0: 60 | print("WARNING: tried to write empty key-value dict") 61 | return 62 | else: 63 | keywidth = max(map(len, key2str.keys())) 64 | valwidth = max(map(len, key2str.values())) 65 | 66 | # Write out the data 67 | dashes = "-" * (keywidth + valwidth + 7) 68 | lines = [dashes] 69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 70 | lines.append( 71 | "| %s%s | %s%s |" 72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 73 | ) 74 | lines.append(dashes) 75 | self.file.write("\n".join(lines) + "\n") 76 | 77 | # Flush the output to the file 78 | self.file.flush() 79 | 80 | def _truncate(self, s): 81 | maxlen = 30 82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 83 | 84 | def writeseq(self, seq): 85 | seq = list(seq) 86 | for (i, elem) in enumerate(seq): 87 | self.file.write(elem) 88 | if i < len(seq) - 1: # add space unless this is the last one 89 | self.file.write(" ") 90 | self.file.write("\n") 91 | self.file.flush() 92 | 93 | def close(self): 94 | if self.own_file: 95 | self.file.close() 96 | 97 | 98 | class JSONOutputFormat(KVWriter): 99 | def __init__(self, filename): 100 | self.file = open(filename, "wt") 101 | 102 | def writekvs(self, kvs): 103 | for k, v in sorted(kvs.items()): 104 | if hasattr(v, "dtype"): 105 | kvs[k] = float(v) 106 | self.file.write(json.dumps(kvs) + "\n") 107 | self.file.flush() 108 | 109 | def close(self): 110 | self.file.close() 111 | 112 | 113 | class CSVOutputFormat(KVWriter): 114 | def __init__(self, filename): 115 | self.file = open(filename, "w+t") 116 | self.keys = [] 117 | self.sep = "," 118 | 119 | def writekvs(self, kvs): 120 | # Add our current row to the history 121 | extra_keys = list(kvs.keys() - self.keys) 122 | extra_keys.sort() 123 | if extra_keys: 124 | self.keys.extend(extra_keys) 125 | self.file.seek(0) 126 | lines = self.file.readlines() 127 | self.file.seek(0) 128 | for (i, k) in enumerate(self.keys): 129 | if i > 0: 130 | self.file.write(",") 131 | self.file.write(k) 132 | self.file.write("\n") 133 | for line in lines[1:]: 134 | self.file.write(line[:-1]) 135 | self.file.write(self.sep * len(extra_keys)) 136 | self.file.write("\n") 137 | for (i, k) in enumerate(self.keys): 138 | if i > 0: 139 | self.file.write(",") 140 | v = kvs.get(k) 141 | if v is not None: 142 | self.file.write(str(v)) 143 | self.file.write("\n") 144 | self.file.flush() 145 | 146 | def close(self): 147 | self.file.close() 148 | 149 | 150 | class TensorBoardOutputFormat(KVWriter): 151 | """ 152 | Dumps key/value pairs into TensorBoard's numeric format. 153 | """ 154 | 155 | def __init__(self, dir): 156 | os.makedirs(dir, exist_ok=True) 157 | self.dir = dir 158 | self.step = 1 159 | prefix = "events" 160 | path = osp.join(osp.abspath(dir), prefix) 161 | import tensorflow as tf 162 | from tensorflow.python import pywrap_tensorflow 163 | from tensorflow.core.util import event_pb2 164 | from tensorflow.python.util import compat 165 | 166 | self.tf = tf 167 | self.event_pb2 = event_pb2 168 | self.pywrap_tensorflow = pywrap_tensorflow 169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 170 | 171 | def writekvs(self, kvs): 172 | def summary_val(k, v): 173 | kwargs = {"tag": k, "simple_value": float(v)} 174 | return self.tf.Summary.Value(**kwargs) 175 | 176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 178 | event.step = ( 179 | self.step 180 | ) # is there any reason why you'd want to specify the step? 181 | self.writer.WriteEvent(event) 182 | self.writer.Flush() 183 | self.step += 1 184 | 185 | def close(self): 186 | if self.writer: 187 | self.writer.Close() 188 | self.writer = None 189 | 190 | 191 | def make_output_format(format, ev_dir, log_suffix=""): 192 | os.makedirs(ev_dir, exist_ok=True) 193 | if format == "stdout": 194 | return HumanOutputFormat(sys.stdout) 195 | elif format == "log": 196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 197 | elif format == "json": 198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 199 | elif format == "csv": 200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 201 | elif format == "tensorboard": 202 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 203 | else: 204 | raise ValueError("Unknown format specified: %s" % (format,)) 205 | 206 | 207 | # ================================================================ 208 | # API 209 | # ================================================================ 210 | 211 | 212 | def logkv(key, val): 213 | """ 214 | Log a value of some diagnostic 215 | Call this once for each diagnostic quantity, each iteration 216 | If called many times, last value will be used. 217 | """ 218 | get_current().logkv(key, val) 219 | 220 | 221 | def logkv_mean(key, val): 222 | """ 223 | The same as logkv(), but if called many times, values averaged. 224 | """ 225 | get_current().logkv_mean(key, val) 226 | 227 | 228 | def logkvs(d): 229 | """ 230 | Log a dictionary of key-value pairs 231 | """ 232 | for (k, v) in d.items(): 233 | logkv(k, v) 234 | 235 | 236 | def dumpkvs(): 237 | """ 238 | Write all of the diagnostics from the current iteration 239 | """ 240 | return get_current().dumpkvs() 241 | 242 | 243 | def getkvs(): 244 | return get_current().name2val 245 | 246 | 247 | def log(*args, level=INFO): 248 | """ 249 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 250 | """ 251 | get_current().log(*args, level=level) 252 | 253 | 254 | def debug(*args): 255 | log(*args, level=DEBUG) 256 | 257 | 258 | def info(*args): 259 | log(*args, level=INFO) 260 | 261 | 262 | def warn(*args): 263 | log(*args, level=WARN) 264 | 265 | 266 | def error(*args): 267 | log(*args, level=ERROR) 268 | 269 | 270 | def set_level(level): 271 | """ 272 | Set logging threshold on current logger. 273 | """ 274 | get_current().set_level(level) 275 | 276 | 277 | def set_comm(comm): 278 | get_current().set_comm(comm) 279 | 280 | 281 | def get_dir(): 282 | """ 283 | Get directory that log files are being written to. 284 | will be None if there is no output directory (i.e., if you didn't call start) 285 | """ 286 | return get_current().get_dir() 287 | 288 | 289 | record_tabular = logkv 290 | dump_tabular = dumpkvs 291 | 292 | 293 | @contextmanager 294 | def profile_kv(scopename): 295 | logkey = "wait_" + scopename 296 | tstart = time.time() 297 | try: 298 | yield 299 | finally: 300 | get_current().name2val[logkey] += time.time() - tstart 301 | 302 | 303 | def profile(n): 304 | """ 305 | Usage: 306 | @profile("my_func") 307 | def my_func(): code 308 | """ 309 | 310 | def decorator_with_name(func): 311 | def func_wrapper(*args, **kwargs): 312 | with profile_kv(n): 313 | return func(*args, **kwargs) 314 | 315 | return func_wrapper 316 | 317 | return decorator_with_name 318 | 319 | 320 | # ================================================================ 321 | # Backend 322 | # ================================================================ 323 | 324 | 325 | def get_current(): 326 | if Logger.CURRENT is None: 327 | _configure_default_logger() 328 | 329 | return Logger.CURRENT 330 | 331 | 332 | class Logger(object): 333 | DEFAULT = None # A logger with no output files. (See right below class definition) 334 | # So that you can still log to the terminal without setting up any output files 335 | CURRENT = None # Current logger being used by the free functions above 336 | 337 | def __init__(self, dir, output_formats, comm=None): 338 | self.name2val = defaultdict(float) # values this iteration 339 | self.name2cnt = defaultdict(int) 340 | self.level = INFO 341 | self.dir = dir 342 | self.output_formats = output_formats 343 | self.comm = comm 344 | 345 | # Logging API, forwarded 346 | # ---------------------------------------- 347 | def logkv(self, key, val): 348 | self.name2val[key] = val 349 | 350 | def logkv_mean(self, key, val): 351 | oldval, cnt = self.name2val[key], self.name2cnt[key] 352 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 353 | self.name2cnt[key] = cnt + 1 354 | 355 | def dumpkvs(self): 356 | if self.comm is None: 357 | d = self.name2val 358 | else: 359 | d = mpi_weighted_mean( 360 | self.comm, 361 | { 362 | name: (val, self.name2cnt.get(name, 1)) 363 | for (name, val) in self.name2val.items() 364 | }, 365 | ) 366 | if self.comm.rank != 0: 367 | d["dummy"] = 1 # so we don't get a warning about empty dict 368 | out = d.copy() # Return the dict for unit testing purposes 369 | for fmt in self.output_formats: 370 | if isinstance(fmt, KVWriter): 371 | fmt.writekvs(d) 372 | self.name2val.clear() 373 | self.name2cnt.clear() 374 | return out 375 | 376 | def log(self, *args, level=INFO): 377 | if self.level <= level: 378 | self._do_log(args) 379 | 380 | # Configuration 381 | # ---------------------------------------- 382 | def set_level(self, level): 383 | self.level = level 384 | 385 | def set_comm(self, comm): 386 | self.comm = comm 387 | 388 | def get_dir(self): 389 | return self.dir 390 | 391 | def close(self): 392 | for fmt in self.output_formats: 393 | fmt.close() 394 | 395 | # Misc 396 | # ---------------------------------------- 397 | def _do_log(self, args): 398 | for fmt in self.output_formats: 399 | if isinstance(fmt, SeqWriter): 400 | fmt.writeseq(map(str, args)) 401 | 402 | 403 | def get_rank_without_mpi_import(): 404 | # check environment variables here instead of importing mpi4py 405 | # to avoid calling MPI_Init() when this module is imported 406 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 407 | if varname in os.environ: 408 | return int(os.environ[varname]) 409 | return 0 410 | 411 | 412 | def mpi_weighted_mean(comm, local_name2valcount): 413 | """ 414 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 415 | Perform a weighted average over dicts that are each on a different node 416 | Input: local_name2valcount: dict mapping key -> (value, count) 417 | Returns: key -> mean 418 | """ 419 | all_name2valcount = comm.gather(local_name2valcount) 420 | if comm.rank == 0: 421 | name2sum = defaultdict(float) 422 | name2count = defaultdict(float) 423 | for n2vc in all_name2valcount: 424 | for (name, (val, count)) in n2vc.items(): 425 | try: 426 | val = float(val) 427 | except ValueError: 428 | if comm.rank == 0: 429 | warnings.warn( 430 | "WARNING: tried to compute mean on non-float {}={}".format( 431 | name, val 432 | ) 433 | ) 434 | else: 435 | name2sum[name] += val * count 436 | name2count[name] += count 437 | return {name: name2sum[name] / name2count[name] for name in name2sum} 438 | else: 439 | return {} 440 | 441 | 442 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 443 | """ 444 | If comm is provided, average all numerical stats across that comm 445 | """ 446 | if dir is None: 447 | dir = os.getenv("OPENAI_LOGDIR") 448 | if dir is None: 449 | dir = osp.join( 450 | tempfile.gettempdir(), 451 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), 452 | ) 453 | assert isinstance(dir, str) 454 | dir = os.path.expanduser(dir) 455 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 456 | 457 | rank = get_rank_without_mpi_import() 458 | if rank > 0: 459 | log_suffix = log_suffix + "-rank%03i" % rank 460 | 461 | if format_strs is None: 462 | if rank == 0: 463 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 464 | else: 465 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 466 | format_strs = filter(None, format_strs) 467 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 468 | 469 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 470 | if output_formats: 471 | log("Logging to %s" % dir) 472 | 473 | 474 | def _configure_default_logger(): 475 | configure() 476 | Logger.DEFAULT = Logger.CURRENT 477 | 478 | 479 | def reset(): 480 | if Logger.CURRENT is not Logger.DEFAULT: 481 | Logger.CURRENT.close() 482 | Logger.CURRENT = Logger.DEFAULT 483 | log("Reset logger") 484 | 485 | 486 | @contextmanager 487 | def scoped_configure(dir=None, format_strs=None, comm=None): 488 | prevlogger = Logger.CURRENT 489 | configure(dir=dir, format_strs=format_strs, comm=comm) 490 | try: 491 | yield 492 | finally: 493 | Logger.CURRENT.close() 494 | Logger.CURRENT = prevlogger -------------------------------------------------------------------------------- /imagen_pytorch/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | Shapes are automatically broadcasted, so batches can be compared to 16 | scalars, among other use cases. 17 | """ 18 | tensor = None 19 | for obj in (mean1, logvar1, mean2, logvar2): 20 | if isinstance(obj, th.Tensor): 21 | tensor = obj 22 | break 23 | assert tensor is not None, "at least one argument must be a Tensor" 24 | 25 | # Force variances to be Tensors. Broadcasting helps convert scalars to 26 | # Tensors, but it does not work for th.exp(). 27 | logvar1, logvar2 = [ 28 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 29 | for x in (logvar1, logvar2) 30 | ] 31 | 32 | return 0.5 * ( 33 | -1.0 34 | + logvar2 35 | - logvar1 36 | + th.exp(logvar1 - logvar2) 37 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 38 | ) 39 | 40 | 41 | def approx_standard_normal_cdf(x): 42 | """ 43 | A fast approximation of the cumulative distribution function of the 44 | standard normal. 45 | """ 46 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 47 | 48 | 49 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 50 | """ 51 | Compute the log-likelihood of a Gaussian distribution discretizing to a 52 | given image. 53 | :param x: the target images. It is assumed that this was uint8 values, 54 | rescaled to the range [-1, 1]. 55 | :param means: the Gaussian mean Tensor. 56 | :param log_scales: the Gaussian log stddev Tensor. 57 | :return: a tensor like x of log probabilities (in nats). 58 | """ 59 | assert x.shape == means.shape == log_scales.shape 60 | centered_x = x - means 61 | inv_stdv = th.exp(-log_scales) 62 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 63 | cdf_plus = approx_standard_normal_cdf(plus_in) 64 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 65 | cdf_min = approx_standard_normal_cdf(min_in) 66 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 67 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 68 | cdf_delta = cdf_plus - cdf_min 69 | log_probs = th.where( 70 | x < -0.999, 71 | log_cdf_plus, 72 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 73 | ) 74 | assert log_probs.shape == x.shape 75 | return log_probs -------------------------------------------------------------------------------- /imagen_pytorch/model_creation.py: -------------------------------------------------------------------------------- 1 | from .gaussian_diffusion import get_named_beta_schedule 2 | from . import gaussian_diffusion as gd 3 | from .respace import SpacedDiffusion, space_timesteps 4 | from imagen_pytorch.text2im_model import ( 5 | Text2ImUNet, 6 | ) 7 | 8 | 9 | def model_and_diffusion_defaults(): 10 | return dict( 11 | image_size=64, 12 | num_channels=192, 13 | num_res_blocks=3, 14 | channel_mult="", 15 | num_heads=1, 16 | num_head_channels=64, 17 | num_heads_upsample=-1, 18 | attention_resolutions="32,16,8", 19 | dropout=0.1, 20 | t5_name='t5-large', 21 | xf_width=512, 22 | use_scale_shift_norm=True, 23 | resblock_updown=True, 24 | use_fp16=True, 25 | cache_text_emb=False, 26 | 27 | learn_sigma=True, 28 | sigma_small=False, 29 | diffusion_steps=1000, 30 | noise_schedule="linear", 31 | timestep_respacing="", 32 | use_kl=False, 33 | predict_xstart=False, 34 | rescale_timesteps=True, 35 | rescale_learned_sigmas=True, 36 | 37 | ) 38 | 39 | 40 | 41 | def create_model_and_diffusion( 42 | image_size, 43 | num_channels, 44 | num_res_blocks, 45 | channel_mult, 46 | num_heads, 47 | num_head_channels, 48 | num_heads_upsample, 49 | attention_resolutions, 50 | dropout, 51 | t5_name, 52 | xf_width, 53 | use_scale_shift_norm, 54 | resblock_updown, 55 | use_fp16, 56 | cache_text_emb, 57 | 58 | learn_sigma, 59 | sigma_small, 60 | diffusion_steps, 61 | noise_schedule, 62 | timestep_respacing, 63 | use_kl, 64 | predict_xstart, 65 | rescale_timesteps, 66 | rescale_learned_sigmas, 67 | 68 | ): 69 | model = create_model( 70 | image_size, 71 | num_channels, 72 | num_res_blocks, 73 | channel_mult=channel_mult, 74 | attention_resolutions=attention_resolutions, 75 | num_heads=num_heads, 76 | num_head_channels=num_head_channels, 77 | num_heads_upsample=num_heads_upsample, 78 | use_scale_shift_norm=use_scale_shift_norm, 79 | dropout=dropout, 80 | xf_width=xf_width, 81 | t5_name=t5_name, 82 | resblock_updown=resblock_updown, 83 | use_fp16=use_fp16, 84 | cache_text_emb=cache_text_emb, 85 | ) 86 | diffusion = create_gaussian_diffusion( 87 | steps=diffusion_steps, 88 | learn_sigma=learn_sigma, 89 | sigma_small=sigma_small, 90 | noise_schedule=noise_schedule, 91 | use_kl=use_kl, 92 | predict_xstart=predict_xstart, 93 | rescale_timesteps=rescale_timesteps, 94 | rescale_learned_sigmas=rescale_learned_sigmas, 95 | timestep_respacing=timestep_respacing, 96 | ) 97 | return model, diffusion 98 | 99 | 100 | def create_model( 101 | image_size, 102 | num_channels, 103 | num_res_blocks, 104 | channel_mult, 105 | attention_resolutions, 106 | num_heads, 107 | num_head_channels, 108 | num_heads_upsample, 109 | use_scale_shift_norm, 110 | dropout, 111 | xf_width, 112 | t5_name, 113 | resblock_updown, 114 | use_fp16, 115 | cache_text_emb, 116 | ): 117 | if channel_mult == "": 118 | if image_size == 256: 119 | channel_mult = (1, 1, 2, 2, 4, 4) 120 | elif image_size == 128: 121 | channel_mult = (1, 1, 2, 3, 4) 122 | elif image_size == 64: 123 | channel_mult = (1, 2, 3, 4) 124 | else: 125 | raise ValueError(f"unsupported image size: {image_size}") 126 | else: 127 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) 128 | assert 2 ** (len(channel_mult) + 2) == image_size 129 | 130 | attention_ds = [] 131 | for res in attention_resolutions.split(","): 132 | attention_ds.append(image_size // int(res)) 133 | 134 | model_cls = Text2ImUNet 135 | return model_cls( 136 | in_channels=3, 137 | model_channels=num_channels, 138 | out_channels=6, 139 | num_res_blocks=num_res_blocks, 140 | attention_resolutions=tuple(attention_ds), 141 | dropout=dropout, 142 | xf_width=xf_width, 143 | t5_name=t5_name, 144 | channel_mult=channel_mult, 145 | use_fp16=use_fp16, 146 | num_heads=num_heads, 147 | num_head_channels=num_head_channels, 148 | num_heads_upsample=num_heads_upsample, 149 | use_scale_shift_norm=use_scale_shift_norm, 150 | resblock_updown=resblock_updown, 151 | cache_text_emb=cache_text_emb, 152 | ) 153 | 154 | 155 | def create_gaussian_diffusion( 156 | *, 157 | steps=1000, 158 | learn_sigma=False, 159 | sigma_small=False, 160 | noise_schedule="linear", 161 | use_kl=False, 162 | predict_xstart=False, 163 | rescale_timesteps=False, 164 | rescale_learned_sigmas=False, 165 | timestep_respacing="", 166 | ): 167 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 168 | if use_kl: 169 | loss_type = gd.LossType.RESCALED_KL 170 | elif rescale_learned_sigmas: 171 | loss_type = gd.LossType.RESCALED_MSE 172 | else: 173 | loss_type = gd.LossType.MSE 174 | if not timestep_respacing: 175 | timestep_respacing = [steps] 176 | return SpacedDiffusion( 177 | use_timesteps=space_timesteps(steps, timestep_respacing), 178 | betas=betas, 179 | model_mean_type=( 180 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 181 | ), 182 | model_var_type=( 183 | ( 184 | gd.ModelVarType.FIXED_LARGE 185 | if not sigma_small 186 | else gd.ModelVarType.FIXED_SMALL 187 | ) 188 | if not learn_sigma 189 | else gd.ModelVarType.LEARNED_RANGE 190 | ), 191 | loss_type=loss_type, 192 | rescale_timesteps=rescale_timesteps, 193 | ) 194 | -------------------------------------------------------------------------------- /imagen_pytorch/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | def update_ema(target_params, source_params, rate=0.99): 12 | """ 13 | Update target parameters to be closer to those of source parameters using 14 | an exponential moving average. 15 | :param target_params: the target parameter sequence. 16 | :param source_params: the source parameter sequence. 17 | :param rate: the EMA rate (closer to 1 means slower). 18 | """ 19 | for targ, src in zip(target_params, source_params): 20 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 21 | 22 | class GroupNorm32(nn.GroupNorm): 23 | def __init__(self, num_groups, num_channels, swish, eps=1e-5): 24 | super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps) 25 | self.swish = swish 26 | 27 | def forward(self, x): 28 | y = super().forward(x.float()).to(x.dtype) 29 | if self.swish == 1.0: 30 | y = F.silu(y) 31 | elif self.swish: 32 | y = y * F.sigmoid(y * float(self.swish)) 33 | return y 34 | 35 | 36 | def conv_nd(dims, *args, **kwargs): 37 | """ 38 | Create a 1D, 2D, or 3D convolution module. 39 | """ 40 | if dims == 1: 41 | return nn.Conv1d(*args, **kwargs) 42 | elif dims == 2: 43 | return nn.Conv2d(*args, **kwargs) 44 | elif dims == 3: 45 | return nn.Conv3d(*args, **kwargs) 46 | raise ValueError(f"unsupported dimensions: {dims}") 47 | 48 | 49 | def linear(*args, **kwargs): 50 | """ 51 | Create a linear module. 52 | """ 53 | return nn.Linear(*args, **kwargs) 54 | 55 | 56 | def avg_pool_nd(dims, *args, **kwargs): 57 | """ 58 | Create a 1D, 2D, or 3D average pooling module. 59 | """ 60 | if dims == 1: 61 | return nn.AvgPool1d(*args, **kwargs) 62 | elif dims == 2: 63 | return nn.AvgPool2d(*args, **kwargs) 64 | elif dims == 3: 65 | return nn.AvgPool3d(*args, **kwargs) 66 | raise ValueError(f"unsupported dimensions: {dims}") 67 | 68 | 69 | def zero_module(module): 70 | """ 71 | Zero out the parameters of a module and return it. 72 | """ 73 | for p in module.parameters(): 74 | p.detach().zero_() 75 | return module 76 | 77 | 78 | def scale_module(module, scale): 79 | """ 80 | Scale the parameters of a module and return it. 81 | """ 82 | for p in module.parameters(): 83 | p.detach().mul_(scale) 84 | return module 85 | 86 | 87 | def normalization(channels, swish=0.0): 88 | """ 89 | Make a standard normalization layer, with an optional swish activation. 90 | 91 | :param channels: number of input channels. 92 | :return: an nn.Module for normalization. 93 | """ 94 | return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) 95 | 96 | 97 | def timestep_embedding(timesteps, dim, max_period=10000): 98 | """ 99 | Create sinusoidal timestep embeddings. 100 | 101 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 102 | These may be fractional. 103 | :param dim: the dimension of the output. 104 | :param max_period: controls the minimum frequency of the embeddings. 105 | :return: an [N x dim] Tensor of positional embeddings. 106 | """ 107 | half = dim // 2 108 | freqs = th.exp( 109 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 110 | ).to(device=timesteps.device) 111 | args = timesteps[:, None].float() * freqs[None] 112 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 113 | if dim % 2: 114 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 115 | return embedding 116 | -------------------------------------------------------------------------------- /imagen_pytorch/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | :param name: the name of the sampler. 12 | :param diffusion: the diffusion object to sample for. 13 | """ 14 | if name == "uniform": 15 | return UniformSampler(diffusion) 16 | elif name == "loss-second-moment": 17 | return LossSecondMomentResampler(diffusion) 18 | else: 19 | raise NotImplementedError(f"unknown schedule sampler: {name}") 20 | 21 | 22 | class ScheduleSampler(ABC): 23 | """ 24 | A distribution over timesteps in the diffusion process, intended to reduce 25 | variance of the objective. 26 | By default, samplers perform unbiased importance sampling, in which the 27 | objective's mean is unchanged. 28 | However, subclasses may override sample() to change how the resampled 29 | terms are reweighted, allowing for actual changes in the objective. 30 | """ 31 | 32 | @abstractmethod 33 | def weights(self): 34 | """ 35 | Get a numpy array of weights, one per diffusion step. 36 | The weights needn't be normalized, but must be positive. 37 | """ 38 | 39 | def sample(self, batch_size, device): 40 | """ 41 | Importance-sample timesteps for a batch. 42 | :param batch_size: the number of timesteps. 43 | :param device: the torch device to save to. 44 | :return: a tuple (timesteps, weights): 45 | - timesteps: a tensor of timestep indices. 46 | - weights: a tensor of weights to scale the resulting losses. 47 | """ 48 | w = self.weights() 49 | p = w / np.sum(w) 50 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 51 | indices = th.from_numpy(indices_np).long().to(device) 52 | weights_np = 1 / (len(p) * p[indices_np]) 53 | weights = th.from_numpy(weights_np).float().to(device) 54 | return indices, weights 55 | 56 | 57 | class UniformSampler(ScheduleSampler): 58 | def __init__(self, diffusion): 59 | self.diffusion = diffusion 60 | self._weights = np.ones([diffusion.num_timesteps]) 61 | 62 | def weights(self): 63 | return self._weights 64 | 65 | 66 | class LossAwareSampler(ScheduleSampler): 67 | def update_with_local_losses(self, local_ts, local_losses): 68 | """ 69 | Update the reweighting using losses from a model. 70 | Call this method from each rank with a batch of timesteps and the 71 | corresponding losses for each of those timesteps. 72 | This method will perform synchronization to make sure all of the ranks 73 | maintain the exact same reweighting. 74 | :param local_ts: an integer Tensor of timesteps. 75 | :param local_losses: a 1D Tensor of losses. 76 | """ 77 | batch_sizes = [ 78 | th.tensor([0], dtype=th.int32, device=local_ts.device) 79 | for _ in range(dist.get_world_size()) 80 | ] 81 | dist.all_gather( 82 | batch_sizes, 83 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 84 | ) 85 | 86 | # Pad all_gather batches to be the maximum batch size. 87 | batch_sizes = [x.item() for x in batch_sizes] 88 | max_bs = max(batch_sizes) 89 | 90 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 91 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 92 | dist.all_gather(timestep_batches, local_ts) 93 | dist.all_gather(loss_batches, local_losses) 94 | timesteps = [ 95 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 96 | ] 97 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 98 | self.update_with_all_losses(timesteps, losses) 99 | 100 | @abstractmethod 101 | def update_with_all_losses(self, ts, losses): 102 | """ 103 | Update the reweighting using losses from a model. 104 | Sub-classes should override this method to update the reweighting 105 | using losses from the model. 106 | This method directly updates the reweighting without synchronizing 107 | between workers. It is called by update_with_local_losses from all 108 | ranks with identical arguments. Thus, it should have deterministic 109 | behavior to maintain state across workers. 110 | :param ts: a list of int timesteps. 111 | :param losses: a list of float losses, one per timestep. 112 | """ 113 | 114 | 115 | class LossSecondMomentResampler(LossAwareSampler): 116 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 117 | self.diffusion = diffusion 118 | self.history_per_term = history_per_term 119 | self.uniform_prob = uniform_prob 120 | self._loss_history = np.zeros( 121 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 122 | ) 123 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 124 | 125 | def weights(self): 126 | if not self._warmed_up(): 127 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 128 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 129 | weights /= np.sum(weights) 130 | weights *= 1 - self.uniform_prob 131 | weights += self.uniform_prob / len(weights) 132 | return weights 133 | 134 | def update_with_all_losses(self, ts, losses): 135 | for t, loss in zip(ts, losses): 136 | if self._loss_counts[t] == self.history_per_term: 137 | # Shift out the oldest loss term. 138 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 139 | self._loss_history[t, -1] = loss 140 | else: 141 | self._loss_history[t, self._loss_counts[t]] = loss 142 | self._loss_counts[t] += 1 143 | 144 | def _warmed_up(self): 145 | return (self._loss_counts == self.history_per_term).all() -------------------------------------------------------------------------------- /imagen_pytorch/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | For example, if there's 300 timesteps and the section counts are [10,15,20] 13 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 14 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 15 | If the stride is a string starting with "ddim", then the fixed striding 16 | from the DDIM paper is used, and only one section is allowed. 17 | :param num_timesteps: the number of diffusion steps in the original 18 | process to divide up. 19 | :param section_counts: either a list of numbers, or a string containing 20 | comma-separated numbers, indicating the step count 21 | per section. As a special case, use "ddimN" where N 22 | is a number of steps to use the striding from the 23 | DDIM paper. 24 | :return: a set of diffusion steps from the original process to use. 25 | """ 26 | if isinstance(section_counts, str): 27 | if section_counts.startswith("ddim"): 28 | desired_count = int(section_counts[len("ddim") :]) 29 | for i in range(1, num_timesteps): 30 | if len(range(0, num_timesteps, i)) == desired_count: 31 | return set(range(0, num_timesteps, i)) 32 | raise ValueError( 33 | f"cannot create exactly {num_timesteps} steps with an integer stride" 34 | ) 35 | section_counts = [int(x) for x in section_counts.split(",")] 36 | size_per = num_timesteps // len(section_counts) 37 | extra = num_timesteps % len(section_counts) 38 | start_idx = 0 39 | all_steps = [] 40 | for i, section_count in enumerate(section_counts): 41 | size = size_per + (1 if i < extra else 0) 42 | if size < section_count: 43 | raise ValueError( 44 | f"cannot divide section of {size} steps into {section_count}" 45 | ) 46 | if section_count <= 1: 47 | frac_stride = 1 48 | else: 49 | frac_stride = (size - 1) / (section_count - 1) 50 | cur_idx = 0.0 51 | taken_steps = [] 52 | for _ in range(section_count): 53 | taken_steps.append(start_idx + round(cur_idx)) 54 | cur_idx += frac_stride 55 | all_steps += taken_steps 56 | start_idx += size 57 | return set(all_steps) 58 | 59 | 60 | class SpacedDiffusion(GaussianDiffusion): 61 | """ 62 | A diffusion process which can skip steps in a base diffusion process. 63 | :param use_timesteps: a collection (sequence or set) of timesteps from the 64 | original diffusion process to retain. 65 | :param kwargs: the kwargs to create the base diffusion process. 66 | """ 67 | 68 | def __init__(self, use_timesteps, **kwargs): 69 | self.use_timesteps = set(use_timesteps) 70 | self.timestep_map = [] 71 | self.original_num_steps = len(kwargs["betas"]) 72 | 73 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 74 | last_alpha_cumprod = 1.0 75 | new_betas = [] 76 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 77 | if i in self.use_timesteps: 78 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 79 | last_alpha_cumprod = alpha_cumprod 80 | self.timestep_map.append(i) 81 | kwargs["betas"] = np.array(new_betas) 82 | super().__init__(**kwargs) 83 | 84 | def p_mean_variance( 85 | self, model, *args, **kwargs 86 | ): # pylint: disable=signature-differs 87 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 88 | 89 | def training_losses( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 93 | 94 | def _wrap_model(self, model): 95 | if isinstance(model, _WrappedModel): 96 | return model 97 | return _WrappedModel( 98 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 99 | ) 100 | 101 | def _scale_timesteps(self, t): 102 | # Scaling is done by the wrapped model. 103 | return t 104 | 105 | 106 | class _WrappedModel: 107 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 108 | self.model = model 109 | self.timestep_map = timestep_map 110 | self.rescale_timesteps = rescale_timesteps 111 | self.original_num_steps = original_num_steps 112 | 113 | def __call__(self, x, ts, **kwargs): 114 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 115 | new_ts = map_tensor[ts] 116 | if self.rescale_timesteps: 117 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 118 | return self.model(x, new_ts, **kwargs) -------------------------------------------------------------------------------- /imagen_pytorch/text2im_model.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .nn import timestep_embedding 5 | from .unet import UNetModel 6 | from .xf import LayerNorm, Transformer, convert_module_to_f16 7 | from transformers import T5EncoderModel 8 | 9 | 10 | class Text2ImUNet(UNetModel): 11 | """ 12 | A UNetModel that conditions on text with an encoding transformer. 13 | Expects an extra kwarg `tokens` of text. 14 | :param text_ctx: number of text tokens to expect. 15 | :param xf_width: width of the transformer. 16 | :param xf_layers: depth of the transformer. 17 | :param xf_heads: heads in the transformer. 18 | :param xf_final_ln: use a LayerNorm after the output layer. 19 | :param tokenizer: the text tokenizer for sampling/vocab size. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | xf_width, 25 | t5_name, 26 | *args, 27 | cache_text_emb=True, 28 | **kwargs, 29 | ): 30 | self.xf_width = xf_width 31 | if not xf_width: 32 | super().__init__(*args, **kwargs, encoder_channels=None) 33 | else: 34 | super().__init__(*args, **kwargs, encoder_channels=xf_width) 35 | 36 | self.t5 = T5EncoderModel.from_pretrained(t5_name) 37 | if t5_name == 't5-11b': 38 | self.t5.to(th.float16) 39 | for param in self.t5.parameters(): 40 | param.requires_grad = False 41 | self.t5_proj = nn.Linear(self.t5.shared.embedding_dim, self.model_channels * 4) 42 | self.to_xf_width = nn.Linear(self.t5.shared.embedding_dim, xf_width) 43 | self.cache_text_emb = cache_text_emb 44 | self.cache = None 45 | def convert_to_fp16(self): 46 | 47 | super().convert_to_fp16() 48 | self.t5_proj.to(th.float16) 49 | self.t5.to(th.float16) 50 | self.to_xf_width.to(th.float16) 51 | def get_text_emb(self, tokens, mask): 52 | #with th.no_grad(): 53 | if self.cache is not None and self.cache_text_emb: 54 | return self.cache 55 | xf_out = self.t5(input_ids=tokens, attention_mask=mask)['last_hidden_state'].float()#.detach() 56 | xf_proj = self.t5_proj(xf_out[:, -1]) 57 | xf_out2 = self.to_xf_width(xf_out) 58 | xf_out2 = xf_out2.permute(0, 2, 1) # NLC -> NCL 59 | 60 | outputs = dict(xf_proj=xf_proj, xf_out=xf_out2) 61 | if self.cache_text_emb: 62 | self.cache = outputs 63 | return outputs 64 | 65 | 66 | def del_cache(self): 67 | self.cache = None 68 | 69 | def forward(self, x, timesteps, tokens=None, mask=None): 70 | hs = [] 71 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 72 | if self.xf_width: 73 | text_outputs = self.get_text_emb(tokens, mask) 74 | xf_proj, xf_out = text_outputs["xf_proj"], text_outputs["xf_out"] 75 | emb = emb + xf_proj.to(emb) 76 | else: 77 | xf_out = None 78 | h = x.type(self.dtype) 79 | for module in self.input_blocks: 80 | h = module(h, emb, xf_out) 81 | hs.append(h) 82 | h = self.middle_block(h, emb, xf_out) 83 | for module in self.output_blocks: 84 | h = th.cat([h, hs.pop()], dim=1) 85 | h = module(h, emb, xf_out) 86 | h = h.type(x.dtype) 87 | h = self.out(h) 88 | return h 89 | class SuperResText2ImUNet(Text2ImUNet): 90 | """ 91 | A text2im model that performs super-resolution. 92 | Expects an extra kwarg `low_res` to condition on a low-resolution image. 93 | """ 94 | 95 | def __init__(self, *args, **kwargs): 96 | if "in_channels" in kwargs: 97 | kwargs = dict(kwargs) 98 | kwargs["in_channels"] = kwargs["in_channels"] * 2 99 | else: 100 | # Curse you, Python. Or really, just curse positional arguments :|. 101 | args = list(args) 102 | args[1] = args[1] * 2 103 | super().__init__(*args, **kwargs) 104 | 105 | def forward(self, x, timesteps, low_res=None, **kwargs): 106 | _, _, new_height, new_width = x.shape 107 | upsampled = F.interpolate( 108 | low_res, (new_height, new_width), mode="bilinear", align_corners=False 109 | ) 110 | x = th.cat([x, upsampled], dim=1) 111 | return super().forward(x, timesteps, **kwargs) 112 | 113 | -------------------------------------------------------------------------------- /imagen_pytorch/tokenizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cene555/Imagen-pytorch/20318089b42cf25fea27618319aa4f7a105be2ec/imagen_pytorch/tokenizer/__init__.py -------------------------------------------------------------------------------- /imagen_pytorch/tokenizer/bpe.py: -------------------------------------------------------------------------------- 1 | """ 2 | Byte pair encoding utilities adapted from: 3 | https://github.com/openai/gpt-2/blob/master/src/encoder.py 4 | """ 5 | 6 | import gzip 7 | import json 8 | import os 9 | from functools import lru_cache 10 | from typing import List, Tuple 11 | 12 | import regex as re 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = ( 27 | list(range(ord("!"), ord("~") + 1)) 28 | + list(range(ord("¡"), ord("¬") + 1)) 29 | + list(range(ord("®"), ord("ÿ") + 1)) 30 | ) 31 | cs = bs[:] 32 | n = 0 33 | for b in range(2 ** 8): 34 | if b not in bs: 35 | bs.append(b) 36 | cs.append(2 ** 8 + n) 37 | n += 1 38 | cs = [chr(n) for n in cs] 39 | return dict(zip(bs, cs)) 40 | 41 | 42 | def get_pairs(word): 43 | """Return set of symbol pairs in a word. 44 | Word is represented as tuple of symbols (symbols being variable-length strings). 45 | """ 46 | pairs = set() 47 | prev_char = word[0] 48 | for char in word[1:]: 49 | pairs.add((prev_char, char)) 50 | prev_char = char 51 | return pairs 52 | 53 | 54 | class Encoder: 55 | def __init__(self, encoder, bpe_merges, errors="replace"): 56 | self.encoder = encoder 57 | self.decoder = {v: k for k, v in self.encoder.items()} 58 | self.errors = errors # how to handle errors in decoding 59 | self.byte_encoder = bytes_to_unicode() 60 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 61 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 62 | self.cache = {} 63 | 64 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 65 | self.pat = re.compile( 66 | r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" 67 | ) 68 | 69 | @property 70 | def n_vocab(self) -> int: 71 | return len(self.encoder) 72 | 73 | @property 74 | def end_token(self) -> int: 75 | return self.n_vocab - 1 76 | 77 | def padded_tokens_and_mask( 78 | self, tokens: List[int], text_ctx: int 79 | ) -> Tuple[List[int], List[bool]]: 80 | tokens = tokens[:text_ctx] 81 | padding = text_ctx - len(tokens) 82 | padded_tokens = tokens + [self.end_token] * padding 83 | mask = [True] * len(tokens) + [False] * padding 84 | return padded_tokens, mask 85 | 86 | def bpe(self, token): 87 | if token in self.cache: 88 | return self.cache[token] 89 | word = tuple(token) 90 | pairs = get_pairs(word) 91 | 92 | if not pairs: 93 | return token 94 | 95 | while True: 96 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) 97 | if bigram not in self.bpe_ranks: 98 | break 99 | first, second = bigram 100 | new_word = [] 101 | i = 0 102 | while i < len(word): 103 | try: 104 | j = word.index(first, i) 105 | new_word.extend(word[i:j]) 106 | i = j 107 | except: # pylint: disable=bare-except 108 | new_word.extend(word[i:]) 109 | break 110 | 111 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 112 | new_word.append(first + second) 113 | i += 2 114 | else: 115 | new_word.append(word[i]) 116 | i += 1 117 | new_word = tuple(new_word) 118 | word = new_word 119 | if len(word) == 1: 120 | break 121 | else: 122 | pairs = get_pairs(word) 123 | word = " ".join(word) 124 | self.cache[token] = word 125 | return word 126 | 127 | def encode(self, text): 128 | text = text.lower() 129 | bpe_tokens = [] 130 | for token in re.findall(self.pat, text): 131 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) 132 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")) 133 | return bpe_tokens 134 | 135 | def decode(self, tokens): 136 | text = "".join([self.decoder[token] for token in tokens]) 137 | text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) 138 | return text 139 | 140 | 141 | def get_encoder(): 142 | root_dir = os.path.dirname(os.path.abspath(__file__)) 143 | with gzip.open(os.path.join(root_dir, "encoder.json.gz"), "r") as f: 144 | encoder = json.load(f) 145 | with gzip.open(os.path.join(root_dir, "vocab.bpe.gz"), "r") as f: 146 | bpe_data = str(f.read(), "utf-8") 147 | bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]] 148 | return Encoder( 149 | encoder=encoder, 150 | bpe_merges=bpe_merges, 151 | ) 152 | -------------------------------------------------------------------------------- /imagen_pytorch/tokenizer/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cene555/Imagen-pytorch/20318089b42cf25fea27618319aa4f7a105be2ec/imagen_pytorch/tokenizer/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /imagen_pytorch/tokenizer/encoder.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cene555/Imagen-pytorch/20318089b42cf25fea27618319aa4f7a105be2ec/imagen_pytorch/tokenizer/encoder.json.gz -------------------------------------------------------------------------------- /imagen_pytorch/tokenizer/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copied from: https://github.com/openai/CLIP/blob/573315e83f07b53a61ff5098757e8fc885f1703e/clip/simple_tokenizer.py 3 | """ 4 | 5 | import gzip 6 | import html 7 | import os 8 | from functools import lru_cache 9 | from typing import List, Tuple 10 | 11 | import ftfy 12 | import regex as re 13 | 14 | 15 | @lru_cache() 16 | def default_bpe(): 17 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 18 | 19 | 20 | @lru_cache() 21 | def bytes_to_unicode(): 22 | """ 23 | Returns list of utf-8 byte and a corresponding list of unicode strings. 24 | The reversible bpe codes work on unicode strings. 25 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 26 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 27 | This is a signficant percentage of your normal, say, 32K bpe vocab. 28 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 29 | And avoids mapping to whitespace/control characters the bpe code barfs on. 30 | """ 31 | bs = ( 32 | list(range(ord("!"), ord("~") + 1)) 33 | + list(range(ord("¡"), ord("¬") + 1)) 34 | + list(range(ord("®"), ord("ÿ") + 1)) 35 | ) 36 | cs = bs[:] 37 | n = 0 38 | for b in range(2 ** 8): 39 | if b not in bs: 40 | bs.append(b) 41 | cs.append(2 ** 8 + n) 42 | n += 1 43 | cs = [chr(n) for n in cs] 44 | return dict(zip(bs, cs)) 45 | 46 | 47 | def get_pairs(word): 48 | """Return set of symbol pairs in a word. 49 | Word is represented as tuple of symbols (symbols being variable-length strings). 50 | """ 51 | pairs = set() 52 | prev_char = word[0] 53 | for char in word[1:]: 54 | pairs.add((prev_char, char)) 55 | prev_char = char 56 | return pairs 57 | 58 | 59 | def basic_clean(text): 60 | text = ftfy.fix_text(text) 61 | text = html.unescape(html.unescape(text)) 62 | return text.strip() 63 | 64 | 65 | def whitespace_clean(text): 66 | text = re.sub(r"\s+", " ", text) 67 | text = text.strip() 68 | return text 69 | 70 | 71 | class SimpleTokenizer(object): 72 | def __init__(self, bpe_path: str = default_bpe()): 73 | self.byte_encoder = bytes_to_unicode() 74 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 75 | merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") 76 | merges = merges[1 : 49152 - 256 - 2 + 1] 77 | merges = [tuple(merge.split()) for merge in merges] 78 | vocab = list(bytes_to_unicode().values()) 79 | vocab = vocab + [v + "" for v in vocab] 80 | for merge in merges: 81 | vocab.append("".join(merge)) 82 | vocab.extend(["<|startoftext|>", "<|endoftext|>"]) 83 | self.encoder = dict(zip(vocab, range(len(vocab)))) 84 | self.decoder = {v: k for k, v in self.encoder.items()} 85 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 86 | self.cache = {"<|startoftext|>": "<|startoftext|>", "<|endoftext|>": "<|endoftext|>"} 87 | self.pat = re.compile( 88 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", 89 | re.IGNORECASE, 90 | ) 91 | 92 | @property 93 | def start_token(self): 94 | return self.encoder["<|startoftext|>"] 95 | 96 | @property 97 | def end_token(self): 98 | return self.encoder["<|endoftext|>"] 99 | 100 | def padded_tokens_and_len(self, tokens: List[int], text_ctx: int) -> Tuple[List[int], int]: 101 | tokens = [self.start_token] + tokens[: text_ctx - 2] + [self.end_token] 102 | text_len = len(tokens) 103 | padding = text_ctx - len(tokens) 104 | padded_tokens = tokens + [0] * padding 105 | return padded_tokens, text_len 106 | 107 | def bpe(self, token): 108 | if token in self.cache: 109 | return self.cache[token] 110 | word = tuple(token[:-1]) + (token[-1] + "",) 111 | pairs = get_pairs(word) 112 | 113 | if not pairs: 114 | return token + "" 115 | 116 | while True: 117 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) 118 | if bigram not in self.bpe_ranks: 119 | break 120 | first, second = bigram 121 | new_word = [] 122 | i = 0 123 | while i < len(word): 124 | try: 125 | j = word.index(first, i) 126 | new_word.extend(word[i:j]) 127 | i = j 128 | except: # pylint: disable=bare-except 129 | new_word.extend(word[i:]) 130 | break 131 | 132 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 133 | new_word.append(first + second) 134 | i += 2 135 | else: 136 | new_word.append(word[i]) 137 | i += 1 138 | new_word = tuple(new_word) 139 | word = new_word 140 | if len(word) == 1: 141 | break 142 | else: 143 | pairs = get_pairs(word) 144 | word = " ".join(word) 145 | self.cache[token] = word 146 | return word 147 | 148 | def encode(self, text): 149 | bpe_tokens = [] 150 | text = whitespace_clean(basic_clean(text)).lower() 151 | for token in re.findall(self.pat, text): 152 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) 153 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")) 154 | return bpe_tokens 155 | 156 | def decode(self, tokens): 157 | text = "".join([self.decoder[token] for token in tokens]) 158 | text = ( 159 | bytearray([self.byte_decoder[c] for c in text]) 160 | .decode("utf-8", errors="replace") 161 | .replace("", " ") 162 | ) 163 | return text 164 | -------------------------------------------------------------------------------- /imagen_pytorch/tokenizer/vocab.bpe.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cene555/Imagen-pytorch/20318089b42cf25fea27618319aa4f7a105be2ec/imagen_pytorch/tokenizer/vocab.bpe.gz -------------------------------------------------------------------------------- /imagen_pytorch/unet.py: -------------------------------------------------------------------------------- 1 | import math 2 | from abc import abstractmethod 3 | 4 | import torch as th 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from .fp16_util import convert_module_to_f16, convert_module_to_f32 9 | from .nn import avg_pool_nd, conv_nd, linear, normalization, timestep_embedding, zero_module 10 | 11 | 12 | class TimestepBlock(nn.Module): 13 | """ 14 | Any module where forward() takes timestep embeddings as a second argument. 15 | """ 16 | 17 | @abstractmethod 18 | def forward(self, x, emb): 19 | """ 20 | Apply the module to `x` given `emb` timestep embeddings. 21 | """ 22 | 23 | 24 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 25 | """ 26 | A sequential module that passes timestep embeddings to the children that 27 | support it as an extra input. 28 | """ 29 | 30 | def forward(self, x, emb, encoder_out=None): 31 | for layer in self: 32 | if isinstance(layer, TimestepBlock): 33 | x = layer(x, emb) 34 | elif isinstance(layer, AttentionBlock): 35 | x = layer(x, encoder_out) 36 | else: 37 | x = layer(x) 38 | return x 39 | 40 | 41 | class Upsample(nn.Module): 42 | """ 43 | An upsampling layer with an optional convolution. 44 | 45 | :param channels: channels in the inputs and outputs. 46 | :param use_conv: a bool determining if a convolution is applied. 47 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 48 | upsampling occurs in the inner-two dimensions. 49 | """ 50 | 51 | def __init__(self, channels, use_conv, dims=2, out_channels=None): 52 | super().__init__() 53 | self.channels = channels 54 | self.out_channels = out_channels or channels 55 | self.use_conv = use_conv 56 | self.dims = dims 57 | if use_conv: 58 | self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) 59 | 60 | def forward(self, x): 61 | assert x.shape[1] == self.channels 62 | if self.dims == 3: 63 | x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") 64 | else: 65 | x = F.interpolate(x, scale_factor=2, mode="nearest") 66 | if self.use_conv: 67 | x = self.conv(x) 68 | return x 69 | 70 | 71 | class Downsample(nn.Module): 72 | """ 73 | A downsampling layer with an optional convolution. 74 | 75 | :param channels: channels in the inputs and outputs. 76 | :param use_conv: a bool determining if a convolution is applied. 77 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 78 | downsampling occurs in the inner-two dimensions. 79 | """ 80 | 81 | def __init__(self, channels, use_conv, dims=2, out_channels=None): 82 | super().__init__() 83 | self.channels = channels 84 | self.out_channels = out_channels or channels 85 | self.use_conv = use_conv 86 | self.dims = dims 87 | stride = 2 if dims != 3 else (1, 2, 2) 88 | if use_conv: 89 | self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=1) 90 | else: 91 | assert self.channels == self.out_channels 92 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 93 | 94 | def forward(self, x): 95 | assert x.shape[1] == self.channels 96 | return self.op(x) 97 | 98 | 99 | class ResBlock(TimestepBlock): 100 | """ 101 | A residual block that can optionally change the number of channels. 102 | 103 | :param channels: the number of input channels. 104 | :param emb_channels: the number of timestep embedding channels. 105 | :param dropout: the rate of dropout. 106 | :param out_channels: if specified, the number of out channels. 107 | :param use_conv: if True and out_channels is specified, use a spatial 108 | convolution instead of a smaller 1x1 convolution to change the 109 | channels in the skip connection. 110 | :param dims: determines if the signal is 1D, 2D, or 3D. 111 | :param use_checkpoint: if True, use gradient checkpointing on this module. 112 | :param up: if True, use this block for upsampling. 113 | :param down: if True, use this block for downsampling. 114 | """ 115 | 116 | def __init__( 117 | self, 118 | channels, 119 | emb_channels, 120 | dropout, 121 | out_channels=None, 122 | use_conv=False, 123 | use_scale_shift_norm=False, 124 | dims=2, 125 | use_checkpoint=False, 126 | up=False, 127 | down=False, 128 | ): 129 | super().__init__() 130 | self.channels = channels 131 | self.emb_channels = emb_channels 132 | self.dropout = dropout 133 | self.out_channels = out_channels or channels 134 | self.use_conv = use_conv 135 | self.use_checkpoint = use_checkpoint 136 | self.use_scale_shift_norm = use_scale_shift_norm 137 | 138 | self.in_layers = nn.Sequential( 139 | normalization(channels, swish=1.0), 140 | nn.Identity(), 141 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 142 | ) 143 | 144 | self.updown = up or down 145 | 146 | if up: 147 | self.h_upd = Upsample(channels, False, dims) 148 | self.x_upd = Upsample(channels, False, dims) 149 | elif down: 150 | self.h_upd = Downsample(channels, False, dims) 151 | self.x_upd = Downsample(channels, False, dims) 152 | else: 153 | self.h_upd = self.x_upd = nn.Identity() 154 | 155 | self.emb_layers = nn.Sequential( 156 | nn.SiLU(), 157 | linear( 158 | emb_channels, 159 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, 160 | ), 161 | ) 162 | self.out_layers = nn.Sequential( 163 | normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0), 164 | nn.SiLU() if use_scale_shift_norm else nn.Identity(), 165 | nn.Dropout(p=dropout), 166 | zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), 167 | ) 168 | 169 | if self.out_channels == channels: 170 | self.skip_connection = nn.Identity() 171 | elif use_conv: 172 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) 173 | else: 174 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 175 | 176 | def forward(self, x, emb): 177 | """ 178 | Apply the block to a Tensor, conditioned on a timestep embedding. 179 | 180 | :param x: an [N x C x ...] Tensor of features. 181 | :param emb: an [N x emb_channels] Tensor of timestep embeddings. 182 | :return: an [N x C x ...] Tensor of outputs. 183 | """ 184 | if self.updown: 185 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 186 | h = in_rest(x) 187 | h = self.h_upd(h) 188 | x = self.x_upd(x) 189 | h = in_conv(h) 190 | else: 191 | h = self.in_layers(x) 192 | emb_out = self.emb_layers(emb).type(h.dtype) 193 | while len(emb_out.shape) < len(h.shape): 194 | emb_out = emb_out[..., None] 195 | if self.use_scale_shift_norm: 196 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 197 | scale, shift = th.chunk(emb_out, 2, dim=1) 198 | h = out_norm(h) * (1 + scale) + shift 199 | h = out_rest(h) 200 | else: 201 | h = h + emb_out 202 | h = self.out_layers(h) 203 | return self.skip_connection(x) + h 204 | 205 | 206 | class AttentionBlock(nn.Module): 207 | """ 208 | An attention block that allows spatial positions to attend to each other. 209 | 210 | Originally ported from here, but adapted to the N-d case. 211 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 212 | """ 213 | 214 | def __init__( 215 | self, 216 | channels, 217 | num_heads=1, 218 | num_head_channels=-1, 219 | use_checkpoint=False, 220 | encoder_channels=None, 221 | ): 222 | super().__init__() 223 | self.channels = channels 224 | if num_head_channels == -1: 225 | self.num_heads = num_heads 226 | else: 227 | assert ( 228 | channels % num_head_channels == 0 229 | ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" 230 | self.num_heads = channels // num_head_channels 231 | self.use_checkpoint = use_checkpoint 232 | self.norm = normalization(channels, swish=0.0) 233 | self.qkv = conv_nd(1, channels, channels * 3, 1) 234 | self.attention = QKVAttention(self.num_heads) 235 | 236 | if encoder_channels is not None: 237 | self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1) 238 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) 239 | 240 | def forward(self, x, encoder_out=None): 241 | b, c, *spatial = x.shape 242 | qkv = self.qkv(self.norm(x).view(b, c, -1)) 243 | if encoder_out is not None: 244 | encoder_out = self.encoder_kv(encoder_out) 245 | h = self.attention(qkv, encoder_out) 246 | else: 247 | h = self.attention(qkv) 248 | h = self.proj_out(h) 249 | return x + h.reshape(b, c, *spatial) 250 | 251 | 252 | class QKVAttention(nn.Module): 253 | """ 254 | A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping 255 | """ 256 | 257 | def __init__(self, n_heads): 258 | super().__init__() 259 | self.n_heads = n_heads 260 | 261 | def forward(self, qkv, encoder_kv=None): 262 | """ 263 | Apply QKV attention. 264 | 265 | :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. 266 | :return: an [N x (H * C) x T] tensor after attention. 267 | """ 268 | bs, width, length = qkv.shape 269 | assert width % (3 * self.n_heads) == 0 270 | ch = width // (3 * self.n_heads) 271 | q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) 272 | if encoder_kv is not None: 273 | assert encoder_kv.shape[1] == self.n_heads * ch * 2 274 | ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1) 275 | k = th.cat([ek, k], dim=-1) 276 | v = th.cat([ev, v], dim=-1) 277 | scale = 1 / math.sqrt(math.sqrt(ch)) 278 | weight = th.einsum( 279 | "bct,bcs->bts", q * scale, k * scale 280 | ) # More stable with f16 than dividing afterwards 281 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 282 | a = th.einsum("bts,bcs->bct", weight, v) 283 | return a.reshape(bs, -1, length) 284 | 285 | 286 | class UNetModel(nn.Module): 287 | """ 288 | The full UNet model with attention and timestep embedding. 289 | 290 | :param in_channels: channels in the input Tensor. 291 | :param model_channels: base channel count for the model. 292 | :param out_channels: channels in the output Tensor. 293 | :param num_res_blocks: number of residual blocks per downsample. 294 | :param attention_resolutions: a collection of downsample rates at which 295 | attention will take place. May be a set, list, or tuple. 296 | For example, if this contains 4, then at 4x downsampling, attention 297 | will be used. 298 | :param dropout: the dropout probability. 299 | :param channel_mult: channel multiplier for each level of the UNet. 300 | :param conv_resample: if True, use learned convolutions for upsampling and 301 | downsampling. 302 | :param dims: determines if the signal is 1D, 2D, or 3D. 303 | :param num_classes: if specified (as an int), then this model will be 304 | class-conditional with `num_classes` classes. 305 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 306 | :param num_heads: the number of attention heads in each attention layer. 307 | :param num_heads_channels: if specified, ignore num_heads and instead use 308 | a fixed channel width per attention head. 309 | :param num_heads_upsample: works with num_heads to set a different number 310 | of heads for upsampling. Deprecated. 311 | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. 312 | :param resblock_updown: use residual blocks for up/downsampling. 313 | """ 314 | 315 | def __init__( 316 | self, 317 | in_channels, 318 | model_channels, 319 | out_channels, 320 | num_res_blocks, 321 | attention_resolutions, 322 | dropout=0, 323 | channel_mult=(1, 2, 4, 8), 324 | conv_resample=True, 325 | dims=2, 326 | num_classes=None, 327 | use_checkpoint=False, 328 | use_fp16=False, 329 | num_heads=1, 330 | num_head_channels=-1, 331 | num_heads_upsample=-1, 332 | use_scale_shift_norm=False, 333 | resblock_updown=False, 334 | encoder_channels=None, 335 | ): 336 | super().__init__() 337 | 338 | if num_heads_upsample == -1: 339 | num_heads_upsample = num_heads 340 | 341 | self.in_channels = in_channels 342 | self.model_channels = model_channels 343 | self.out_channels = out_channels 344 | self.num_res_blocks = num_res_blocks 345 | self.attention_resolutions = attention_resolutions 346 | self.dropout = dropout 347 | self.channel_mult = channel_mult 348 | self.conv_resample = conv_resample 349 | self.num_classes = num_classes 350 | self.use_checkpoint = use_checkpoint 351 | self.dtype = th.float16 if use_fp16 else th.float32 352 | self.num_heads = num_heads 353 | self.num_head_channels = num_head_channels 354 | self.num_heads_upsample = num_heads_upsample 355 | 356 | time_embed_dim = model_channels * 4 357 | self.time_embed = nn.Sequential( 358 | linear(model_channels, time_embed_dim), 359 | nn.SiLU(), 360 | linear(time_embed_dim, time_embed_dim), 361 | ) 362 | 363 | if self.num_classes is not None: 364 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) 365 | 366 | ch = input_ch = int(channel_mult[0] * model_channels) 367 | self.input_blocks = nn.ModuleList( 368 | [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))] 369 | ) 370 | self._feature_size = ch 371 | input_block_chans = [ch] 372 | ds = 1 373 | for level, mult in enumerate(channel_mult): 374 | for _ in range(num_res_blocks): 375 | layers = [ 376 | ResBlock( 377 | ch, 378 | time_embed_dim, 379 | dropout, 380 | out_channels=int(mult * model_channels), 381 | dims=dims, 382 | use_checkpoint=use_checkpoint, 383 | use_scale_shift_norm=use_scale_shift_norm, 384 | ) 385 | ] 386 | ch = int(mult * model_channels) 387 | if ds in attention_resolutions: 388 | layers.append( 389 | AttentionBlock( 390 | ch, 391 | use_checkpoint=use_checkpoint, 392 | num_heads=num_heads, 393 | num_head_channels=num_head_channels, 394 | encoder_channels=encoder_channels, 395 | ) 396 | ) 397 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 398 | self._feature_size += ch 399 | input_block_chans.append(ch) 400 | if level != len(channel_mult) - 1: 401 | out_ch = ch 402 | self.input_blocks.append( 403 | TimestepEmbedSequential( 404 | ResBlock( 405 | ch, 406 | time_embed_dim, 407 | dropout, 408 | out_channels=out_ch, 409 | dims=dims, 410 | use_checkpoint=use_checkpoint, 411 | use_scale_shift_norm=use_scale_shift_norm, 412 | down=True, 413 | ) 414 | if resblock_updown 415 | else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) 416 | ) 417 | ) 418 | ch = out_ch 419 | input_block_chans.append(ch) 420 | ds *= 2 421 | self._feature_size += ch 422 | 423 | self.middle_block = TimestepEmbedSequential( 424 | ResBlock( 425 | ch, 426 | time_embed_dim, 427 | dropout, 428 | dims=dims, 429 | use_checkpoint=use_checkpoint, 430 | use_scale_shift_norm=use_scale_shift_norm, 431 | ), 432 | AttentionBlock( 433 | ch, 434 | use_checkpoint=use_checkpoint, 435 | num_heads=num_heads, 436 | num_head_channels=num_head_channels, 437 | encoder_channels=encoder_channels, 438 | ), 439 | ResBlock( 440 | ch, 441 | time_embed_dim, 442 | dropout, 443 | dims=dims, 444 | use_checkpoint=use_checkpoint, 445 | use_scale_shift_norm=use_scale_shift_norm, 446 | ), 447 | ) 448 | self._feature_size += ch 449 | 450 | self.output_blocks = nn.ModuleList([]) 451 | for level, mult in list(enumerate(channel_mult))[::-1]: 452 | for i in range(num_res_blocks + 1): 453 | ich = input_block_chans.pop() 454 | layers = [ 455 | ResBlock( 456 | ch + ich, 457 | time_embed_dim, 458 | dropout, 459 | out_channels=int(model_channels * mult), 460 | dims=dims, 461 | use_checkpoint=use_checkpoint, 462 | use_scale_shift_norm=use_scale_shift_norm, 463 | ) 464 | ] 465 | ch = int(model_channels * mult) 466 | if ds in attention_resolutions: 467 | layers.append( 468 | AttentionBlock( 469 | ch, 470 | use_checkpoint=use_checkpoint, 471 | num_heads=num_heads_upsample, 472 | num_head_channels=num_head_channels, 473 | encoder_channels=encoder_channels, 474 | ) 475 | ) 476 | if level and i == num_res_blocks: 477 | out_ch = ch 478 | layers.append( 479 | ResBlock( 480 | ch, 481 | time_embed_dim, 482 | dropout, 483 | out_channels=out_ch, 484 | dims=dims, 485 | use_checkpoint=use_checkpoint, 486 | use_scale_shift_norm=use_scale_shift_norm, 487 | up=True, 488 | ) 489 | if resblock_updown 490 | else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) 491 | ) 492 | ds //= 2 493 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 494 | self._feature_size += ch 495 | 496 | self.out = nn.Sequential( 497 | normalization(ch, swish=1.0), 498 | nn.Identity(), 499 | zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)), 500 | ) 501 | self.use_fp16 = use_fp16 502 | 503 | def convert_to_fp16(self): 504 | """ 505 | Convert the torso of the model to float16. 506 | """ 507 | self.input_blocks.apply(convert_module_to_f16) 508 | self.middle_block.apply(convert_module_to_f16) 509 | self.output_blocks.apply(convert_module_to_f16) 510 | 511 | def convert_to_fp32(self): 512 | """ 513 | Convert the torso of the model to float32. 514 | """ 515 | self.input_blocks.apply(convert_module_to_f32) 516 | self.middle_block.apply(convert_module_to_f32) 517 | self.output_blocks.apply(convert_module_to_f32) 518 | 519 | def forward(self, x, timesteps, y=None): 520 | """ 521 | Apply the model to an input batch. 522 | 523 | :param x: an [N x C x ...] Tensor of inputs. 524 | :param timesteps: a 1-D batch of timesteps. 525 | :param y: an [N] Tensor of labels, if class-conditional. 526 | :return: an [N x C x ...] Tensor of outputs. 527 | """ 528 | assert (y is not None) == ( 529 | self.num_classes is not None 530 | ), "must specify y if and only if the model is class-conditional" 531 | 532 | hs = [] 533 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 534 | 535 | if self.num_classes is not None: 536 | assert y.shape == (x.shape[0],) 537 | emb = emb + self.label_emb(y) 538 | 539 | h = x.type(self.dtype) 540 | for module in self.input_blocks: 541 | h = module(h, emb) 542 | hs.append(h) 543 | h = self.middle_block(h, emb) 544 | for module in self.output_blocks: 545 | h = th.cat([h, hs.pop()], dim=1) 546 | h = module(h, emb) 547 | h = h.type(x.dtype) 548 | return self.out(h) 549 | 550 | class SuperResUNetModel(UNetModel): 551 | """ 552 | A UNetModel that performs super-resolution. 553 | 554 | Expects an extra kwarg `low_res` to condition on a low-resolution image. 555 | """ 556 | 557 | def __init__(self, *args, **kwargs): 558 | if "in_channels" in kwargs: 559 | kwargs = dict(kwargs) 560 | kwargs["in_channels"] = kwargs["in_channels"] * 2 561 | else: 562 | # Curse you, Python. Or really, just curse positional arguments :|. 563 | args = list(args) 564 | args[1] = args[1] * 2 565 | super().__init__(*args, **kwargs) 566 | 567 | def forward(self, x, timesteps, low_res=None, **kwargs): 568 | _, _, new_height, new_width = x.shape 569 | upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") 570 | x = th.cat([x, upsampled], dim=1) 571 | return super().forward(x, timesteps, **kwargs) 572 | 573 | 574 | class InpaintUNetModel(UNetModel): 575 | """ 576 | A UNetModel which can perform inpainting. 577 | """ 578 | 579 | def __init__(self, *args, **kwargs): 580 | if "in_channels" in kwargs: 581 | kwargs = dict(kwargs) 582 | kwargs["in_channels"] = kwargs["in_channels"] * 2 + 1 583 | else: 584 | # Curse you, Python. Or really, just curse positional arguments :|. 585 | args = list(args) 586 | args[1] = args[1] * 2 + 1 587 | super().__init__(*args, **kwargs) 588 | 589 | def forward(self, x, timesteps, inpaint_image=None, inpaint_mask=None, **kwargs): 590 | if inpaint_image is None: 591 | inpaint_image = th.zeros_like(x) 592 | if inpaint_mask is None: 593 | inpaint_mask = th.zeros_like(x[:, :1]) 594 | return super().forward( 595 | th.cat([x, inpaint_image * inpaint_mask, inpaint_mask], dim=1), 596 | timesteps, 597 | **kwargs, 598 | ) 599 | 600 | 601 | class SuperResInpaintUNetModel(UNetModel): 602 | """ 603 | A UNetModel which can perform both upsampling and inpainting. 604 | """ 605 | 606 | def __init__(self, *args, **kwargs): 607 | if "in_channels" in kwargs: 608 | kwargs = dict(kwargs) 609 | kwargs["in_channels"] = kwargs["in_channels"] * 3 + 1 610 | else: 611 | # Curse you, Python. Or really, just curse positional arguments :|. 612 | args = list(args) 613 | args[1] = args[1] * 3 + 1 614 | super().__init__(*args, **kwargs) 615 | 616 | def forward( 617 | self, 618 | x, 619 | timesteps, 620 | inpaint_image=None, 621 | inpaint_mask=None, 622 | low_res=None, 623 | **kwargs, 624 | ): 625 | if inpaint_image is None: 626 | inpaint_image = th.zeros_like(x) 627 | if inpaint_mask is None: 628 | inpaint_mask = th.zeros_like(x[:, :1]) 629 | _, _, new_height, new_width = x.shape 630 | upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") 631 | return super().forward( 632 | th.cat([x, inpaint_image * inpaint_mask, inpaint_mask, upsampled], dim=1), 633 | timesteps, 634 | **kwargs, 635 | ) 636 | -------------------------------------------------------------------------------- /imagen_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch as th 4 | import torch.nn as nn 5 | 6 | def mean_flat(tensor): 7 | """ 8 | Take the mean over all non-batch dimensions. 9 | """ 10 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) -------------------------------------------------------------------------------- /imagen_pytorch/xf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transformer implementation adapted from CLIP ViT: 3 | https://github.com/openai/CLIP/blob/4c0275784d6d9da97ca1f47eaaee31de1867da91/clip/model.py 4 | """ 5 | 6 | import math 7 | 8 | import torch as th 9 | import torch.nn as nn 10 | 11 | 12 | def convert_module_to_f16(l): 13 | """ 14 | Convert primitive modules to float16. 15 | """ 16 | if isinstance(l, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): 17 | l.weight.data = l.weight.data.half() 18 | if l.bias is not None: 19 | l.bias.data = l.bias.data.half() 20 | 21 | 22 | class LayerNorm(nn.LayerNorm): 23 | """ 24 | Implementation that supports fp16 inputs but fp32 gains/biases. 25 | """ 26 | 27 | def forward(self, x: th.Tensor): 28 | return super().forward(x.float()).to(x.dtype) 29 | 30 | 31 | class MultiheadAttention(nn.Module): 32 | def __init__(self, n_ctx, width, heads): 33 | super().__init__() 34 | self.n_ctx = n_ctx 35 | self.width = width 36 | self.heads = heads 37 | self.c_qkv = nn.Linear(width, width * 3) 38 | self.c_proj = nn.Linear(width, width) 39 | self.attention = QKVMultiheadAttention(heads, n_ctx) 40 | 41 | def forward(self, x): 42 | x = self.c_qkv(x) 43 | x = self.attention(x) 44 | x = self.c_proj(x) 45 | return x 46 | 47 | 48 | class MLP(nn.Module): 49 | def __init__(self, width): 50 | super().__init__() 51 | self.width = width 52 | self.c_fc = nn.Linear(width, width * 4) 53 | self.c_proj = nn.Linear(width * 4, width) 54 | self.gelu = nn.GELU() 55 | 56 | def forward(self, x): 57 | return self.c_proj(self.gelu(self.c_fc(x))) 58 | 59 | 60 | class QKVMultiheadAttention(nn.Module): 61 | def __init__(self, n_heads: int, n_ctx: int): 62 | super().__init__() 63 | self.n_heads = n_heads 64 | self.n_ctx = n_ctx 65 | 66 | def forward(self, qkv): 67 | bs, n_ctx, width = qkv.shape 68 | attn_ch = width // self.n_heads // 3 69 | scale = 1 / math.sqrt(math.sqrt(attn_ch)) 70 | qkv = qkv.view(bs, n_ctx, self.n_heads, -1) 71 | q, k, v = th.split(qkv, attn_ch, dim=-1) 72 | weight = th.einsum( 73 | "bthc,bshc->bhts", q * scale, k * scale 74 | ) # More stable with f16 than dividing afterwards 75 | wdtype = weight.dtype 76 | weight = th.softmax(weight.float(), dim=-1).type(wdtype) 77 | return th.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 78 | 79 | 80 | class ResidualAttentionBlock(nn.Module): 81 | def __init__( 82 | self, 83 | n_ctx: int, 84 | width: int, 85 | heads: int, 86 | ): 87 | super().__init__() 88 | 89 | self.attn = MultiheadAttention( 90 | n_ctx, 91 | width, 92 | heads, 93 | ) 94 | self.ln_1 = LayerNorm(width) 95 | self.mlp = MLP(width) 96 | self.ln_2 = LayerNorm(width) 97 | 98 | def forward(self, x: th.Tensor): 99 | x = x + self.attn(self.ln_1(x)) 100 | x = x + self.mlp(self.ln_2(x)) 101 | return x 102 | 103 | 104 | class Transformer(nn.Module): 105 | def __init__( 106 | self, 107 | n_ctx: int, 108 | width: int, 109 | layers: int, 110 | heads: int, 111 | ): 112 | super().__init__() 113 | self.n_ctx = n_ctx 114 | self.width = width 115 | self.layers = layers 116 | self.resblocks = nn.ModuleList( 117 | [ 118 | ResidualAttentionBlock( 119 | n_ctx, 120 | width, 121 | heads, 122 | ) 123 | for _ in range(layers) 124 | ] 125 | ) 126 | 127 | def forward(self, x: th.Tensor): 128 | for block in self.resblocks: 129 | x = block(x) 130 | return x 131 | -------------------------------------------------------------------------------- /images/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cene555/Imagen-pytorch/20318089b42cf25fea27618319aa4f7a105be2ec/images/1.jpg -------------------------------------------------------------------------------- /images/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cene555/Imagen-pytorch/20318089b42cf25fea27618319aa4f7a105be2ec/images/2.jpg -------------------------------------------------------------------------------- /images/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cene555/Imagen-pytorch/20318089b42cf25fea27618319aa4f7a105be2ec/images/3.jpg -------------------------------------------------------------------------------- /images/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cene555/Imagen-pytorch/20318089b42cf25fea27618319aa4f7a105be2ec/images/4.jpg -------------------------------------------------------------------------------- /images/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cene555/Imagen-pytorch/20318089b42cf25fea27618319aa4f7a105be2ec/images/5.jpg -------------------------------------------------------------------------------- /images/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cene555/Imagen-pytorch/20318089b42cf25fea27618319aa4f7a105be2ec/images/6.jpg -------------------------------------------------------------------------------- /images/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cene555/Imagen-pytorch/20318089b42cf25fea27618319aa4f7a105be2ec/images/7.jpg -------------------------------------------------------------------------------- /images/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cene555/Imagen-pytorch/20318089b42cf25fea27618319aa4f7a105be2ec/images/8.jpg -------------------------------------------------------------------------------- /notebooks/Imagen_pytorch_inference_new.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "tXkwem58FrAj" 7 | }, 8 | "source": [ 9 | "## Installation" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 191, 15 | "metadata": { 16 | "id": "jFduUePWDiRr" 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "%%capture\n", 21 | "!git lfs install\n", 22 | "!git clone https://huggingface.co/Cene655/ImagenT5-3B" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": { 29 | "id": "eBuoRWpcY3ph" 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "%%capture\n", 34 | "!pip install git+https://github.com/cene555/Imagen-pytorch.git\n", 35 | "!pip install git+https://github.com/openai/CLIP.git" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": { 42 | "id": "mH-AFZAEfyEJ" 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "%%capture\n", 47 | "!git clone https://github.com/xinntao/Real-ESRGAN.git" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": { 54 | "id": "MvW2RHMHQl9g" 55 | }, 56 | "outputs": [], 57 | "source": [ 58 | "%cd Real-ESRGAN" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": { 65 | "id": "xIiOlXJzQosB" 66 | }, 67 | "outputs": [], 68 | "source": [ 69 | "%%capture\n", 70 | "!pip install basicsr\n", 71 | "# facexlib and gfpgan are for face enhancement\n", 72 | "!pip install facexlib\n", 73 | "!pip install gfpgan" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": { 80 | "id": "fChuITUYQstj" 81 | }, 82 | "outputs": [], 83 | "source": [ 84 | "%%capture\n", 85 | "!pip install -r requirements.txt\n", 86 | "!python setup.py develop\n", 87 | "!wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P experiments/pretrained_models" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": { 93 | "id": "h5_rnwU5GTry" 94 | }, 95 | "source": [ 96 | "## Imports " 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": { 103 | "id": "HAGoFjvGZJ6s" 104 | }, 105 | "outputs": [], 106 | "source": [ 107 | "from PIL import Image\n", 108 | "from IPython.display import display\n", 109 | "import torch as th\n", 110 | "from imagen_pytorch.model_creation import create_model_and_diffusion as create_model_and_diffusion_dalle2\n", 111 | "from imagen_pytorch.model_creation import model_and_diffusion_defaults as model_and_diffusion_defaults_dalle2\n", 112 | "from transformers import AutoTokenizer\n", 113 | "import cv2\n", 114 | "\n", 115 | "import glob\n", 116 | "import os\n", 117 | "from basicsr.archs.rrdbnet_arch import RRDBNet\n", 118 | "from realesrgan import RealESRGANer\n", 119 | "from realesrgan.archs.srvgg_arch import SRVGGNetCompact\n", 120 | "from gfpgan import GFPGANer\n", 121 | "\n", 122 | "has_cuda = th.cuda.is_available()\n", 123 | "device = th.device('cpu' if not has_cuda else 'cuda')" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": { 129 | "id": "qhCZMEueGk3Q" 130 | }, 131 | "source": [ 132 | "## Setting Up" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": { 139 | "id": "Kp6HRM1vdFZq" 140 | }, 141 | "outputs": [], 142 | "source": [ 143 | "def model_fn(x_t, ts, **kwargs):\n", 144 | " guidance_scale = 5\n", 145 | " half = x_t[: len(x_t) // 2]\n", 146 | " combined = th.cat([half, half], dim=0)\n", 147 | " model_out = model(combined, ts, **kwargs)\n", 148 | " eps, rest = model_out[:, :3], model_out[:, 3:]\n", 149 | " cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)\n", 150 | " half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)\n", 151 | " eps = th.cat([half_eps, half_eps], dim=0)\n", 152 | " return th.cat([eps, rest], dim=1)" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": { 159 | "id": "RyQgjL0OdMeu" 160 | }, 161 | "outputs": [], 162 | "source": [ 163 | "def show_images(batch: th.Tensor):\n", 164 | " \"\"\" Display a batch of images inline.\"\"\"\n", 165 | " scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()\n", 166 | " reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])\n", 167 | " display(Image.fromarray(reshaped.numpy()))" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": { 174 | "id": "rVHdkvoGPqJJ" 175 | }, 176 | "outputs": [], 177 | "source": [ 178 | "def get_numpy_img(img):\n", 179 | " scaled = ((img + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()\n", 180 | " reshaped = scaled.permute(2, 0, 3, 1).reshape([img.shape[2], -1, 3])\n", 181 | " return cv2.cvtColor(reshaped.numpy(), cv2.COLOR_BGR2RGB)" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "metadata": { 188 | "id": "qhKnHPtJZCL_" 189 | }, 190 | "outputs": [], 191 | "source": [ 192 | "def _fix_path(path):\n", 193 | " d = th.load(path)\n", 194 | " checkpoint = {}\n", 195 | " for key in d.keys():\n", 196 | " checkpoint[key.replace('module.','')] = d[key]\n", 197 | " return checkpoint" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": null, 203 | "metadata": { 204 | "id": "veXmk5XTZOuy" 205 | }, 206 | "outputs": [], 207 | "source": [ 208 | "options = model_and_diffusion_defaults_dalle2()\n", 209 | "options['use_fp16'] = False\n", 210 | "options['diffusion_steps'] = 200\n", 211 | "options['num_res_blocks'] = 3\n", 212 | "options['t5_name'] = 't5-3b'\n", 213 | "options['cache_text_emb'] = True\n", 214 | "model, diffusion = create_model_and_diffusion_dalle2(**options)\n", 215 | "\n", 216 | "model.eval()\n", 217 | "\n", 218 | "#if has_cuda:\n", 219 | "# model.convert_to_fp16()\n", 220 | "\n", 221 | "model.to(device)" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 192, 227 | "metadata": { 228 | "colab": { 229 | "base_uri": "https://localhost:8080/" 230 | }, 231 | "id": "dtqqsCzzaPzo", 232 | "outputId": "c8da3d0a-b897-4743-bc48-ab8440f9354b" 233 | }, 234 | "outputs": [ 235 | { 236 | "output_type": "stream", 237 | "name": "stdout", 238 | "text": [ 239 | "total base parameters 1550556742\n" 240 | ] 241 | } 242 | ], 243 | "source": [ 244 | "model.load_state_dict(_fix_path('/content/ImagenT5-3B/model.pt'))\n", 245 | "print('total base parameters', sum(x.numel() for x in model.parameters()))" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 193, 251 | "metadata": { 252 | "colab": { 253 | "base_uri": "https://localhost:8080/" 254 | }, 255 | "id": "_mdhjDRmejr5", 256 | "outputId": "61049fe6-00fa-4cd3-ef48-c3c8932e15f0" 257 | }, 258 | "outputs": [ 259 | { 260 | "output_type": "execute_result", 261 | "data": { 262 | "text/plain": [ 263 | "1550556742" 264 | ] 265 | }, 266 | "metadata": {}, 267 | "execution_count": 193 268 | } 269 | ], 270 | "source": [ 271 | "num_params = sum(param.numel() for param in model.parameters())\n", 272 | "num_params" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 194, 278 | "metadata": { 279 | "id": "4oDZRKP_NcV0" 280 | }, 281 | "outputs": [], 282 | "source": [ 283 | "realesrgan_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,\n", 284 | " num_block=23, num_grow_ch=32, scale=4)" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 195, 290 | "metadata": { 291 | "id": "BFl4yR5ONtil" 292 | }, 293 | "outputs": [], 294 | "source": [ 295 | "netscale = 4" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 196, 301 | "metadata": { 302 | "id": "Vwy6nPleNuUN" 303 | }, 304 | "outputs": [], 305 | "source": [ 306 | "upsampler = RealESRGANer(\n", 307 | " scale=netscale,\n", 308 | " model_path='/content/Real-ESRGAN/experiments/pretrained_models/RealESRGAN_x4plus.pth',\n", 309 | " model=realesrgan_model,\n", 310 | " tile=0,\n", 311 | " tile_pad=10,\n", 312 | " pre_pad=0,\n", 313 | " half=True\n", 314 | ")" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": 197, 320 | "metadata": { 321 | "id": "P_PiM5y5PHCe" 322 | }, 323 | "outputs": [], 324 | "source": [ 325 | "face_enhancer = GFPGANer(\n", 326 | " model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',\n", 327 | " upscale=4,\n", 328 | " arch='clean',\n", 329 | " channel_multiplier=2,\n", 330 | " bg_upsampler=upsampler\n", 331 | ")" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 198, 337 | "metadata": { 338 | "colab": { 339 | "base_uri": "https://localhost:8080/" 340 | }, 341 | "id": "ilBtcMcGcUSZ", 342 | "outputId": "063933dc-014d-422b-bc28-f11ffa24ff61" 343 | }, 344 | "outputs": [ 345 | { 346 | "output_type": "stream", 347 | "name": "stderr", 348 | "text": [ 349 | "/usr/local/lib/python3.7/dist-packages/transformers/models/t5/tokenization_t5_fast.py:161: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n", 350 | "For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n", 351 | "- Be aware that you SHOULD NOT rely on t5-3b automatically truncating your input to 512 when padding/encoding.\n", 352 | "- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n", 353 | "- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.\n", 354 | " FutureWarning,\n" 355 | ] 356 | } 357 | ], 358 | "source": [ 359 | "tokenizer = AutoTokenizer.from_pretrained(options['t5_name'])" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": 312, 365 | "metadata": { 366 | "id": "I1E6jjzhvV40" 367 | }, 368 | "outputs": [], 369 | "source": [ 370 | "#@title What do you want to generate?\n", 371 | "\n", 372 | "prompt = 'A photo of cat'#@param {type:\"string\"}" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": 313, 378 | "metadata": { 379 | "id": "h3FFIeLrcuTX" 380 | }, 381 | "outputs": [], 382 | "source": [ 383 | "text_encoding = tokenizer(\n", 384 | " prompt,\n", 385 | " max_length=128,\n", 386 | " padding=\"max_length\",\n", 387 | " truncation=True,\n", 388 | " return_attention_mask=True,\n", 389 | " add_special_tokens=True,\n", 390 | " return_tensors=\"pt\"\n", 391 | ")\n", 392 | "\n", 393 | "uncond_text_encoding = tokenizer(\n", 394 | " '',\n", 395 | " max_length=128,\n", 396 | " padding=\"max_length\",\n", 397 | " truncation=True,\n", 398 | " return_attention_mask=True,\n", 399 | " add_special_tokens=True,\n", 400 | " return_tensors=\"pt\"\n", 401 | ")" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": 314, 407 | "metadata": { 408 | "id": "1Imtn9DVZD99" 409 | }, 410 | "outputs": [], 411 | "source": [ 412 | "import numpy as np\n", 413 | "batch_size = 4\n", 414 | "cond_tokens = th.from_numpy(np.array([text_encoding['input_ids'][0].numpy() for i in range(batch_size)]))\n", 415 | "uncond_tokens = th.from_numpy(np.array([uncond_text_encoding['input_ids'][0].numpy() for i in range(batch_size)]))\n", 416 | "cond_attention_mask = th.from_numpy(np.array([text_encoding['attention_mask'][0].numpy() for i in range(batch_size)]))\n", 417 | "uncond_attention_mask = th.from_numpy(np.array([uncond_text_encoding['attention_mask'][0].numpy() for i in range(batch_size)]))\n", 418 | "model_kwargs = {}\n", 419 | "model_kwargs[\"tokens\"] = th.cat((cond_tokens,\n", 420 | " uncond_tokens)).to(device)\n", 421 | "model_kwargs[\"mask\"] = th.cat((cond_attention_mask,\n", 422 | " uncond_attention_mask)).to(device)" 423 | ] 424 | }, 425 | { 426 | "cell_type": "markdown", 427 | "metadata": { 428 | "id": "y0tzvo1tG9vS" 429 | }, 430 | "source": [ 431 | "## Generation" 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": null, 437 | "metadata": { 438 | "id": "D-4GzOBadCbj" 439 | }, 440 | "outputs": [], 441 | "source": [ 442 | "model.del_cache()\n", 443 | "sample = diffusion.p_sample_loop(\n", 444 | " model_fn,\n", 445 | " (batch_size * 2, 3, 64, 64),\n", 446 | " clip_denoised=True,\n", 447 | " model_kwargs=model_kwargs,\n", 448 | " device='cuda',\n", 449 | " progress=True,\n", 450 | ")[:batch_size]\n", 451 | "model.del_cache()" 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": null, 457 | "metadata": { 458 | "id": "hvsr_S-xMwyG" 459 | }, 460 | "outputs": [], 461 | "source": [ 462 | "show_images(sample)" 463 | ] 464 | }, 465 | { 466 | "cell_type": "code", 467 | "execution_count": null, 468 | "metadata": { 469 | "id": "lsYwSFrE6zrJ" 470 | }, 471 | "outputs": [], 472 | "source": [ 473 | "for i in sample:\n", 474 | " show_images(i.unsqueeze(0))" 475 | ] 476 | }, 477 | { 478 | "cell_type": "code", 479 | "execution_count": 318, 480 | "metadata": { 481 | "id": "h4hoHPBw2DQh" 482 | }, 483 | "outputs": [], 484 | "source": [ 485 | "new_img = get_numpy_img(sample)" 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": null, 491 | "metadata": { 492 | "id": "VAHtiUiyPZWg" 493 | }, 494 | "outputs": [], 495 | "source": [ 496 | "%%time\n", 497 | "for j in range(batch_size):\n", 498 | " new_img = get_numpy_img(sample[j].unsqueeze(0))\n", 499 | " for i in range(1):\n", 500 | " _, _, new_img = face_enhancer.enhance(new_img, has_aligned=False,\n", 501 | " only_center_face=False, paste_back=True)\n", 502 | " cv2.imwrite(f'/content/test_out{j}.jpg', new_img)" 503 | ] 504 | }, 505 | { 506 | "cell_type": "code", 507 | "execution_count": 319, 508 | "metadata": { 509 | "id": "f_BlQypdaQ92" 510 | }, 511 | "outputs": [], 512 | "source": [ 513 | "" 514 | ] 515 | } 516 | ], 517 | "metadata": { 518 | "accelerator": "GPU", 519 | "colab": { 520 | "collapsed_sections": [], 521 | "machine_shape": "hm", 522 | "name": "Imagen_pytorch_inference_new.ipynb", 523 | "provenance": [] 524 | }, 525 | "kernelspec": { 526 | "display_name": "Python 3", 527 | "name": "python3" 528 | }, 529 | "language_info": { 530 | "name": "python" 531 | } 532 | }, 533 | "nbformat": 4, 534 | "nbformat_minor": 0 535 | } -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="imagen_pytorch", 5 | packages=[ 6 | "imagen_pytorch", 7 | "imagen_pytorch.clip", 8 | "imagen_pytorch.tokenizer", 9 | ], 10 | package_data={ 11 | "imagen_pytorch.tokenizer": [ 12 | "bpe_simple_vocab_16e6.txt.gz", 13 | "encoder.json.gz", 14 | "vocab.bpe.gz", 15 | ], 16 | "imagen_pytorch.clip": ["config.yaml"], 17 | }, 18 | install_requires=[ 19 | "Pillow", 20 | "attrs", 21 | "torch", 22 | "filelock", 23 | "requests", 24 | "tqdm", 25 | "ftfy", 26 | "regex", 27 | "numpy", 28 | "blobfile", 29 | "accelerate", 30 | "transformers", 31 | ], 32 | author="cene655", 33 | ) 34 | --------------------------------------------------------------------------------