├── LICENSE ├── README.md ├── celle ├── __init__.py ├── attention.py ├── celle.py ├── reversible.py ├── transformer.py └── vae.py ├── celle_main.py ├── celle_taming_main.py ├── configs ├── celle.yaml ├── nucleus_vqgan.yaml └── threshold_vqgan.yaml ├── data └── aaDescriptors.csv ├── dataloader.py ├── images ├── generate.gif ├── huanglogo.jpeg ├── nucleus.jpg └── preview.png ├── notebooks ├── Demo.ipynb ├── grad_map.py └── visualize_attention.ipynb ├── requirements.txt └── taming ├── lr_scheduler.py ├── models ├── cond_transformer.py ├── dummy_cond_stage.py └── vqgan.py ├── modules ├── autoencoder │ └── lpips │ │ └── vgg.pth ├── diffusionmodules │ └── model.py ├── discriminator │ └── model.py ├── losses │ ├── __init__.py │ ├── lpips.py │ ├── segmentation.py │ └── vqperceptual.py ├── misc │ └── coord.py ├── transformer │ ├── mingpt.py │ └── permuter.py ├── util.py └── vqvae │ └── quantize.py └── util.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Emaad Khwaja, Yun Song, & Bo Huang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |

3 | 4 |
5 |

6 | 7 | # CELL-E: Biological Zero-Shot Text-to-Image Synthesis for Protein Localization Prediction 8 | 9 | This repository is the official implementation of [CELL-E: Biological Zero-Shot Text-to-Image Synthesis for Protein Localization Prediction](https://www.biorxiv.org/content/10.1101/2022.05.27.493774v1). 10 | 11 | ![schematic](images/preview.png) 12 | 13 |

14 |

15 | 16 |
17 |

18 | 19 | 20 | ## Requirements 21 | 22 | Create a virtual environment and install the required packages via: 23 | 24 | ```setup 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | Next, install ```torch = 1.7.1``` and ```torchvision==0.8.2``` with the [appropriate CUDA version](https://pytorch.org/get-started/previous-versions/#v171) 29 | 30 | ## Preparing Dataset 31 | 32 | ### Images 33 | 34 | We used OpenCell for CELL-E, which has [information on downloading the entire dataset](https://opencell.czbiohub.org/download). A ```data_csv``` is needed to for the dataloader. You must generate a csv file which contains the columns ```nucleus_image_path```, ```protein_image_path```, ```metadata_path```, and ```split``` (train or val). It is assumed that this file exists within the the same general ```data``` folder as the images and metadata files. 35 | 36 | ### Metadata 37 | 38 | Metadata is a JSON which should accompany every protein sequence. If a sequence does not appear in the ```data_csv```, it must appear in ```metadata.json``` with the a key named ```protein_sequence```. 39 | 40 | Adding more information here can be useful for querying individual proteins. They can be retrieved via ```retrieve_metadata```, which creates a ```self.metadata``` variable within the dataset object. 41 | 42 | ## Training 43 | 44 | Training for CELL-E occurs in 2 (or 3) stages: 45 | 46 | - Training Protein Threshold Image encoder 47 | - (Optional, but recommended) Training a Nucleus Image (Conditional Image) Encoder 48 | - Training CELL-E Transformer 49 | 50 | ### Image Encoders 51 | 52 | There are two available image encoders in this repository: Discrete VAE (Similar to the original OpenAI implementation) and VQGAN (recommended). If using the protein threshold image, set ```threshold: True``` for the dataset. 53 | 54 | #### Discrete VAE 55 | 56 | The discrete VAE can be trained using the following code: 57 | 58 | ```python 59 | from celle import DiscreteVAE 60 | 61 | vae = DiscreteVAE( 62 | image_size = 256, 63 | num_layers = 3, # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map) 64 | channels = 1, 65 | num_tokens = 512, # number of visual tokens. in the paper, they used 8192, but could be smaller for downsized projects 66 | codebook_dim = 512, # codebook dimension 67 | hidden_dim = 64, # hidden dimension 68 | num_resnet_blocks = 1, # number of resnet blocks 69 | temperature = 0.9, # gumbel softmax temperature, the lower this is, the harder the discretization 70 | straight_through = False, # straight-through for gumbel softmax. unclear if it is better one way or the other) 71 | 72 | loss = vae(images, return_loss = True) 73 | loss.backward() 74 | ``` 75 | 76 | #### VQGAN (Recommended) 77 | 78 | We use a slightly modified version of the [taming-transformers](https://github.com/CompVis/taming-transformers) code. 79 | 80 | To train, run the following script: 81 | 82 | ```python celle_taming_main.py --base configs/threshold_vqgan.yaml -t True``` 83 | 84 | Please refer to the original repo for additional flags, such as ```--gpus```. 85 | 86 | 87 | ### CELL-E 88 | 89 | To train, run the following script: 90 | 91 | ```python celle_main.py --base configs/celle.yaml -t True``` 92 | 93 | Specify ```--gpus``` in the same format as VQGAN. 94 | 95 | CELL-E contains the following options from [dalle-pytorch](https://github.com/lucidrains/DALLE-pytorch). 96 | 97 | - ```ckpt_path``` : Resume previous CELL-E training. Saved model with state_dict 98 | - ```vqgan_model_path``` : Saved protein image model (with state_dict) for protein image encoder 99 | - ```vqgan_config_path```: Saved protein image model yaml 100 | - ```condition_model_path``` : (Optional) Saved condition (nucleus) model (with state_dict) for protein image encoder 101 | - ```condition_config_path```: (Optional) Saved condition (nucleus) model yaml 102 | - ```num_images```: 1 if only using protein image encoder, 2 if including condition image encoder 103 | - ```image_key```: ```nucleus```, ```target```, or ```threshold``` 104 | - ```dim```: Dimension of language model embedding (768 for BERT) 105 | - ```num_text_tokens```: total number of tokens in language model (30 for BERT) 106 | - ```text_seq_len```: Total number of amino acids considered 107 | 108 | 109 | - ```depth```: Transformer model depth, deeper is usually better at the cost of VRAM 110 | - ```heads```: number of heads used in multi-headed attention 111 | - ```dim_head```: size of attention heads 112 | - ```reversible```: See [https://github.com/lucidrains/DALLE-pytorch#scaling-depth](https://github.com/lucidrains/DALLE-pytorch#scaling-depth) 113 | - ```attn_dropout```: Attention Dropout rate in training 114 | - ```ff_dropout```: Feed-Forward Dropout rate in training 115 | - ```attn_types```: See [https://github.com/lucidrains/DALLE-pytorch#sparse-attention](https://github.com/lucidrains/DALLE-pytorch#sparse-attention). Sparse attention not supported 116 | - ```loss_img_weight```: Weighting applied to image reconstruction. text weight = 1 117 | - ```loss_cond_weight```: Weighting applied to condition image reconstruction. 118 | - ```stable```: Norms weights (for when exploding gradients occur) 119 | - ```sandwich_norm```: See [https://github.com/lucidrains/x-transformers#sandwich-norm](https://github.com/lucidrains/x-transformers#sandwich-norm) 120 | - ```shift_tokens```: Applies shift in feature dimension. Only applied to images. 121 | - ```rotary_emb```: [Rotary embedding](https://github.com/lucidrains/x-transformers#rotary-positional-embeddings) scheme for positional encoding 122 | - ```text_embedding```: language used for model. ```no_text```, ```unirep```, ```bert```, ```esm1b```, ```onehot```, ```aadescriptors``` available 123 | - ```fixed_embedding```: Setting to ```True``` allows for protein sequence embeddings to be updated during training 124 | - ```learning_rate```: Learning rate for Adam optimizer 125 | - ```monitor```: Param used to save models 126 | 127 | ## Generating Images 128 | 129 | To generate images, set the saved model as the ckpt_path. This method can be unstable, so refer to ```Demo.ipynb``` to see another way of loading. 130 | 131 | ```python 132 | import OmegaConf 133 | from celle_main import instantiate_from_config 134 | 135 | configs = OmegaConf.load(configs/celle.yaml); 136 | 137 | model = instantiate_from_config(configs.model).to(device); 138 | 139 | model.generate_images(text=sequence, 140 | condition=nucleus, 141 | return_logits=True, 142 | progress=True, 143 | use_cache=True) 144 | ``` 145 | 146 | ## Citation 147 | 148 | Please cite us if you decide to use our code for any part of your research. 149 | ``` 150 | CELL-E: Biological Zero-Shot Text-to-Image Synthesis for Protein Localization Prediction 151 | Emaad Khwaja, Yun S. Song, Bo Huang 152 | bioRxiv 2022.05.27.493774; doi: https://doi.org/10.1101/2022.05.27.493774 153 | ``` 154 | 155 | ## Acknowledgements 156 | 157 | Huge shoutout to [@lucidrains](https://github.com/lucidrains) for putting out [dalle-pytorch](https://github.com/lucidrains/DALLE-pytorch), which this code is based on. This work would not have been possible without the invaluable contribution. 158 | -------------------------------------------------------------------------------- /celle/__init__.py: -------------------------------------------------------------------------------- 1 | from celle.celle import CELLE, DiscreteVAE 2 | from celle.vae import VQGanVAE 3 | 4 | from pkg_resources import get_distribution 5 | 6 | # __version__ = get_distribution('dalle_pytorch').version 7 | __version__ = "1.4.1" 8 | -------------------------------------------------------------------------------- /celle/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | 3 | import torch 4 | from torch import nn, einsum 5 | import torch.nn.functional as F 6 | from einops import rearrange, repeat 7 | 8 | from rotary_embedding_torch import apply_rotary_emb 9 | 10 | # helpers 11 | 12 | 13 | def exists(val): 14 | return val is not None 15 | 16 | 17 | def uniq(arr): 18 | return {el: True for el in arr}.keys() 19 | 20 | 21 | def default(val, d): 22 | if exists(val): 23 | return val 24 | return d() if isfunction(d) else d 25 | 26 | 27 | def max_neg_value(t): 28 | return -torch.finfo(t.dtype).max 29 | 30 | 31 | def stable_softmax(t, dim=-1, alpha=32 ** 2): 32 | t = t / alpha 33 | t = t - torch.amax(t, dim=dim, keepdim=True).detach() 34 | return (t * alpha).softmax(dim=dim) 35 | 36 | 37 | def apply_pos_emb(pos_emb, qkv): 38 | n = qkv[0].shape[-2] 39 | pos_emb = pos_emb[..., :n, :] 40 | return tuple(map(lambda t: apply_rotary_emb(pos_emb, t), qkv)) 41 | 42 | 43 | # classes 44 | 45 | 46 | class Attention(nn.Module): 47 | def __init__( 48 | self, 49 | dim, 50 | seq_len, 51 | causal=True, 52 | heads=8, 53 | dim_head=64, 54 | dropout=0.0, 55 | stable=False, 56 | static_mask=None, 57 | ): 58 | super().__init__() 59 | inner_dim = dim_head * heads 60 | self.heads = heads 61 | self.seq_len = seq_len 62 | self.scale = dim_head ** -0.5 63 | 64 | self.stable = stable 65 | self.causal = causal 66 | self.register_buffer("static_mask", static_mask, persistent=False) 67 | 68 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 69 | self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) 70 | 71 | self.save_attn = nn.Identity() 72 | 73 | def forward(self, x, mask=None, rotary_pos_emb=None, cache=None, cache_key=None): 74 | b, n, _, h, device = *x.shape, self.heads, x.device 75 | softmax = torch.softmax if not self.stable else stable_softmax 76 | 77 | offset = cache.get("offset", 0) if exists(cache) else 0 78 | 79 | qkv = self.to_qkv(x).chunk(3, dim=-1) 80 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv) 81 | 82 | if exists(rotary_pos_emb): 83 | q, k, v = apply_pos_emb(rotary_pos_emb[..., offset:, :], (q, k, v)) 84 | 85 | q = q * self.scale 86 | 87 | if offset > 0: 88 | k_top, v_top = cache[cache_key] 89 | k = torch.cat([k_top, k], dim=-2) 90 | v = torch.cat([v_top, v], dim=-2) 91 | if exists(cache): 92 | cache[cache_key] = k, v 93 | 94 | dots = torch.einsum("b h i d, b h j d -> b h i j", q, k) 95 | mask_value = max_neg_value(dots) 96 | 97 | if exists(mask): 98 | mask = rearrange(mask, "b j -> b () () j") 99 | dots.masked_fill_(~mask, mask_value) 100 | del mask 101 | 102 | if ( 103 | self.causal and offset == 0 104 | ): # causality is naturally enforced for the cached inference 105 | i, j = dots.shape[-2:] 106 | mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool() 107 | dots.masked_fill_(mask, mask_value) 108 | 109 | if exists(self.static_mask): 110 | dots.masked_fill_( 111 | ~self.static_mask[offset : offset + n, : offset + n], mask_value 112 | ) 113 | 114 | attn = softmax(dots, dim=-1) 115 | 116 | self.save_attn(attn) 117 | 118 | out = torch.einsum("b h i j, b h j d -> b h i d", attn, v) 119 | out = rearrange(out, "b h n d -> b n (h d)") 120 | out = self.to_out(out) 121 | return out 122 | 123 | 124 | # sparse attention with convolutional pattern, as mentioned in the blog post. customizable kernel size and dilation 125 | 126 | 127 | class SparseConvCausalAttention(nn.Module): 128 | def __init__( 129 | self, 130 | dim, 131 | seq_len, 132 | image_size=32, 133 | kernel_size=5, 134 | dilation=1, 135 | heads=8, 136 | dim_head=64, 137 | dropout=0.0, 138 | stable=False, 139 | **kwargs, 140 | ): 141 | super().__init__() 142 | assert kernel_size % 2 == 1, "kernel size must be odd" 143 | 144 | inner_dim = dim_head * heads 145 | self.seq_len = seq_len 146 | self.heads = heads 147 | self.scale = dim_head ** -0.5 148 | self.image_size = image_size 149 | self.kernel_size = kernel_size 150 | self.dilation = dilation 151 | 152 | self.stable = stable 153 | 154 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 155 | 156 | self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) 157 | 158 | def forward(self, x, mask=None, rotary_pos_emb=None): 159 | b, n, _, h, img_size, kernel_size, dilation, seq_len, device = ( 160 | *x.shape, 161 | self.heads, 162 | self.image_size, 163 | self.kernel_size, 164 | self.dilation, 165 | self.seq_len, 166 | x.device, 167 | ) 168 | softmax = torch.softmax if not self.stable else stable_softmax 169 | 170 | img_seq_len = img_size ** 2 171 | text_len = seq_len + 1 - img_seq_len 172 | 173 | # padding 174 | 175 | padding = seq_len - n + 1 176 | mask = default(mask, lambda: torch.ones(b, text_len, device=device).bool()) 177 | 178 | x = F.pad(x, (0, 0, 0, padding), value=0) 179 | mask = mask[:, :text_len] 180 | 181 | # derive query / keys / values 182 | 183 | qkv = self.to_qkv(x).chunk(3, dim=-1) 184 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), qkv) 185 | 186 | if exists(rotary_pos_emb): 187 | q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v)) 188 | 189 | q *= self.scale 190 | 191 | ((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map( 192 | lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v) 193 | ) 194 | 195 | # text attention 196 | 197 | dots_text = einsum("b i d, b j d -> b i j", q_text, k_text) 198 | mask_value = max_neg_value(dots_text) 199 | 200 | i, j = dots_text.shape[-2:] 201 | text_causal_mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool() 202 | dots_text.masked_fill_(text_causal_mask, mask_value) 203 | 204 | attn_text = softmax(dots_text, dim=-1) 205 | out_text = einsum("b i j, b j d -> b i d", attn_text, v_text) 206 | 207 | # image attention 208 | 209 | effective_kernel_size = (kernel_size - 1) * dilation + 1 210 | padding = effective_kernel_size // 2 211 | 212 | k_img, v_img = map( 213 | lambda t: rearrange(t, "b (h w) c -> b c h w", h=img_size), (k_img, v_img) 214 | ) 215 | k_img, v_img = map( 216 | lambda t: F.unfold(t, kernel_size, padding=padding, dilation=dilation), 217 | (k_img, v_img), 218 | ) 219 | k_img, v_img = map( 220 | lambda t: rearrange(t, "b (d j) i -> b i j d", j=kernel_size ** 2), 221 | (k_img, v_img), 222 | ) 223 | 224 | # let image attend to all of text 225 | 226 | dots_image = einsum("b i d, b i j d -> b i j", q_img, k_img) 227 | dots_image_to_text = einsum("b i d, b j d -> b i j", q_img, k_text) 228 | 229 | # calculate causal attention for local convolution 230 | 231 | i, j = dots_image.shape[-2:] 232 | img_seq = torch.arange(img_seq_len, device=device) 233 | k_img_indices = rearrange(img_seq.float(), "(h w) -> () () h w", h=img_size) 234 | k_img_indices = F.pad( 235 | k_img_indices, (padding,) * 4, value=img_seq_len 236 | ) # padding set to be max, so it is never attended to 237 | k_img_indices = F.unfold(k_img_indices, kernel_size, dilation=dilation) 238 | k_img_indices = rearrange(k_img_indices, "b j i -> b i j") 239 | 240 | # mask image attention 241 | 242 | q_img_indices = rearrange(img_seq, "i -> () i ()") 243 | causal_mask = q_img_indices < k_img_indices 244 | 245 | # concat text mask with image causal mask 246 | 247 | causal_mask = repeat(causal_mask, "() i j -> b i j", b=b * h) 248 | mask = repeat(mask, "b j -> (b h) i j", i=i, h=h) 249 | mask = torch.cat((~mask, causal_mask), dim=-1) 250 | 251 | # image can attend to all of text 252 | 253 | dots = torch.cat((dots_image_to_text, dots_image), dim=-1) 254 | dots.masked_fill_(mask, mask_value) 255 | 256 | attn = softmax(dots, dim=-1) 257 | 258 | # aggregate 259 | 260 | attn_image_to_text, attn_image = attn[..., :text_len], attn[..., text_len:] 261 | 262 | out_image_to_image = einsum("b i j, b i j d -> b i d", attn_image, v_img) 263 | out_image_to_text = einsum("b i j, b j d -> b i d", attn_image_to_text, v_text) 264 | 265 | out_image = out_image_to_image + out_image_to_text 266 | 267 | # combine attended values for both text and image 268 | 269 | out = torch.cat((out_text, out_image), dim=1) 270 | 271 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 272 | out = self.to_out(out) 273 | return out[:, :n] 274 | 275 | 276 | # sparse axial causal attention 277 | 278 | 279 | class SparseAxialCausalAttention(nn.Module): 280 | def __init__( 281 | self, 282 | dim, 283 | seq_len, 284 | image_size=32, 285 | axis=0, 286 | heads=8, 287 | dim_head=64, 288 | dropout=0.0, 289 | stable=False, 290 | **kwargs, 291 | ): 292 | super().__init__() 293 | assert axis in {0, 1}, "axis must be either 0 (along height) or 1 (along width)" 294 | self.axis = axis 295 | 296 | inner_dim = dim_head * heads 297 | self.seq_len = seq_len 298 | self.heads = heads 299 | self.scale = dim_head ** -0.5 300 | self.image_size = image_size 301 | 302 | self.stable = stable 303 | 304 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 305 | 306 | self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) 307 | 308 | def forward(self, x, mask=None, rotary_pos_emb=None): 309 | b, n, _, h, img_size, axis, seq_len, device = ( 310 | *x.shape, 311 | self.heads, 312 | self.image_size, 313 | self.axis, 314 | self.seq_len, 315 | x.device, 316 | ) 317 | softmax = torch.softmax if not self.stable else stable_softmax 318 | 319 | img_seq_len = img_size ** 2 320 | text_len = seq_len + 1 - img_seq_len 321 | 322 | # padding 323 | 324 | padding = seq_len - n + 1 325 | mask = default(mask, lambda: torch.ones(b, text_len, device=device).bool()) 326 | 327 | x = F.pad(x, (0, 0, 0, padding), value=0) 328 | mask = mask[:, :text_len] 329 | 330 | # derive queries / keys / values 331 | 332 | qkv = self.to_qkv(x).chunk(3, dim=-1) 333 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), qkv) 334 | 335 | if exists(rotary_pos_emb): 336 | q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v)) 337 | 338 | q *= self.scale 339 | 340 | ((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map( 341 | lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v) 342 | ) 343 | 344 | # text attention 345 | 346 | dots_text = einsum("b i d, b j d -> b i j", q_text, k_text) 347 | mask_value = max_neg_value(dots_text) 348 | 349 | i, j = dots_text.shape[-2:] 350 | text_causal_mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool() 351 | dots_text.masked_fill_(text_causal_mask, mask_value) 352 | 353 | attn_text = softmax(dots_text, dim=-1) 354 | out_text = einsum("b i j, b j d -> b i d", attn_text, v_text) 355 | 356 | # image attention 357 | 358 | split_axis_einops = ( 359 | "b (h w) c -> b h w c" if axis == 0 else "b (h w) c -> b w h c" 360 | ) 361 | merge_axis_einops = ( 362 | "b x n d -> b (x n) d" if axis == 0 else "b x n d -> b (n x) d" 363 | ) 364 | 365 | # split out axis 366 | 367 | q_img, k_img, v_img = map( 368 | lambda t: rearrange(t, split_axis_einops, h=img_size), (q_img, k_img, v_img) 369 | ) 370 | 371 | # similarity 372 | 373 | dots_image_to_image = einsum("b x i d, b x j d -> b x i j", q_img, k_img) 374 | dots_image_to_text = einsum("b x i d, b j d -> b x i j", q_img, k_text) 375 | 376 | dots = torch.cat((dots_image_to_text, dots_image_to_image), dim=-1) 377 | 378 | # mask so image has full attention to text, but causal along axis 379 | 380 | bh, x, i, j = dots.shape 381 | causal_mask = ( 382 | torch.ones(i, img_size, device=device).triu_(img_size - i + 1).bool() 383 | ) 384 | causal_mask = repeat(causal_mask, "i j -> b x i j", b=bh, x=x) 385 | 386 | mask = repeat(mask, "b j -> (b h) x i j", h=h, x=x, i=i) 387 | mask = torch.cat((~mask, causal_mask), dim=-1) 388 | 389 | dots.masked_fill_(mask, mask_value) 390 | 391 | # attention. 392 | 393 | attn = softmax(dots, dim=-1) 394 | 395 | # aggregate 396 | 397 | attn_image_to_text, attn_image_to_image = ( 398 | attn[..., :text_len], 399 | attn[..., text_len:], 400 | ) 401 | 402 | out_image_to_image = einsum( 403 | "b x i j, b x j d -> b x i d", attn_image_to_image, v_img 404 | ) 405 | out_image_to_text = einsum( 406 | "b x i j, b j d -> b x i d", attn_image_to_text, v_text 407 | ) 408 | 409 | out_image = out_image_to_image + out_image_to_text 410 | 411 | # merge back axis 412 | 413 | out_image = rearrange(out_image, merge_axis_einops, x=img_size) 414 | 415 | # combine attended values for both text and image 416 | 417 | out = torch.cat((out_text, out_image), dim=1) 418 | 419 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 420 | out = self.to_out(out) 421 | return out[:, :n] -------------------------------------------------------------------------------- /celle/reversible.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from operator import itemgetter 4 | from torch.autograd.function import Function 5 | from torch.utils.checkpoint import get_device_states, set_device_states 6 | 7 | # for routing arguments into the functions of the reversible layer 8 | def route_args(router, args, depth): 9 | routed_args = [(dict(), dict()) for _ in range(depth)] 10 | matched_keys = [key for key in args.keys() if key in router] 11 | 12 | for key in matched_keys: 13 | val = args[key] 14 | for depth, ((f_args, g_args), routes) in enumerate( 15 | zip(routed_args, router[key]) 16 | ): 17 | new_f_args, new_g_args = map( 18 | lambda route: ({key: val} if route else {}), routes 19 | ) 20 | routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) 21 | return routed_args 22 | 23 | 24 | # following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html 25 | class Deterministic(nn.Module): 26 | def __init__(self, net): 27 | super().__init__() 28 | self.net = net 29 | self.cpu_state = None 30 | self.cuda_in_fwd = None 31 | self.gpu_devices = None 32 | self.gpu_states = None 33 | 34 | def record_rng(self, *args): 35 | self.cpu_state = torch.get_rng_state() 36 | if torch.cuda._initialized: 37 | self.cuda_in_fwd = True 38 | self.gpu_devices, self.gpu_states = get_device_states(*args) 39 | 40 | def forward(self, *args, record_rng=False, set_rng=False, **kwargs): 41 | if record_rng: 42 | self.record_rng(*args) 43 | 44 | if not set_rng: 45 | return self.net(*args, **kwargs) 46 | 47 | rng_devices = [] 48 | if self.cuda_in_fwd: 49 | rng_devices = self.gpu_devices 50 | 51 | with torch.random.fork_rng(devices=rng_devices, enabled=True): 52 | torch.set_rng_state(self.cpu_state) 53 | if self.cuda_in_fwd: 54 | set_device_states(self.gpu_devices, self.gpu_states) 55 | return self.net(*args, **kwargs) 56 | 57 | 58 | # heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 59 | # once multi-GPU is confirmed working, refactor and send PR back to source 60 | class ReversibleBlock(nn.Module): 61 | def __init__(self, f, g): 62 | super().__init__() 63 | self.f = Deterministic(f) 64 | self.g = Deterministic(g) 65 | 66 | def forward(self, x, f_args={}, g_args={}): 67 | x1, x2 = torch.chunk(x, 2, dim=2) 68 | y1, y2 = None, None 69 | 70 | with torch.no_grad(): 71 | y1 = x1 + self.f(x2, record_rng=self.training, **f_args) 72 | y2 = x2 + self.g(y1, record_rng=self.training, **g_args) 73 | 74 | return torch.cat([y1, y2], dim=2) 75 | 76 | def backward_pass(self, y, dy, f_args={}, g_args={}): 77 | y1, y2 = torch.chunk(y, 2, dim=2) 78 | del y 79 | 80 | dy1, dy2 = torch.chunk(dy, 2, dim=2) 81 | del dy 82 | 83 | with torch.enable_grad(): 84 | y1.requires_grad = True 85 | gy1 = self.g(y1, set_rng=True, **g_args) 86 | torch.autograd.backward(gy1, dy2) 87 | 88 | with torch.no_grad(): 89 | x2 = y2 - gy1 90 | del y2, gy1 91 | 92 | dx1 = dy1 + y1.grad 93 | del dy1 94 | y1.grad = None 95 | 96 | with torch.enable_grad(): 97 | x2.requires_grad = True 98 | fx2 = self.f(x2, set_rng=True, **f_args) 99 | torch.autograd.backward(fx2, dx1, retain_graph=True) 100 | 101 | with torch.no_grad(): 102 | x1 = y1 - fx2 103 | del y1, fx2 104 | 105 | dx2 = dy2 + x2.grad 106 | del dy2 107 | x2.grad = None 108 | 109 | x = torch.cat([x1, x2.detach()], dim=2) 110 | dx = torch.cat([dx1, dx2], dim=2) 111 | 112 | return x, dx 113 | 114 | 115 | class _ReversibleFunction(Function): 116 | @staticmethod 117 | def forward(ctx, x, blocks, args): 118 | ctx.args = args 119 | for block, kwarg in zip(blocks, args): 120 | x = block(x, **kwarg) 121 | ctx.y = x.detach() 122 | ctx.blocks = blocks 123 | return x 124 | 125 | @staticmethod 126 | def backward(ctx, dy): 127 | y = ctx.y 128 | args = ctx.args 129 | for block, kwargs in zip(ctx.blocks[::-1], args[::-1]): 130 | y, dy = block.backward_pass(y, dy, **kwargs) 131 | return dy, None, None 132 | 133 | 134 | class SequentialSequence(nn.Module): 135 | def __init__(self, layers, args_route={}, layer_dropout=0.0): 136 | super().__init__() 137 | assert all( 138 | len(route) == len(layers) for route in args_route.values() 139 | ), "each argument route map must have the same depth as the number of sequential layers" 140 | self.layers = layers 141 | self.args_route = args_route 142 | self.layer_dropout = layer_dropout 143 | 144 | def forward(self, x, **kwargs): 145 | args = route_args(self.args_route, kwargs, len(self.layers)) 146 | layers_and_args = list(zip(self.layers, args)) 147 | 148 | for (f, g), (f_args, g_args) in layers_and_args: 149 | x = x + f(x, **f_args) 150 | x = x + g(x, **g_args) 151 | return x 152 | 153 | 154 | class ReversibleSequence(nn.Module): 155 | def __init__(self, blocks, args_route={}): 156 | super().__init__() 157 | self.args_route = args_route 158 | self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks]) 159 | 160 | def forward(self, x, **kwargs): 161 | x = torch.cat([x, x], dim=-1) 162 | 163 | blocks = self.blocks 164 | args = route_args(self.args_route, kwargs, len(blocks)) 165 | args = list(map(lambda x: {"f_args": x[0], "g_args": x[1]}, args)) 166 | 167 | out = _ReversibleFunction.apply(x, blocks, args) 168 | return torch.stack(out.chunk(2, dim=-1)).mean(dim=0) 169 | -------------------------------------------------------------------------------- /celle/transformer.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from functools import partial 3 | from itertools import islice, cycle 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | from einops import rearrange 9 | 10 | from celle.reversible import ReversibleSequence, SequentialSequence 11 | from celle.attention import ( 12 | Attention, 13 | SparseConvCausalAttention, 14 | SparseAxialCausalAttention, 15 | ) 16 | 17 | from rotary_embedding_torch import RotaryEmbedding, broadcat 18 | 19 | # helpers 20 | 21 | 22 | def exists(val): 23 | return val is not None 24 | 25 | 26 | def default(val, d): 27 | return val if exists(val) else d 28 | 29 | 30 | def cast_tuple(val, depth=1): 31 | if isinstance(val, list): 32 | val = tuple(val) 33 | return val if isinstance(val, tuple) else (val,) * depth 34 | 35 | 36 | # classes 37 | 38 | 39 | class DivideMax(nn.Module): 40 | def __init__(self, dim): 41 | super().__init__() 42 | self.dim = dim 43 | 44 | def forward(self, x): 45 | maxes = x.amax(dim=self.dim, keepdim=True).detach() 46 | return x / maxes 47 | 48 | 49 | class NonCached(nn.Module): 50 | """ 51 | A wrapper for layers that don't support the inference cache themselves. 52 | Reconstructs the full sequence before the layer and 53 | cuts the suffix of the outputs after the layer. 54 | """ 55 | 56 | def __init__(self, fn): 57 | super().__init__() 58 | self.fn = fn 59 | 60 | def forward(self, x, *, cache=None, cache_key=None, **kwargs): 61 | n = x.shape[-2] 62 | if exists(cache): 63 | if cache_key in cache: 64 | x = torch.cat([cache[cache_key], x], dim=-2) 65 | cache[cache_key] = x 66 | 67 | out = self.fn(x, **kwargs) 68 | 69 | return out[:, -n:] 70 | 71 | 72 | class CachedAs(nn.Module): 73 | """ 74 | A wrapper that defines a key for the inference cache. 75 | """ 76 | 77 | def __init__(self, cache_key, fn): 78 | super().__init__() 79 | self.cache_key = cache_key 80 | self.fn = fn 81 | 82 | def forward(self, x, *, cache=None, **kwargs): 83 | return self.fn(x, cache=cache, cache_key=self.cache_key, **kwargs) 84 | 85 | 86 | # https://arxiv.org/abs/2103.17239 87 | class LayerScale(nn.Module): 88 | def __init__(self, dim, depth, fn): 89 | super().__init__() 90 | if depth <= 18: 91 | init_eps = 0.1 92 | elif depth > 18 and depth <= 24: 93 | init_eps = 1e-5 94 | else: 95 | init_eps = 1e-6 96 | 97 | scale = torch.zeros(1, 1, dim).fill_(init_eps) 98 | self.scale = nn.Parameter(scale) 99 | self.fn = fn 100 | 101 | def forward(self, x, **kwargs): 102 | return self.fn(x, **kwargs) * self.scale 103 | 104 | 105 | # layer norm 106 | 107 | 108 | class PreNorm(nn.Module): 109 | def __init__(self, dim, fn, sandwich=False): 110 | super().__init__() 111 | self.norm = nn.LayerNorm(dim) 112 | self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity() 113 | self.fn = fn 114 | 115 | def forward(self, x, **kwargs): 116 | x = self.norm(x) 117 | x = self.fn(x, **kwargs) 118 | return self.norm_out(x) 119 | 120 | 121 | # feed forward 122 | 123 | 124 | class GEGLU(nn.Module): 125 | def forward(self, x): 126 | x, gates = x.chunk(2, dim=-1) 127 | return x * F.gelu(gates) 128 | 129 | 130 | class FeedForward(nn.Module): 131 | def __init__(self, dim, dropout=0.0, mult=4.0): 132 | super().__init__() 133 | self.net = nn.Sequential( 134 | nn.Linear(dim, dim * mult * 2), 135 | GEGLU(), 136 | nn.Dropout(dropout), 137 | nn.Linear(dim * mult, dim), 138 | ) 139 | 140 | def forward(self, x, cache=None, cache_key=None): 141 | return self.net(x) 142 | 143 | 144 | # token shift classes 145 | 146 | 147 | class PreShiftToken(nn.Module): 148 | def __init__(self, fn, image_size, num_images, seq_len): 149 | super().__init__() 150 | self.fn = fn 151 | self.image_size = image_size 152 | self.seq_len = seq_len 153 | img_seq_len = ((image_size // num_images) ** 2) * num_images 154 | self.text_len = seq_len - img_seq_len + 1 155 | 156 | def forward(self, x, cache=None, cache_key=None, **kwargs): 157 | 158 | seq_len, image_size, text_len = self.seq_len, self.image_size, self.text_len 159 | 160 | if exists(cache) and cache_key in cache: 161 | offset = cache["offset"] 162 | assert offset >= text_len, "cached inference for text is not supported" 163 | q = cache[cache_key] 164 | assert isinstance(q, deque) and len(q) == image_size 165 | 166 | x_top, x_left, *x_pass = x[:, -1].chunk(4, dim=-1) 167 | 168 | q.append((x_top, x_left)) 169 | x_top = q.popleft()[0] 170 | x_left = q[-2][1] 171 | if (offset - text_len) % image_size == 0: 172 | x_left = torch.zeros_like(x_left) 173 | 174 | x = torch.cat((x_top, x_left, *x_pass), dim=-1) 175 | return self.fn(x[:, None], cache=cache, **kwargs) 176 | 177 | n = x.shape[1] 178 | 179 | padding = seq_len - n + 1 180 | 181 | if n < text_len: 182 | return self.fn(x, **kwargs) 183 | 184 | # get text and image tokens 185 | 186 | x_text, x_img = x[:, :text_len], x[:, text_len:] 187 | x_img = F.pad(x_img, (0, 0, 0, padding)) 188 | x_img = rearrange(x_img, "b (h w) d -> b h w d", h=image_size) 189 | 190 | # shift 1 from the left for text tokens 191 | 192 | # x_text_shift, x_text_pass = x_text.chunk(2, dim=-1) 193 | # x_text_shift = F.pad(x_text_shift, (0, 0, 1, -1)) 194 | # x_text = torch.cat((x_text_shift, x_text_pass), dim=-1) 195 | 196 | # shift from top, left for image tokens 197 | 198 | x_img_shift_top, x_img_shift_left, *x_img_pass = x_img.chunk(4, dim=-1) 199 | x_img_shift_left = F.pad(x_img_shift_left, (0, 0, 1, -1)) 200 | x_img_shift_top = F.pad(x_img_shift_top, (0, 0, 0, 0, 1, -1)) 201 | x_img = torch.cat((x_img_shift_top, x_img_shift_left, *x_img_pass), dim=-1) 202 | 203 | # merge text and image sequence back together 204 | 205 | x_img = rearrange(x_img, "b h w d -> b (h w) d") 206 | x_img = x_img[:, :-padding] 207 | x = torch.cat((x_text, x_img), dim=1) 208 | 209 | if exists(cache): 210 | dummy_top, dummy_left, *_ = x[:, -1].chunk(4, dim=-1) 211 | dummy_top, dummy_left = torch.zeros_like(dummy_top), torch.zeros_like( 212 | dummy_left 213 | ) 214 | 215 | q = deque() 216 | x_img = x_img[:, -image_size:] 217 | for _ in range(image_size - x_img.shape[1]): 218 | q.append((dummy_top, dummy_left)) 219 | for i in range(x_img.shape[1]): 220 | q.append(x_img[:, i].chunk(4, dim=-1)[:2]) 221 | cache[cache_key] = q 222 | 223 | return self.fn(x, cache=cache, **kwargs) 224 | 225 | 226 | # main transformer class 227 | 228 | 229 | class Transformer(nn.Module): 230 | def __init__( 231 | self, 232 | *, 233 | dim, 234 | depth, 235 | seq_len, 236 | reversible=False, 237 | causal=True, 238 | heads=8, 239 | dim_head=64, 240 | ff_mult=4, 241 | attn_dropout=0.0, 242 | ff_dropout=0.0, 243 | attn_types=None, 244 | image_fmap_size=None, 245 | num_images=None, 246 | stable=False, 247 | sandwich_norm=False, 248 | shift_tokens=False, 249 | rotary_emb=True, 250 | shared_attn_ids=None, 251 | shared_ff_ids=None, 252 | optimize_for_inference=False, # use cache-friendly masked attention instead of sparse one 253 | ): 254 | super().__init__() 255 | layers = nn.ModuleList([]) 256 | 257 | self.seq_len = seq_len 258 | self.image_fmap_size = image_fmap_size 259 | 260 | attn_types = default(attn_types, ("full",)) 261 | attn_types = cast_tuple(attn_types) 262 | attn_type_layer = islice(cycle(attn_types), depth) 263 | 264 | shared_attn_ids = cycle(default(shared_attn_ids, range(depth))) 265 | shared_ff_ids = cycle(default(shared_ff_ids, range(depth))) 266 | shared_attn_layers = {} 267 | shared_ff_layers = {} 268 | 269 | for (ind, attn_type, attn_id, ff_id) in zip( 270 | range(depth), attn_type_layer, shared_attn_ids, shared_ff_ids 271 | ): 272 | 273 | if attn_type == "full": 274 | attn_class = partial(Attention, stable=stable) 275 | 276 | elif attn_type == "axial_row": 277 | if optimize_for_inference: 278 | attn_class = partial( 279 | Attention, 280 | stable=stable, 281 | static_mask=self._get_attention_mask(attn_type), 282 | ) 283 | else: 284 | attn_class = partial( 285 | SparseAxialCausalAttention, 286 | seq_len=seq_len, 287 | axis=0, 288 | image_size=image_fmap_size, 289 | stable=stable, 290 | ) 291 | elif attn_type == "axial_col": 292 | if optimize_for_inference: 293 | attn_class = partial( 294 | Attention, 295 | stable=stable, 296 | static_mask=self._get_attention_mask(attn_type), 297 | ) 298 | else: 299 | attn_class = partial( 300 | SparseAxialCausalAttention, 301 | seq_len=seq_len, 302 | axis=1, 303 | image_size=image_fmap_size, 304 | stable=stable, 305 | ) 306 | elif attn_type == "conv_like": 307 | attn_class = partial( 308 | SparseConvCausalAttention, 309 | seq_len=seq_len, 310 | image_size=image_fmap_size, 311 | stable=stable, 312 | ) 313 | 314 | else: 315 | raise ValueError(f'attention type "{attn_type}" is not valid') 316 | 317 | attn, reused_attn_type = shared_attn_layers.get(attn_id, (None, None)) 318 | 319 | if not exists(attn): 320 | attn = attn_class( 321 | dim, 322 | causal=causal, 323 | seq_len=seq_len, 324 | heads=heads, 325 | dim_head=dim_head, 326 | dropout=attn_dropout, 327 | ) 328 | 329 | shared_attn_layers[attn_id] = (attn, attn_type) 330 | elif attn_type != reused_attn_type: 331 | raise ValueError( 332 | "attn_types do not match shared_attn_ids " 333 | f'(ind = {ind}, attn_type = "{attn_type}", reused_attn_type = "{reused_attn_type}")' 334 | ) 335 | 336 | ff = shared_ff_layers.get(ff_id) 337 | if not exists(ff): 338 | ff = FeedForward(dim, mult=ff_mult, dropout=ff_dropout) 339 | shared_ff_layers[ff_id] = ff 340 | 341 | if isinstance(attn, Attention): 342 | attn = CachedAs(f"attn_{ind}", attn) 343 | else: 344 | # at the moment, other attention classes don't support cache 345 | attn = NonCached(attn) 346 | 347 | ff = FeedForward(dim, mult=ff_mult, dropout=ff_dropout) 348 | 349 | if shift_tokens: 350 | attn = CachedAs( 351 | f"preshift_attn_{ind}", 352 | PreShiftToken( 353 | attn, 354 | image_size=image_fmap_size, 355 | num_images=num_images, 356 | seq_len=seq_len, 357 | ), 358 | ) 359 | ff = CachedAs( 360 | f"preshift_ff_{ind}", 361 | PreShiftToken( 362 | ff, 363 | image_size=image_fmap_size, 364 | num_images=num_images, 365 | seq_len=seq_len, 366 | ), 367 | ) 368 | 369 | layers.append( 370 | nn.ModuleList( 371 | [ 372 | LayerScale( 373 | dim, ind + 1, PreNorm(dim, attn, sandwich=sandwich_norm) 374 | ), 375 | LayerScale( 376 | dim, ind + 1, PreNorm(dim, ff, sandwich=sandwich_norm) 377 | ), 378 | ] 379 | ) 380 | ) 381 | 382 | execute_type = ReversibleSequence if reversible else SequentialSequence 383 | route_attn = ((True, False),) * depth 384 | route_all = ((True, True),) * depth 385 | attn_route_map = { 386 | "mask": route_attn, 387 | "rotary_pos_emb": route_attn, 388 | "cache": route_all, 389 | } 390 | 391 | self.layers = execute_type(layers, args_route=attn_route_map) 392 | 393 | # generate positional embeddings for rotary 394 | 395 | pos_emb = None 396 | if rotary_emb: 397 | 398 | rot_dim = dim_head // 3 399 | img_seq_len = ((image_fmap_size // num_images) ** 2) * num_images 400 | 401 | text_len = seq_len - img_seq_len + 1 402 | 403 | text_pos_emb = RotaryEmbedding(dim=rot_dim) 404 | 405 | img_axial_pos_emb = RotaryEmbedding(dim=rot_dim, freqs_for="pixel") 406 | 407 | text_freqs = text_pos_emb(torch.arange(text_len)) 408 | 409 | img_to_text_freqs = text_pos_emb( 410 | torch.full((img_seq_len,), 8192) 411 | ) # image is given a position far away from text 412 | 413 | text_freqs = torch.cat((text_freqs, img_to_text_freqs), dim=0) 414 | 415 | img_freqs_axial = img_axial_pos_emb( 416 | torch.linspace(-1, 1, steps=image_fmap_size) 417 | ) 418 | 419 | if num_images > 1: 420 | 421 | split_img_freqs_axial = torch.split( 422 | img_freqs_axial, image_fmap_size // num_images, dim=0 423 | ) 424 | 425 | split_img_freqs = [ 426 | broadcat( 427 | ( 428 | rearrange(img_freqs_axial_per_image, "i d -> i () d"), 429 | rearrange(img_freqs_axial_per_image, "j d -> () j d"), 430 | ), 431 | dim=-1, 432 | ) 433 | for img_freqs_axial_per_image in split_img_freqs_axial 434 | ] 435 | 436 | split_img_freqs = [ 437 | rearrange(img_freqs_per_image, "h w d -> (h w) d") 438 | for img_freqs_per_image in split_img_freqs 439 | ] 440 | 441 | # concat per image-image_freqs 442 | 443 | img_freqs = torch.cat(split_img_freqs, dim=0) 444 | 445 | elif num_images == 1: 446 | img_freqs = broadcat( 447 | ( 448 | rearrange(img_freqs_axial, "i d -> i () d"), 449 | rearrange(img_freqs_axial, "j d -> () j d"), 450 | ), 451 | dim=-1, 452 | ) 453 | 454 | img_freqs = rearrange(img_freqs, "h w d -> (h w) d") 455 | 456 | else: 457 | assert False, "num_images must be int greater than 0" 458 | self.img_axial_pos_emb = img_axial_pos_emb 459 | self.text_pos_emb = text_pos_emb 460 | 461 | text_axial_freqs = img_axial_pos_emb( 462 | torch.full((text_len,), -10.0) 463 | ) # text is given a position of -10 apart from the image axial positions, which is from range [-1, 1] 464 | 465 | text_axial_freqs = torch.cat((text_axial_freqs, text_axial_freqs), dim=-1) 466 | 467 | img_freqs = torch.cat((text_axial_freqs, img_freqs), dim=0) 468 | 469 | pos_emb = torch.cat((text_freqs, img_freqs), dim=-1) 470 | 471 | pos_emb = rearrange(pos_emb, "n d -> () n d") 472 | 473 | self.register_buffer("pos_emb", pos_emb) 474 | 475 | def forward(self, x, **kwargs): 476 | return self.layers(x, rotary_pos_emb=self.pos_emb, **kwargs) 477 | 478 | def _get_attention_mask(self, attn_type): 479 | img_seq_len = self.image_fmap_size ** 2 480 | text_len = self.seq_len + 1 - img_seq_len 481 | 482 | static_mask = torch.zeros(self.seq_len, self.seq_len, dtype=torch.bool) 483 | static_mask[:, :text_len] = True 484 | if attn_type == "axial_row": 485 | for row in range(self.image_fmap_size): 486 | begin = text_len + row * self.image_fmap_size 487 | end = text_len + (row + 1) * self.image_fmap_size 488 | static_mask[begin:end, begin:end] = True 489 | elif attn_type == "axial_col": 490 | for col in range(self.image_fmap_size): 491 | begin = text_len + col 492 | static_mask[ 493 | begin :: self.image_fmap_size, begin :: self.image_fmap_size 494 | ] = True 495 | else: 496 | raise ValueError( 497 | f'attention type "{attn_type}" can\'t be simulated with a static mask' 498 | ) 499 | return static_mask -------------------------------------------------------------------------------- /celle/vae.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib 3 | from tqdm import tqdm 4 | from math import sqrt, log 5 | from omegaconf import OmegaConf 6 | from taming.models.vqgan import VQModel, GumbelVQ 7 | import importlib 8 | 9 | import torch 10 | from torch import nn 11 | import torch.nn.functional as F 12 | 13 | from einops import rearrange 14 | 15 | # constants 16 | 17 | CACHE_PATH = os.path.expanduser("~/.cache/dalle") 18 | 19 | # helpers methods 20 | 21 | 22 | def exists(val): 23 | return val is not None 24 | 25 | 26 | def default(val, d): 27 | return val if exists(val) else d 28 | 29 | 30 | def load_model(path): 31 | with open(path, "rb") as f: 32 | return torch.load(f, map_location=torch.device("cpu")) 33 | 34 | 35 | def map_pixels(x, eps=0.1): 36 | return (1 - 2 * eps) * x + eps 37 | 38 | 39 | def unmap_pixels(x, eps=0.1): 40 | return torch.clamp((x - eps) / (1 - 2 * eps), 0, 1) 41 | 42 | def make_contiguous(module): 43 | with torch.no_grad(): 44 | for param in module.parameters(): 45 | param.set_(param.contiguous()) 46 | 47 | 48 | # VQGAN from Taming Transformers paper 49 | # https://arxiv.org/abs/2012.09841 50 | 51 | 52 | def get_obj_from_str(string, reload=False): 53 | module, cls = string.rsplit(".", 1) 54 | if reload: 55 | module_imp = importlib.import_module(module) 56 | importlib.reload(module_imp) 57 | return getattr(importlib.import_module(module, package=None), cls) 58 | 59 | 60 | def instantiate_from_config(config): 61 | if not "target" in config: 62 | raise KeyError("Expected key `target` to instantiate.") 63 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 64 | 65 | 66 | class VQGanVAE(nn.Module): 67 | def __init__(self, vqgan_model_path=None, vqgan_config_path=None, channels = 1): 68 | super().__init__() 69 | 70 | assert(vqgan_model_path is not None) 71 | 72 | model_path = vqgan_model_path 73 | config_path = vqgan_config_path 74 | 75 | config = OmegaConf.load(config_path) 76 | 77 | model = instantiate_from_config(config["model"]) 78 | 79 | state = torch.load(model_path, map_location="cpu")["state_dict"] 80 | model.load_state_dict(state, strict=False) 81 | 82 | print(f"Loaded VQGAN from {model_path} and {config_path}") 83 | 84 | self.model = model 85 | 86 | # f as used in https://github.com/CompVis/taming-transformers#overview-of-pretrained-models 87 | f = ( 88 | config.model.params.ddconfig.resolution 89 | / config.model.params.ddconfig.attn_resolutions[0] 90 | ) 91 | self.num_layers = int(log(f) / log(2)) 92 | self.image_size = config.model.params.ddconfig.resolution 93 | self.num_tokens = config.model.params.n_embed 94 | self.is_gumbel = isinstance(self.model, GumbelVQ) 95 | self.channels = config.model.params.ddconfig.in_channels 96 | 97 | def get_codebook_indices(self, img): 98 | b = img.shape[0] 99 | # img = (2 * img) - 1 100 | _, _, [_, _, indices] = self.model.encode(img) 101 | if self.is_gumbel: 102 | return rearrange(indices, "b h w -> b (h w)", b=b) 103 | return rearrange(indices, "(b n) -> b n", b=b) 104 | 105 | def decode(self, img_seq): 106 | b, n = img_seq.shape 107 | one_hot_indices = F.one_hot(img_seq, num_classes=self.num_tokens).float() 108 | z = ( 109 | one_hot_indices @ self.model.quantize.embed.weight 110 | if self.is_gumbel 111 | else (one_hot_indices @ self.model.quantize.embedding.weight) 112 | ) 113 | 114 | z = rearrange(z, "b (h w) c -> b c h w", h=int(sqrt(n))) 115 | img = self.model.decode(z) 116 | 117 | # img = (img.clamp(-1.0, 1.0) + 1) * 0.5 118 | return img 119 | 120 | def forward(self, img, optimizer_idx=1): 121 | return self.model.training_step(img, optimizer_idx=optimizer_idx) 122 | -------------------------------------------------------------------------------- /celle_main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import torch 5 | import torch.random 6 | from torch.optim import Adam 7 | from torch.utils.data import DataLoader 8 | import pytorch_lightning as pl 9 | from pytorch_lightning import seed_everything 10 | from pytorch_lightning.trainer import Trainer 11 | 12 | from dataloader import OpenCellLoader 13 | from celle import VQGanVAE, CELLE 14 | from omegaconf import OmegaConf 15 | import argparse, os, sys, datetime, glob 16 | 17 | 18 | from celle.celle import gumbel_sample, top_k 19 | 20 | torch.random.manual_seed(42) 21 | np.random.seed(42) 22 | 23 | from celle_taming_main import ( 24 | instantiate_from_config, 25 | nondefault_trainer_args, 26 | get_parser, 27 | ) 28 | 29 | 30 | class CellDataModule(pl.LightningDataModule): 31 | def __init__( 32 | self, 33 | data_csv, 34 | sequence_mode="simple", 35 | vocab="bert", 36 | crop_size=256, 37 | batch_size=1, 38 | threshold=False, 39 | text_seq_len=1000, 40 | num_workers=1, 41 | **kwargs, 42 | ): 43 | super().__init__() 44 | 45 | self.data_csv = data_csv 46 | self.protein_sequence_length = 0 47 | self.image_folders = [] 48 | self.crop_size = crop_size 49 | self.batch_size = batch_size 50 | self.sequence_mode = sequence_mode 51 | self.threshold = threshold 52 | self.text_seq_len = int(text_seq_len) 53 | self.vocab = vocab 54 | self.num_workers = num_workers if num_workers is not None else batch_size * 2 55 | 56 | def setup(self): 57 | # called on every GPU 58 | self.cell_dataset_train = OpenCellLoader( 59 | data_csv=self.data_csv, 60 | crop_size=self.crop_size, 61 | split_key="train", 62 | crop_method="random", 63 | sequence_mode=self.sequence_mode, 64 | vocab=self.vocab, 65 | text_seq_len=self.text_seq_len, 66 | threshold=self.threshold, 67 | ) 68 | 69 | self.cell_dataset_val = OpenCellLoader( 70 | data_csv=self.data_csv, 71 | crop_size=self.crop_size, 72 | crop_method="center", 73 | split_key="val", 74 | sequence_mode=self.sequence_mode, 75 | vocab=self.vocab, 76 | text_seq_len=self.text_seq_len, 77 | threshold=self.threshold, 78 | ) 79 | 80 | def prepare_data(self): 81 | 82 | pass 83 | 84 | def train_dataloader(self): 85 | return DataLoader( 86 | self.cell_dataset_train, 87 | num_workers=self.num_workers, 88 | shuffle=True, 89 | batch_size=self.batch_size, 90 | ) 91 | 92 | def val_dataloader(self): 93 | return DataLoader( 94 | self.cell_dataset_val, 95 | num_workers=self.num_workers, 96 | batch_size=self.batch_size, 97 | ) 98 | 99 | # def test_dataloader(self): 100 | # transforms = ... 101 | # return DataLoader(self.test, batch_size=64) 102 | 103 | 104 | class CELLE_trainer(pl.LightningModule): 105 | def __init__( 106 | self, 107 | vqgan_model_path, 108 | vqgan_config_path, 109 | ckpt_path=None, 110 | image_key="threshold", 111 | condition_model_path=None, 112 | condition_config_path=None, 113 | num_images=2, 114 | dim=2, 115 | num_text_tokens=30, 116 | text_seq_len=1000, 117 | depth=16, 118 | heads=16, 119 | dim_head=64, 120 | reversible=False, 121 | attn_dropout=0.1, 122 | ff_dropout=0.1, 123 | attn_types="full", 124 | loss_img_weight=7, 125 | stable=False, 126 | sandwich_norm=False, 127 | shift_tokens=True, 128 | rotary_emb=True, 129 | text_embedding="bert", 130 | fixed_embedding=True, 131 | loss_cond_weight=1, 132 | learning_rate=3e-4, 133 | monitor="val_loss", 134 | ): 135 | super().__init__() 136 | 137 | vae = VQGanVAE( 138 | vqgan_model_path=vqgan_model_path, vqgan_config_path=vqgan_config_path 139 | ) 140 | 141 | self.image_key = image_key 142 | 143 | if condition_config_path: 144 | condition_vae = VQGanVAE( 145 | vqgan_model_path=condition_model_path, 146 | vqgan_config_path=condition_config_path, 147 | ) 148 | else: 149 | condition_vae = None 150 | 151 | self.celle = CELLE( 152 | dim=dim, 153 | vae=vae, # automatically infer (1) image sequence length and (2) number of image tokens 154 | condition_vae=condition_vae, 155 | num_images=num_images, 156 | num_text_tokens=num_text_tokens, # vocab size for text 157 | text_seq_len=text_seq_len, # text sequence length 158 | depth=depth, # should aim to be 64 159 | heads=heads, # attention heads 160 | reversible=reversible, # should aim to be True 161 | dim_head=dim_head, # attention head dimension 162 | attn_dropout=attn_dropout, # attention dropout 163 | ff_dropout=ff_dropout, 164 | attn_types=attn_types, 165 | loss_img_weight=loss_img_weight, 166 | stable=stable, 167 | sandwich_norm=sandwich_norm, 168 | shift_tokens=shift_tokens, 169 | rotary_emb=rotary_emb, 170 | text_embedding=text_embedding, 171 | fixed_embedding=fixed_embedding, 172 | loss_cond_weight=loss_cond_weight 173 | # feedforward dropout 174 | ) 175 | 176 | self.learning_rate = learning_rate 177 | self.num_text_tokens = num_text_tokens 178 | self.num_images = num_images 179 | 180 | if monitor is not None: 181 | self.monitor = monitor 182 | 183 | if ckpt_path is not None: 184 | self.init_from_ckpt(ckpt_path, ignore_keys=[]) 185 | 186 | def init_from_ckpt(self, path, ignore_keys=list()): 187 | sd = torch.load(path, map_location="cpu")["state_dict"] 188 | for k in sd.keys(): 189 | for ik in ignore_keys: 190 | if k.startswith(ik): 191 | self.print("Deleting key {} from state_dict.".format(k)) 192 | del sd[k] 193 | self.celle.load_state_dict(sd, strict=False) 194 | print(f"Restored from {path}") 195 | 196 | def forward(self, text, condition, target, return_loss=True): 197 | 198 | return self.celle( 199 | text=text, condition=condition, image=target, return_loss=return_loss 200 | ) 201 | 202 | def get_input(self, batch): 203 | text = batch["sequence"].squeeze(1) 204 | condition = batch["nucleus"] 205 | target = batch[self.image_key] 206 | 207 | return text, condition, target 208 | 209 | def get_image_from_logits(self, logits, temperature=0.9): 210 | 211 | filtered_logits = top_k(logits, thres=0.5) 212 | sample = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) 213 | 214 | self.celle.vae.eval() 215 | out = self.celle.vae.decode( 216 | sample[:, self.celle.text_seq_len + self.celle.condition_seq_len :] 217 | - (self.celle.num_text_tokens + self.celle.num_condition_tokens) 218 | ) 219 | 220 | return out 221 | 222 | def get_loss(self, text, condition, target): 223 | 224 | 225 | loss_dict = {} 226 | 227 | loss, loss_dict, logits = self( 228 | text, condition, target, return_loss=True 229 | ) 230 | 231 | return loss, loss_dict 232 | 233 | def total_loss( 234 | self, 235 | loss, 236 | loss_dict={"loss_text": 0, "loss_cond": 0, "loss_img": 0}, 237 | mode="train", 238 | ): 239 | 240 | loss_dict = {f"{mode}/{key}": value for key, value in loss_dict.items()} 241 | 242 | self.log( 243 | f"{mode}/loss_text", 244 | loss_dict[f"{mode}/loss_text"], 245 | prog_bar=True, 246 | logger=True, 247 | on_step=True, 248 | on_epoch=True, 249 | ) 250 | 251 | self.log( 252 | f"{mode}/loss_cond", 253 | loss_dict[f"{mode}/loss_cond"], 254 | prog_bar=True, 255 | logger=True, 256 | on_step=True, 257 | on_epoch=True, 258 | ) 259 | 260 | self.log( 261 | f"{mode}/loss_img", 262 | loss_dict[f"{mode}/loss_img"], 263 | prog_bar=True, 264 | logger=True, 265 | on_step=True, 266 | on_epoch=True, 267 | ) 268 | 269 | return loss 270 | 271 | def training_step(self, batch, batch_idx): 272 | 273 | text, condition, target = self.get_input(batch) 274 | loss, log_dict = self.get_loss(text, condition, target) 275 | 276 | loss = self.total_loss(loss, log_dict, mode="train") 277 | 278 | return loss 279 | 280 | def validation_step(self, batch, batch_idx): 281 | 282 | with torch.no_grad(): 283 | 284 | text, condition, target = self.get_input(batch) 285 | loss, log_dict = self.get_loss(text, condition, target) 286 | 287 | loss = self.total_loss(loss, log_dict, mode="val") 288 | 289 | return loss 290 | 291 | def configure_optimizers(self): 292 | 293 | 294 | optimizer = Adam(self.parameters(), lr=self.learning_rate) 295 | 296 | return optimizer 297 | 298 | def scale_image(self, image): 299 | 300 | for tensor in image: 301 | if torch.min(tensor) < 0: 302 | tensor += -torch.min(tensor) 303 | else: 304 | tensor -= torch.min(tensor) 305 | 306 | tensor /= torch.max(tensor) 307 | 308 | return image 309 | 310 | @torch.no_grad() 311 | def log_images(self, batch, **kwargs): 312 | 313 | log = dict() 314 | text, condition, target = self.get_input(batch) 315 | text = text.squeeze(1).to(self.device) 316 | condition = condition.to(self.device) 317 | 318 | out = self.celle.generate_images(text=text, condition=condition, use_cache=True) 319 | 320 | log["condition"] = self.scale_image(condition) 321 | log["output"] = self.scale_image(out) 322 | if self.image_key == "threshold": 323 | log["threshold"] = self.scale_image(target) 324 | log["target"] = self.scale_image(batch["target"]) 325 | else: 326 | log["target"] = self.scale_image(target) 327 | 328 | return log 329 | 330 | 331 | # from https://github.com/CompVis/taming-transformers/blob/master/celle_main.py 332 | 333 | if __name__ == "__main__": 334 | # custom parser to specify config files, train, test and debug mode, 335 | # postfix, resume. 336 | # `--key value` arguments are interpreted as arguments to the trainer. 337 | # `nested.key=value` arguments are interpreted as config parameters. 338 | # configs are merged from left-to-right followed by command line parameters. 339 | 340 | # model: 341 | # learning_rate: float 342 | # target: path to lightning module 343 | # params: 344 | # key: value 345 | # data: 346 | # target: celle_main.DataModuleFromConfig 347 | # params: 348 | # batch_size: int 349 | # wrap: bool 350 | # train: 351 | # target: path to train dataset 352 | # params: 353 | # key: value 354 | # validation: 355 | # target: path to validation dataset 356 | # params: 357 | # key: value 358 | # test: 359 | # target: path to test dataset 360 | # params: 361 | # key: value 362 | # lightning: (optional, has sane defaults and can be specified on cmdline) 363 | # trainer: 364 | # additional arguments to trainer 365 | # logger: 366 | # logger to instantiate 367 | # modelcheckpoint: 368 | # modelcheckpoint to instantiate 369 | # callbacks: 370 | # callback1: 371 | # target: importpath 372 | # params: 373 | # key: value 374 | 375 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 376 | 377 | # add cwd for convenience and to make classes in this file available when 378 | # running as `python celle_main.py` 379 | # (in particular `celle_main.DataModuleFromConfig`) 380 | sys.path.append(os.getcwd()) 381 | 382 | parser = get_parser() 383 | parser = Trainer.add_argparse_args(parser) 384 | 385 | opt, unknown = parser.parse_known_args() 386 | if opt.name and opt.resume: 387 | raise ValueError( 388 | "-n/--name and -r/--resume cannot be specified both." 389 | "If you want to resume training in a new log folder, " 390 | "use -n/--name in combination with --resume_from_checkpoint" 391 | ) 392 | if opt.resume: 393 | if not os.path.exists(opt.resume): 394 | raise ValueError("Cannot find {}".format(opt.resume)) 395 | if os.path.isfile(opt.resume): 396 | paths = opt.resume.split("/") 397 | idx = len(paths) - paths[::-1].index("logs") + 1 398 | logdir = "/".join(paths[:idx]) 399 | ckpt = opt.resume 400 | else: 401 | assert os.path.isdir(opt.resume), opt.resume 402 | logdir = opt.resume.rstrip("/") 403 | ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") 404 | 405 | opt.resume_from_checkpoint = ckpt 406 | base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) 407 | opt.base = base_configs + opt.base 408 | _tmp = logdir.split("/") 409 | nowname = _tmp[_tmp.index("logs") + 1] 410 | else: 411 | if opt.name: 412 | name = "_" + opt.name 413 | elif opt.base: 414 | cfg_fname = os.path.split(opt.base[0])[-1] 415 | cfg_name = os.path.splitext(cfg_fname)[0] 416 | name = "_" + cfg_name 417 | else: 418 | name = "" 419 | nowname = now + name + opt.postfix 420 | logdir = os.path.join("logs", nowname) 421 | 422 | ckptdir = os.path.join(logdir, "checkpoints") 423 | cfgdir = os.path.join(logdir, "configs") 424 | seed_everything(opt.seed) 425 | 426 | try: 427 | # init and save configs 428 | configs = [OmegaConf.load(cfg) for cfg in opt.base] 429 | cli = OmegaConf.from_dotlist(unknown) 430 | config = OmegaConf.merge(*configs, cli) 431 | lightning_config = config.pop("lightning", OmegaConf.create()) 432 | # merge trainer cli with config 433 | trainer_config = lightning_config.get("trainer", OmegaConf.create()) 434 | # default to ddp 435 | # trainer_config["distributed_backend"] = "ddp" 436 | for k in nondefault_trainer_args(opt): 437 | trainer_config[k] = getattr(opt, k) 438 | if not "gpus" in trainer_config: 439 | del trainer_config["distributed_backend"] 440 | cpu = True 441 | else: 442 | gpuinfo = trainer_config["gpus"] 443 | print(f"Running on GPUs {gpuinfo}") 444 | cpu = False 445 | trainer_opt = argparse.Namespace(**trainer_config) 446 | lightning_config.trainer = trainer_config 447 | 448 | # model 449 | model = instantiate_from_config(config.model) 450 | 451 | # trainer and callbacks 452 | trainer_kwargs = dict() 453 | 454 | # default logger configs 455 | # NOTE wandb < 0.10.0 interferes with shutdown 456 | # wandb >= 0.10.0 seems to fix it but still interferes with pudb 457 | # debugging (wrongly sized pudb ui) 458 | # thus prefer testtube for now 459 | default_logger_cfgs = { 460 | "wandb": { 461 | "target": "pytorch_lightning.loggers.WandbLogger", 462 | "params": { 463 | "name": nowname, 464 | "save_dir": logdir, 465 | "offline": opt.debug, 466 | "id": nowname, 467 | }, 468 | }, 469 | "testtube": { 470 | "target": "pytorch_lightning.loggers.TestTubeLogger", 471 | "params": { 472 | "name": "testtube", 473 | "save_dir": logdir, 474 | }, 475 | }, 476 | } 477 | default_logger_cfg = default_logger_cfgs["testtube"] 478 | logger_cfg = lightning_config.logger or OmegaConf.create() 479 | logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) 480 | trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) 481 | 482 | # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to 483 | # specify which metric is used to determine best models 484 | default_modelckpt_cfg = { 485 | "target": "pytorch_lightning.callbacks.ModelCheckpoint", 486 | "params": { 487 | "dirpath": ckptdir, 488 | "filename": "{epoch:06}", 489 | "verbose": True, 490 | "save_last": True, 491 | }, 492 | } 493 | if hasattr(model, "monitor"): 494 | print(f"Monitoring {model.monitor} as checkpoint metric.") 495 | default_modelckpt_cfg["params"]["monitor"] = model.monitor 496 | default_modelckpt_cfg["params"]["save_top_k"] = 3 497 | 498 | modelckpt_cfg = lightning_config.modelcheckpoint or OmegaConf.create() 499 | modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) 500 | trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg) 501 | 502 | # add callback which sets up log directory 503 | default_callbacks_cfg = { 504 | "setup_callback": { 505 | "target": "celle_taming_main.SetupCallback", 506 | "params": { 507 | "resume": opt.resume, 508 | "now": now, 509 | "logdir": logdir, 510 | "ckptdir": ckptdir, 511 | "cfgdir": cfgdir, 512 | "config": config, 513 | "lightning_config": lightning_config, 514 | }, 515 | }, 516 | "image_logger": { 517 | "target": "celle_taming_main.ImageLogger", 518 | "params": { 519 | "batch_frequency": 1500, 520 | "max_images": 5, 521 | "clamp": False, 522 | "increase_log_steps": False, 523 | }, 524 | }, 525 | # "learning_rate_logger": { 526 | # "target": "celle_taming_main.LearningRateMonitor", 527 | # "params": { 528 | # "logging_interval": "step", 529 | # # "log_momentum": True 530 | # }, 531 | # }, 532 | } 533 | callbacks_cfg = lightning_config.callbacks or OmegaConf.create() 534 | callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) 535 | trainer_kwargs["callbacks"] = [ 536 | instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg 537 | ] 538 | 539 | trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) 540 | 541 | # data 542 | data = instantiate_from_config(config.data) 543 | # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html 544 | # calling these ourselves should not be necessary but it is. 545 | # lightning still takes care of proper multiprocessing though 546 | data.setup() 547 | data.prepare_data() 548 | 549 | # configure learning rate 550 | bs, lr = config.data.params.batch_size, config.model.learning_rate 551 | 552 | if not cpu: 553 | ngpu = len(lightning_config.trainer.gpus.strip(",").split(",")) 554 | else: 555 | ngpu = 1 556 | accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches or 1 557 | print(f"accumulate_grad_batches = {accumulate_grad_batches}") 558 | lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches 559 | model.learning_rate = accumulate_grad_batches * ngpu * bs * lr 560 | 561 | print( 562 | "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (lr)".format( 563 | model.learning_rate, accumulate_grad_batches, ngpu, bs, lr 564 | ) 565 | ) 566 | 567 | # allow checkpointing via USR1 568 | def melk(*args, **kwargs): 569 | # run all checkpoint hooks 570 | if trainer.global_rank == 0: 571 | print("Summoning checkpoint.") 572 | ckpt_path = os.path.join(ckptdir, "last.ckpt") 573 | trainer.save_checkpoint(ckpt_path) 574 | 575 | def divein(*args, **kwargs): 576 | if trainer.global_rank == 0: 577 | import pudb 578 | 579 | pudb.set_trace() 580 | 581 | import signal 582 | 583 | signal.signal(signal.SIGUSR1, melk) 584 | signal.signal(signal.SIGUSR2, divein) 585 | 586 | # run 587 | if opt.train: 588 | try: 589 | trainer.fit(model, data) 590 | except Exception: 591 | melk() 592 | raise 593 | if not opt.no_test and not trainer.interrupted: 594 | trainer.test(model, data) 595 | except Exception: 596 | if opt.debug and trainer.global_rank == 0: 597 | try: 598 | import pudb as debugger 599 | except ImportError: 600 | import pdb as debugger 601 | debugger.post_mortem() 602 | raise 603 | finally: 604 | # move newly created debug project to debug_runs 605 | if opt.debug and not opt.resume and trainer.global_rank == 0: 606 | dst, name = os.path.split(logdir) 607 | dst = os.path.join(dst, "debug_runs", name) 608 | os.makedirs(os.path.split(dst)[0], exist_ok=True) 609 | os.rename(logdir, dst) 610 | -------------------------------------------------------------------------------- /configs/celle.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | learning_rate: 0.0003 3 | target: celle_main.CELLE_trainer 4 | params: 5 | ckpt_path: 6 | condition_model_path: 7 | condition_config_path: 8 | vqgan_model_path: 9 | vqgan_config_path: 10 | image_key: threshold 11 | num_images: 2 12 | dim: 768 13 | num_text_tokens: 30 14 | text_seq_len: 1000 15 | depth: 32 16 | heads: 16 17 | dim_head: 64 18 | attn_dropout: 0.1 19 | ff_dropout: 0.1 20 | attn_types: full 21 | rotary_emb: true 22 | fixed_embedding: True 23 | monitor: val/loss_img_epoch 24 | text_embedding: bert 25 | 26 | data: 27 | target: celle_main.CellDataModule 28 | params: 29 | data_csv: 30 | crop_size: 256 31 | batch_size: 1 32 | sequence_mode: embedding 33 | vocab: bert 34 | threshold: true 35 | text_seq_len: 1000 36 | num_workers: 4 37 | -------------------------------------------------------------------------------- /configs/nucleus_vqgan.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: taming.models.vqgan.VQModel 4 | params: 5 | image_key: nucleus 6 | monitor: 7 | ckpt_path: 8 | embed_dim: 256 9 | n_embed: 512 10 | ddconfig: 11 | double_z: false 12 | z_channels: 256 13 | resolution: 256 14 | in_channels: 1 15 | out_ch: 1 16 | ch: 128 17 | ch_mult: 18 | - 1 19 | - 1 20 | - 2 21 | - 2 22 | - 4 23 | num_res_blocks: 2 24 | attn_resolutions: 25 | - 16 26 | dropout: 0.0 27 | lossconfig: 28 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 29 | params: 30 | disc_conditional: false 31 | disc_in_channels: 1 32 | disc_start: 50000 33 | disc_weight: 0.2 34 | codebook_weight: 1.0 35 | 36 | data: 37 | target: celle_main.CellDataModule 38 | params: 39 | data_csv: 40 | crop_size: 256 41 | batch_size: 1 42 | num_workers: 8 43 | -------------------------------------------------------------------------------- /configs/threshold_vqgan.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: taming.models.vqgan.VQModel 4 | params: 5 | image_key: threshold 6 | monitor: 7 | embed_dim: 256 8 | n_embed: 512 9 | ddconfig: 10 | double_z: false 11 | z_channels: 256 12 | resolution: 256 13 | in_channels: 1 14 | out_ch: 1 15 | ch: 128 16 | ch_mult: 17 | - 1 18 | - 1 19 | - 2 20 | - 2 21 | - 4 22 | num_res_blocks: 2 23 | attn_resolutions: 24 | - 16 25 | dropout: 0.0 26 | lossconfig: 27 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 28 | params: 29 | disc_conditional: false 30 | disc_in_channels: 1 31 | disc_start: 50000 32 | disc_weight: 0.2 33 | codebook_weight: 1.0 34 | 35 | data: 36 | target: celle_main.CellDataModule 37 | params: 38 | data_csv: 39 | crop_size: 256 40 | batch_size: 1 41 | num_workers: 8 42 | threshold: True -------------------------------------------------------------------------------- /data/aaDescriptors.csv: -------------------------------------------------------------------------------- 1 | "AA","PP1","PP2","PP3","KF1","KF2","KF3","KF4","KF5","KF6","KF7","KF8","KF9","KF10","Z1","Z2","Z3","Z4","Z5","F1","F2","F3","F4","F5","F6","T1","T2","T3","T4","T5","VHSE1","VHSE2","VHSE3","VHSE4","VHSE5","VHSE6","VHSE7","VHSE8","ProtFP1","ProtFP2","ProtFP3","ProtFP4","ProtFP5","ProtFP6","ProtFP7","ProtFP8","ST1","ST2","ST3","ST4","ST5","ST6","ST7","ST8","BLOSUM1","BLOSUM2","BLOSUM3","BLOSUM4","BLOSUM5","BLOSUM6","BLOSUM7","BLOSUM8","BLOSUM9","BLOSUM10","MSWHIM1","MSWHIM2","MSWHIM3" 2 | "A",-0.96,-0.76,0.31,-1.56,-1.67,-0.97,-0.27,-0.93,-0.78,-0.2,-0.08,0.21,-0.48,0.24,-2.32,0.6,-0.14,1.3,0.207,0.821,-1.009,1.387,0.063,-0.6,-9.11,-1.63,0.63,1.04,2.26,0.15,-1.11,-1.35,-0.92,0.02,-0.91,0.36,-0.48,-0.1,-4.94,-2.13,1.7,-0.39,1.06,-1.39,0.97,-1.552,-0.791,-0.627,0.237,-0.461,-2.229,0.283,1.221,0.08,-0.92,0.53,0,0.24,0.19,0.66,-0.05,1.36,0.33,-0.73,0.2,-0.62 3 | "R",0.8,0.63,0.99,0.22,1.27,1.37,1.87,-1.7,0.46,0.92,-0.39,0.23,0.93,3.52,2.5,-3.5,1.99,-0.17,-1.229,0.378,0.516,-0.328,-0.052,2.728,0.23,3.89,-1.16,-0.39,-0.06,-1.47,1.45,1.24,1.27,1.55,1.47,1.3,0.83,-2.79,6.6,1.21,2.07,1.67,0.76,0,0.32,-0.059,0.731,-0.013,-0.096,-0.253,0.3,1.256,0.854,1.01,0.19,-0.86,-0.61,1.28,0.2,0.66,0.18,-0.22,-0.52,-0.22,0.27,1 4 | "N",0.82,-0.57,0.02,1.14,-0.07,-0.12,0.81,0.18,0.37,-0.09,1.23,1.1,-1.73,3.05,1.62,1.04,-1.15,1.61,-1.009,-0.939,-0.428,-0.397,-0.539,-0.605,-4.62,0.66,1.16,-0.22,0.93,-0.99,0,-0.37,0.69,-0.55,0.85,0.73,-0.8,-4.88,0.81,0.14,-0.14,1.23,-0.65,1.02,-1.94,-0.888,-0.057,-0.651,-0.214,0.917,0.164,-0.14,-0.166,1.51,0.22,-0.05,1.01,0.12,0.83,-0.03,-0.57,-1.2,-0.14,0.14,0.2,-0.66 5 | "D",1,-0.89,-1,0.58,-0.22,-1.58,0.81,-0.92,0.15,-1.52,0.47,0.76,0.7,3.98,0.93,1.93,-2.46,0.75,-1.298,-0.444,-0.584,-0.175,-0.259,-1.762,-4.65,0.75,1.39,-0.4,1.05,-1.15,0.67,-0.41,-0.01,-2.68,1.31,0.03,0.56,-6.61,0.94,-3.04,-4.58,0.48,-1.31,0.1,0.94,-0.907,-0.054,-0.781,-0.248,1.12,0.101,-0.245,-0.075,1.55,0.01,0.32,0.49,-0.99,0.01,-1.62,0.53,-0.15,-0.28,0.11,-1,-0.96 6 | "C",-0.55,-0.47,0.19,0.12,-0.89,0.45,-1.05,-0.71,2.41,1.52,-0.69,1.13,1.1,0.84,-1.67,3.75,0.18,-2.65,0.997,0.021,-1.419,-2.08,-0.799,0.502,-7.35,-0.86,-0.33,0.8,0.98,0.18,-1.67,-0.46,-0.21,0,1.2,-1.61,-0.19,4.62,-3.54,1.5,-1.26,3.27,-0.34,-0.47,-0.23,-1.276,-0.401,0.134,0.859,-0.196,-0.72,0.639,-0.857,-1.08,-1.11,1.56,0.81,1.83,-1.05,-0.74,0.38,-0.12,-0.1,-0.66,0.26,-0.27 7 | "Q",0.78,-0.3,-0.38,-0.47,0.24,0.07,1.1,1.1,0.59,0.84,-0.71,-0.03,-2.33,1.75,0.5,-1.44,-1.34,0.66,-0.88,0.381,-0.044,-0.455,-0.04,0.405,-3,1.72,0.28,-0.39,0.33,-0.96,0.12,0.18,0.16,0.09,0.42,-0.2,-0.41,-3.95,2.88,-0.83,0.52,0.9,0.55,-0.08,0.64,-0.662,0.228,-0.193,-0.105,0.418,0.474,0.172,0.408,1.09,0.3,-0.87,-0.72,0.5,-0.08,-0.44,0.2,0.38,0.67,0.3,1,-0.3 8 | "E",0.94,-0.54,-0.99,-1.45,0.19,-1.61,1.17,-1.31,0.4,0.04,0.38,-0.35,-0.12,3.11,0.26,-0.11,-3.04,-0.25,-1.349,1.388,-0.361,0.213,0.424,-1.303,-3.03,1.82,0.51,-0.58,0.43,-1.18,0.4,0.1,0.36,-2.16,-0.17,0.91,0.02,-5.1,2.2,-3.59,-2.26,-2.14,1.35,-0.45,-1.31,-0.629,-0.39,-0.38,-0.366,0.635,0.514,0.175,0.367,1.48,0.23,-0.67,-0.36,-0.28,-0.08,-1.01,0.36,0.77,0.3,0.24,-0.39,-0.04 9 | "G",-0.88,-1,0.49,1.46,-1.96,-0.23,-0.16,0.1,-0.11,1.32,2.36,-1.66,0.46,2.05,-4.06,0.36,-0.82,-0.38,-0.205,-2.219,-1.656,1.229,-1.115,-1.146,-10.61,-1.21,-0.12,0.75,3.25,-0.2,-1.53,-2.63,2.28,-0.53,-1.18,2.01,-1.34,-5.7,-8.72,4.18,-1.35,-0.31,2.91,0.32,-0.11,-1.844,-0.018,-0.184,0.573,-0.728,-3.317,0.166,2.522,0.85,0.17,1.73,0.09,-0.55,1.19,1.21,0.87,0.01,0.24,-0.31,-0.28,-0.75 10 | "H",0.67,-0.11,0.37,-0.41,0.52,-0.28,0.28,1.61,1.01,-1.85,0.47,1.13,1.63,2.47,1.95,0.26,3.9,0.09,-0.27,0.461,-0.024,-1.407,0.001,0.169,-1.01,-1.31,0.01,-1.81,-0.21,-0.43,-0.25,0.37,0.19,0.51,1.28,0.93,0.65,0.17,2.14,1.2,0.71,1.16,-0.38,-1.85,-2.79,-0.225,0.361,0.079,-1.037,0.568,0.273,1.208,-0.001,0.72,1.55,-0.8,1.55,0.35,-0.79,0.66,-0.08,-0.19,0.99,0.84,0.67,-0.78 11 | "I",-0.94,-0.05,-0.18,-0.73,-0.16,1.79,-0.77,-0.54,0.03,-0.83,0.51,0.66,-1.78,-3.89,-1.73,-1.71,-0.84,0.26,1.524,0.536,0.809,0.734,-0.196,0.427,-4.25,-0.28,-0.15,1.4,-0.21,1.27,-0.14,0.3,-1.8,0.3,-1.61,-0.16,-0.13,6.58,-1.73,-2.49,1.09,-0.34,-0.28,1.97,-0.92,-0.785,-1.01,-0.349,-0.097,-0.402,1.091,-0.139,-0.764,-1.46,-1.13,-0.76,0.38,-0.6,0.28,-0.13,0.2,-0.22,0.21,-0.91,0.83,-0.25 12 | "L",-0.9,0.03,-0.24,-1.04,0,-0.24,-1.1,-0.55,-2.05,0.96,-0.76,0.45,0.93,-4.28,-1.3,-1.49,-0.72,0.84,1.2,1.128,0.703,1.904,0.536,-0.141,-4.38,0.28,-0.49,1.45,0.02,1.36,0.07,0.36,-0.8,0.22,-1.37,0.08,-0.62,5.76,-1.33,-1.71,0.63,-1.7,0.71,-0.05,-0.51,-0.826,-0.379,0.038,-0.059,-0.625,1.025,-0.229,-0.129,-1.41,-0.86,-0.88,-0.17,0.03,0.34,0.11,0.15,-0.44,-0.02,-0.74,0.72,-0.16 13 | "K",0.6,0.1,1,-0.34,0.82,-0.23,1.7,1.54,-1.62,1.15,-0.08,-0.48,0.6,2.29,0.89,-2.49,1.49,0.31,-1.387,0.572,0.285,0.333,-0.169,1.157,-2.59,2.34,-1.69,0.41,-0.21,-1.17,0.7,0.7,0.8,1.64,0.67,1.63,0.13,-4.99,5,0.7,3,-1.23,1.41,0.19,0.87,-0.504,0.245,0.297,-0.065,-0.387,1.011,0.525,0.553,1.14,-0.04,-0.8,-0.85,0.82,0.1,0.21,0.13,0.18,-0.85,-0.51,0.08,0.6 14 | "M",-0.82,0.03,-0.08,-1.4,0.18,-0.42,-0.73,2,1.52,0.26,0.11,-1.27,0.27,-2.85,-0.22,0.47,1.94,-0.98,0.886,1.346,0.277,-0.913,0.007,-0.265,-4.08,0.98,-2.34,1.64,-0.79,1.01,-0.53,0.43,0,0.23,0.1,-0.86,-0.68,5.11,0.19,-1.02,0.15,0.13,-0.3,-2.95,0.5,-0.693,0.498,0.658,0.457,-0.231,1.064,0.248,-0.778,-0.96,-0.59,-0.97,-0.53,0.24,0.37,0.06,0.21,-0.56,0.36,-0.7,1,-0.32 15 | "F",-0.85,0.48,-0.58,-0.21,0.98,-0.36,-1.43,0.22,-0.81,0.67,1.1,1.71,-0.44,-4.22,1.94,1.06,0.54,-0.62,1.247,0.293,1.336,-0.026,0.012,-0.015,0.49,-0.94,-0.63,-1.27,-0.44,1.52,0.61,0.96,-0.16,0.25,0.28,-1.33,-0.2,6.76,0.88,0.89,-1.12,-0.49,-0.55,-0.87,1.05,-0.019,0.024,1.08,-0.22,-0.937,0.57,-0.357,0.278,-1.62,1.01,-0.31,0.62,-0.55,0.29,-0.02,0.1,0.43,-1.29,0.76,0.85,-0.34 16 | "P",-0.81,-0.4,-0.07,2.06,-0.33,-1.15,-0.75,0.88,-0.45,0.3,-2.3,0.74,-0.28,-1.66,0.27,1.84,0.7,2,-0.407,-2.038,-0.564,-0.128,3.847,-1.108,-5.11,-3.54,-0.53,-0.36,-0.29,0.22,-0.17,-0.5,0.05,-0.01,-1.34,-0.19,3.56,-3.82,-2.31,3.45,1,-3.22,-3.54,-0.36,-0.3,-1.049,-0.407,-0.067,-0.066,-0.813,-0.89,0.021,-0.894,0.88,-0.68,0.38,-0.87,-1.24,-2.02,0.85,-0.35,-0.42,-0.3,-0.43,0.73,-0.6 17 | "S",0.41,-0.82,0.57,0.81,-1.08,0.16,0.42,-0.21,-0.43,-1.89,-1.15,-0.97,-0.23,2.39,-1.07,1.15,-1.39,0.67,-0.495,-0.847,-1.079,0.582,0.035,-0.068,-7.44,-0.65,0.68,-0.17,1.58,-0.67,-0.86,-1.07,-0.41,-0.32,0.27,-0.64,0.11,-4.57,-2.55,-0.67,1.11,0.99,-1.02,0.11,0.65,-1.343,-0.311,-0.917,-0.049,0.549,-1.533,0.166,0.28,0.84,-0.45,0.42,0.32,0.2,0.54,0.01,-0.8,0.62,-0.13,-0.8,0.61,-1 18 | "T",0.4,-0.64,0.37,0.26,-0.7,1.21,0.63,-0.1,0.21,0.24,-1.15,-0.56,0.19,0.75,-2.18,-1.12,-1.46,-0.4,-0.032,-0.45,-0.61,0.341,0.117,0.577,-5.97,-0.62,1.11,0.31,0.95,-0.34,-0.51,-0.55,-1.06,-0.06,-0.01,-0.79,0.39,-2,-1.77,-0.7,1.02,1.06,-1.2,0.74,1.65,-1.061,-0.928,-0.911,-0.063,0.538,-0.775,-0.147,-0.717,0.19,-0.73,0.18,-0.01,0.02,0.38,-0.3,-1.96,0.15,0.06,-0.58,0.85,-0.89 19 | "W",0.06,1,-0.47,0.3,2.1,-0.72,-1.57,-1.16,0.57,-0.48,-0.4,-2.3,-0.6,-4.36,3.94,0.59,3.44,-1.59,0.844,-0.075,2.069,-1.36,-0.81,-0.38,5.73,-2.67,-0.07,-1.96,-0.54,1.5,2.06,1.79,0.75,0.75,-0.13,-1.01,-0.85,7.33,4.55,2.77,-2.41,-1.08,1.04,0.23,0.59,0.853,0.039,0.26,-1.163,0.16,-0.202,1.01,0.195,-1.58,2.28,1.17,-1.61,0.12,0.24,-0.54,-0.4,-0.35,0.5,1,0.98,-0.47 20 | "Y",0.31,0.42,-0.2,1.38,1.48,0.8,-0.56,0,-0.68,-0.31,1.03,-0.05,0.53,-2.54,2.44,0.43,0.04,-1.47,0.329,-0.858,1.753,-0.479,-0.835,0.289,2.08,-0.47,0.07,-1.67,-0.35,0.61,1.6,1.17,0.73,0.53,0.25,-0.96,-0.52,3.14,3.59,2.45,-1.27,-0.06,-0.29,1.99,0.3,0.308,0.569,1.1,-0.464,-0.144,-0.354,-1.099,0.162,-1.14,1.74,-0.58,0.75,-0.12,-0.48,0.24,-0.25,0.71,-0.25,0.97,0.66,-0.16 21 | "V",-1,-0.43,-0.14,-0.74,-0.71,2.04,-0.4,0.5,-0.81,-1.07,0.06,-0.46,0.65,-2.59,-2.64,-1.54,-0.85,-0.02,-1.332,0.545,0.029,1.026,-0.229,1.038,-5.87,-0.94,0.28,1.1,0.48,0.76,-0.92,-0.17,-1.91,0.22,-1.4,-0.24,-0.03,5.04,-2.9,-2.29,1.38,0.06,0.08,1.79,-0.38,-1.133,-0.893,-0.325,0.303,-0.561,-0.175,-0.02,-0.311,-1.13,-1.23,-0.63,0.06,-0.6,0.16,0.01,0.02,0.25,0.61,-1,0.79,-0.58 22 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | import random 5 | import json 6 | import pandas as pd 7 | 8 | import torch 9 | from torch.utils.data import Dataset 10 | from torchvision import transforms 11 | import torchvision.transforms.functional as TF 12 | 13 | 14 | def simple_conversion(seq): 15 | """Create 26-dim embedding""" 16 | chars = [ 17 | "-", 18 | "M", 19 | "R", 20 | "H", 21 | "K", 22 | "D", 23 | "E", 24 | "S", 25 | "T", 26 | "N", 27 | "Q", 28 | "C", 29 | "U", 30 | "G", 31 | "P", 32 | "A", 33 | "V", 34 | "I", 35 | "F", 36 | "Y", 37 | "W", 38 | "L", 39 | "O", 40 | "X", 41 | "Z", 42 | "B", 43 | "J", 44 | ] 45 | 46 | nums = range(len(chars)) 47 | 48 | seqs_x = np.zeros(len(seq)) 49 | 50 | for idx, char in enumerate(seq): 51 | 52 | lui = chars.index(char) 53 | 54 | seqs_x[idx] = nums[lui] 55 | 56 | return torch.tensor([seqs_x]).long() 57 | 58 | 59 | def convert_descriptor(seq): 60 | seq_dict = { 61 | "": 0, 62 | "M": 1, 63 | "R": 2, 64 | "H": 3, 65 | "K": 4, 66 | "D": 5, 67 | "E": 6, 68 | "S": 7, 69 | "T": 8, 70 | "N": 9, 71 | "Q": 10, 72 | "C": 11, 73 | "G": 12, 74 | "P": 13, 75 | "A": 14, 76 | "V": 15, 77 | "I": 16, 78 | "F": 17, 79 | "Y": 18, 80 | "W": 19, 81 | "L": 20, 82 | "": 21, 83 | } 84 | seq = seq.upper() 85 | return torch.tensor([seq_dict[char] for char in seq]).long() 86 | 87 | 88 | class OpenCellLoader(Dataset): 89 | """imports mined opencell images with protein sequence""" 90 | 91 | def __init__( 92 | self, 93 | data_csv, 94 | split_key=None, 95 | crop_size=600, 96 | crop_method="random", 97 | sequence_mode="simple", 98 | vocab="bert", 99 | threshold=False, 100 | text_seq_len=0, 101 | ): 102 | self.data_csv = data_csv 103 | self.image_folders = [] 104 | self.crop_method = crop_method 105 | self.crop_size = crop_size 106 | self.sequence_mode = sequence_mode 107 | self.threshold = threshold 108 | self.text_seq_len = int(text_seq_len) 109 | self.vocab = vocab 110 | 111 | if self.sequence_mode == "embedding" or self.sequence_mode == "onehot": 112 | 113 | from tape import TAPETokenizer 114 | 115 | if self.vocab == "unirep" or self.sequence_mode == "onehot": 116 | self.tokenizer = TAPETokenizer(vocab="unirep") 117 | 118 | elif self.vocab == "bert": 119 | self.tokenizer = TAPETokenizer(vocab="iupac") 120 | 121 | elif self.vocab == "esm1b": 122 | from esm import Alphabet 123 | 124 | self.tokenizer = Alphabet.from_architecture( 125 | "ESM-1b" 126 | ).get_batch_converter() 127 | 128 | data = pd.read_csv(data_csv) 129 | 130 | self.parent_path = os.path.dirname(data_csv).split(data_csv)[0] 131 | 132 | if split_key == "train": 133 | self.data = data[data["split"] == "train"] 134 | elif split_key == "val": 135 | self.data = data[data["split"] == "val"] 136 | else: 137 | self.data = data 138 | 139 | self.data = self.data.reset_index(drop=True) 140 | 141 | def __len__(self): 142 | return len(self.data) 143 | 144 | def __getitem__(self, idx): 145 | 146 | protein_vector = self.get_protein_vector(idx) 147 | 148 | nucleus, target, threshold = self.get_images(idx) 149 | 150 | data_dict = { 151 | "nucleus": nucleus.float(), 152 | "target": target.float(), 153 | "threshold": threshold.float(), 154 | "sequence": protein_vector.long(), 155 | } 156 | 157 | return data_dict 158 | 159 | def get_protein_vector(self, idx): 160 | 161 | if "protein_sequence" not in self.data.columns: 162 | 163 | metadata = self.retrieve_metadata(idx) 164 | protein_sequence = metadata["sequence"] 165 | else: 166 | protein_sequence = self.data.iloc[idx]["protein_sequence"] 167 | 168 | protein_vector = self.tokenize_seqeuence(protein_sequence) 169 | 170 | return protein_vector 171 | 172 | def get_images(self, idx): 173 | 174 | nucleus = Image.open( 175 | os.path.join(self.parent_path, self.data.iloc[idx]["nucleus_image_path"]) 176 | ) 177 | target = Image.open( 178 | os.path.join(self.parent_path, self.data.iloc[idx]["protein_image_path"]) 179 | ) 180 | 181 | # from https://discuss.pytorch.org/t/how-to-apply-same-transform-on-a-pair-of-picture/14914 182 | 183 | if self.crop_method == "random": 184 | 185 | # Random crop 186 | i, j, h, w = transforms.RandomCrop.get_params( 187 | nucleus, output_size=(self.crop_size, self.crop_size) 188 | ) 189 | 190 | nucleus = TF.crop(nucleus, i, j, h, w) 191 | target = TF.crop(target, i, j, h, w) 192 | 193 | # Random horizontal flipping 194 | if random.random() > 0.5: 195 | nucleus = TF.hflip(nucleus) 196 | target = TF.hflip(target) 197 | 198 | # Random vertical flipping 199 | if random.random() > 0.5: 200 | nucleus = TF.vflip(nucleus) 201 | target = TF.vflip(target) 202 | 203 | elif self.crop_method == "center": 204 | nucleus = TF.center_crop(nucleus, self.crop_size) 205 | target = TF.center_crop(target, self.crop_size) 206 | 207 | nucleus = TF.to_tensor(nucleus) 208 | target = TF.to_tensor(target) 209 | 210 | threshold = target 211 | 212 | if self.threshold: 213 | threshold = 1.0 * (threshold > (torch.mean(threshold))) 214 | 215 | return nucleus, target, threshold 216 | 217 | def retrieve_metadata(self, idx): 218 | 219 | with open( 220 | os.path.join(self.parent_path, self.data.iloc[idx]["metadata_path"]) 221 | ) as f: 222 | metadata = json.load(f) 223 | 224 | return metadata 225 | 226 | def tokenize_seqeuence(self, protein_sequence): 227 | 228 | prot_len = len(protein_sequence) 229 | 230 | if prot_len > self.text_seq_len: 231 | start_int = np.random.randint(0, len(protein_sequence) - self.text_seq_len) 232 | protein_sequence = protein_sequence[ 233 | start_int : start_int + self.text_seq_len 234 | ] 235 | 236 | if self.sequence_mode == "simple": 237 | protein_vector = simple_conversion(protein_sequence) 238 | 239 | elif self.sequence_mode == "center": 240 | protein_sequence = protein_sequence.center(self.text_seq_length, "-") 241 | protein_vector = simple_conversion(protein_sequence) 242 | 243 | elif self.sequence_mode == "alternating": 244 | protein_sequence = protein_sequence.center(self.text_seq_length, "-") 245 | protein_sequence = protein_sequence[::18] 246 | protein_sequence = protein_sequence.center( 247 | int(self.text_seq_length / 18) + 1, "-" 248 | ) 249 | protein_vector = simple_conversion(protein_sequence) 250 | 251 | elif self.sequence_mode == "onehot": 252 | 253 | protein_vector = torch.tensor([self.tokenizer.encode(protein_sequence)])[ 254 | :, 1:-1 255 | ] 256 | 257 | elif self.sequence_mode == "aadescriptors": 258 | 259 | protein_vector = convert_descriptor(protein_sequence).long().unsqueeze(0) 260 | 261 | elif self.sequence_mode == "embedding": 262 | 263 | if self.vocab == "esm1b": 264 | pad_token = 1 265 | 266 | protein_vector = self.tokenizer([("", protein_sequence)])[-1][:, 1:] 267 | 268 | elif self.vocab == "unirep" or self.vocab == "bert": 269 | pad_token = 0 270 | protein_vector = torch.tensor( 271 | [self.tokenizer.encode(protein_sequence)] 272 | )[:, 1:] 273 | 274 | 275 | if prot_len > self.text_seq_len: 276 | protein_vector = protein_vector[:, :-1] 277 | elif prot_len == self.text_seq_len: 278 | protein_vector = protein_vector[:, :-2] 279 | 280 | if protein_vector.shape[-1] < self.text_seq_len: 281 | diff = self.text_seq_len - protein_vector.shape[-1] 282 | protein_vector = torch.nn.functional.pad( 283 | protein_vector, (0, diff), "constant", pad_token 284 | ) 285 | 286 | return protein_vector.long() 287 | 288 | else: 289 | 290 | assert("No valid sequence mode selected") 291 | 292 | if protein_vector.shape[-1] + 1 < self.text_seq_len: 293 | diff = self.text_seq_len - protein_vector.shape[-1] 294 | protein_vector = torch.nn.functional.pad( 295 | protein_vector, (0, diff), "constant", 0 296 | ) 297 | 298 | return protein_vector.long() 299 | -------------------------------------------------------------------------------- /images/generate.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoHuangLab/Protein-Localization-Transformer/9d7d0bb4296a0363d1af21a7e50ddc8672e0eca6/images/generate.gif -------------------------------------------------------------------------------- /images/huanglogo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoHuangLab/Protein-Localization-Transformer/9d7d0bb4296a0363d1af21a7e50ddc8672e0eca6/images/huanglogo.jpeg -------------------------------------------------------------------------------- /images/nucleus.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoHuangLab/Protein-Localization-Transformer/9d7d0bb4296a0363d1af21a7e50ddc8672e0eca6/images/nucleus.jpg -------------------------------------------------------------------------------- /images/preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoHuangLab/Protein-Localization-Transformer/9d7d0bb4296a0363d1af21a7e50ddc8672e0eca6/images/preview.png -------------------------------------------------------------------------------- /notebooks/Demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "protein_sequence = ''\n", 10 | "nucleus_image = 'images/nucleus.jpg'\n", 11 | "protein_name = None\n", 12 | "device = \"cuda:0\"\n", 13 | "config_file = 'configs/celle.yaml'\n", 14 | "ckpt_path = None" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "#run once\n", 24 | "import os\n", 25 | "\n", 26 | "if 'notebooks' in os.getcwd():\n", 27 | " os.chdir('..')\n", 28 | "\n", 29 | "import torch\n", 30 | "import numpy as np\n", 31 | "from matplotlib import pyplot as plt\n", 32 | "from matplotlib import cm\n", 33 | "from matplotlib.colors import LinearSegmentedColormap\n", 34 | "import torchvision\n", 35 | "\n", 36 | "from einops import rearrange\n", 37 | "from omegaconf import OmegaConf\n", 38 | "\n", 39 | "from celle_main import instantiate_from_config\n", 40 | "from dataloader import OpenCellLoader\n", 41 | "\n", 42 | "# color map for plot\n", 43 | "color_array = plt.get_cmap('gist_rainbow')(range(256))\n", 44 | "color_array[:,-1] = np.linspace(1.0,0.0,256)\n", 45 | "map_object = LinearSegmentedColormap.from_list(name='rainbow_alpha',colors=color_array[::-1])\n", 46 | "plt.register_cmap(cmap=map_object)\n", 47 | "\n", 48 | "device = torch.device(device)\n", 49 | "\n", 50 | "#load model\n", 51 | "configs = OmegaConf.load(config_file);\n", 52 | "model = instantiate_from_config(configs.model).to(device);\n", 53 | "if ckpt_path:\n", 54 | " t = torch.load(ckpt_path,map_location = 'cpu')['state_dict'];\n", 55 | " for key in list(t.keys()):\n", 56 | " t[key.replace('celle.','')] = t.pop(key);\n", 57 | "model.celle.load_state_dict(t,strict=False);\n", 58 | "model = model.celle\n", 59 | "model = model.to(device)\n", 60 | "model = model.eval()\n", 61 | "\n", 62 | "# get some params\n", 63 | "crop_size = configs.data.params.crop_size\n", 64 | "sequence_mode = configs.data.params.sequence_mode\n", 65 | "vocab = configs.data.params.vocab\n", 66 | "threshold = configs.data.params.threshold\n", 67 | "text_seq_len = configs.data.params.text_seq_len\n", 68 | "\n", 69 | "# convert string to numbered index\n", 70 | "dataset = OpenCellLoader(crop_size=crop_size, sequence_mode=sequence_mode, vocab=vocab, threshold=threshold, text_seq_len=text_seq_len)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "protein_sequence = ''.join(filter(str.isalpha, protein_sequence)) \n", 80 | "protein_sequence = dataset.tokenize_seqeuence(protein_sequence)\n", 81 | "\n", 82 | "# import nucleus, scale and crop\n", 83 | "nucleus = torch.tensor(plt.imread(nucleus_image)).float()\n", 84 | "nucleus /= 255\n", 85 | "nucleus = torchvision.transforms.RandomCrop(256)(nucleus).unsqueeze(0).unsqueeze(0)\n", 86 | "\n", 87 | "# generate image\n", 88 | "with torch.no_grad():\n", 89 | " output = model.generate_images(text=protein_sequence.to(device), condition = nucleus.to(device), return_logits=True, use_cache=True, progress=True)\n", 90 | " \n", 91 | " logits = output[-1][:,-256:,-512:]\n", 92 | " image_tokens = logits @ model.vae.model.quantize.embedding.weight\n", 93 | " image_tokens = rearrange(image_patches, \"b (h w) c -> b c h w\", h=int(np.sqrt(256)))\n", 94 | " pdf = model.vae.model.decode(image_tokens)\n", 95 | " pdf = torch.clip(pdf,0,1)\n", 96 | " \n", 97 | " plt.figure(dpi=300, clear=True) \n", 98 | " plt.axis('off')\n", 99 | " plt.imshow(nucleus[0,0],cmap='gray',interpolation='bicubic')\n", 100 | " plt.imshow(pdf.cpu()[0,0],cmap='rainbow_alpha',alpha = .75,interpolation='bicubic')\n", 101 | " plt.colorbar(mappable=cm.ScalarMappable(cmap='rainbow_alpha'))\n", 102 | " \n", 103 | " if protein_name:\n", 104 | " plt.title(protein_name)" 105 | ] 106 | } 107 | ], 108 | "metadata": { 109 | "interpreter": { 110 | "hash": "f703332b1593e5986aec844d60dd2796d9b0ddf157e3991cb22534f7b76c19d6" 111 | }, 112 | "kernelspec": { 113 | "display_name": "Python 3.8.5 ('env': venv)", 114 | "language": "python", 115 | "name": "python3" 116 | }, 117 | "language_info": { 118 | "codemirror_mode": { 119 | "name": "ipython", 120 | "version": 3 121 | }, 122 | "file_extension": ".py", 123 | "mimetype": "text/x-python", 124 | "name": "python", 125 | "nbconvert_exporter": "python", 126 | "pygments_lexer": "ipython3", 127 | "version": "3.8.5" 128 | }, 129 | "orig_nbformat": 4 130 | }, 131 | "nbformat": 4, 132 | "nbformat_minor": 2 133 | } 134 | -------------------------------------------------------------------------------- /notebooks/grad_map.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class ActivationsAndGradients: 5 | """Class for extracting activations and 6 | registering gradients from targetted intermediate layers""" 7 | 8 | def __init__(self, model, target_layers, reshape_transform): 9 | self.model = model 10 | self.gradients = [] 11 | self.activations = [] 12 | self.reshape_transform = reshape_transform 13 | self.handles = [] 14 | for target_layer in target_layers: 15 | self.handles.append( 16 | target_layer.register_forward_hook(self.save_activation) 17 | ) 18 | # Because of https://github.com/pytorch/pytorch/issues/61519, 19 | # we don't use backward hook to record gradients. 20 | self.handles.append(target_layer.register_forward_hook(self.save_gradient)) 21 | 22 | def save_activation(self, module, input, output): 23 | 24 | if len(output) == 2: 25 | output = output["attn"] 26 | 27 | if len(self.activations) == 0: 28 | self.min_val = torch.min(output) 29 | 30 | attn = torch.nn.functional.pad( 31 | output, 32 | (0, 1512 - (output.shape[-1]), 0, 0, 0, 0, 0, 0), 33 | "constant", 34 | self.min_val, 35 | ) 36 | 37 | self.activations.append(attn.cpu().detach()) 38 | 39 | def save_gradient(self, module, input, output): 40 | # if not hasattr(output, "requires_grad") or not output.requires_grad: 41 | # # You can only register hooks on tensor requires grad. 42 | # return 43 | 44 | # Gradients are computed in reverse order 45 | def _store_grad(grad): 46 | if self.reshape_transform is not None: 47 | grad = self.reshape_transform(grad) 48 | self.gradients = [grad.cpu().detach()] + self.gradients 49 | 50 | # output.register_hook(_store_grad(output)) 51 | 52 | def __call__(self, x, filter_thres=0.5): 53 | self.gradients = [] 54 | self.activations = [] 55 | sequence, nucleus, target = x 56 | return self.model.generate_images( 57 | text=sequence, 58 | condition=nucleus, 59 | return_logits=True, 60 | progress=True, 61 | use_cache=True, 62 | filter_thres=filter_thres, 63 | ) 64 | 65 | def release(self): 66 | for handle in self.handles: 67 | handle.remove() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | antlr4-python3-runtime==4.8 3 | argon2-cffi==21.1.0 4 | asttokens==2.0.5 5 | attrs==21.2.0 6 | axial-positional-embedding==0.2.1 7 | backcall==0.2.0 8 | biopython==1.79 9 | bleach==4.1.0 10 | boto3==1.20.5 11 | botocore==1.23.5 12 | cachetools==4.2.4 13 | calmsize==0.1.3 14 | certifi==2021.10.8 15 | cffi==1.15.0 16 | charset-normalizer==2.0.7 17 | click==7.1.2 18 | cycler==0.11.0 19 | dataclasses==0.6 20 | debugpy==1.5.1 21 | decorator==5.1.0 22 | defusedxml==0.7.1 23 | einops==0.3.2 24 | entrypoints==0.3 25 | executing==0.8.3 26 | filelock==3.3.2 27 | fsspec==2021.11.0 28 | future==0.18.2 29 | g-mlp-pytorch==0.1.5 30 | google-auth==2.3.3 31 | google-auth-oauthlib==0.4.6 32 | grpcio==1.41.1 33 | idna==3.3 34 | imageio==2.10.3 35 | importlib-metadata==4.11.3 36 | importlib-resources==5.4.0 37 | ipykernel==6.6.0 38 | ipython==7.30.1 39 | ipython-genutils==0.2.0 40 | ipywidgets==7.6.5 41 | jedi==0.18.0 42 | Jinja2==3.0.3 43 | jmespath==0.10.0 44 | joblib==1.1.0 45 | jsonschema==4.2.1 46 | jupyter==1.0.0 47 | jupyter-client==7.0.6 48 | jupyter-console==6.4.0 49 | jupyter-core==4.9.1 50 | jupyterlab-pygments==0.1.2 51 | jupyterlab-widgets==1.0.2 52 | kiwisolver==1.3.2 53 | llvmlite==0.38.0 54 | lmdb==1.2.1 55 | Markdown==3.3.6 56 | MarkupSafe==2.0.1 57 | matplotlib==3.4.3 58 | matplotlib-inline==0.1.3 59 | mistune==0.8.4 60 | nbclient==0.5.8 61 | nbconvert==6.4.3 62 | nbformat==5.1.3 63 | nest-asyncio==1.5.1 64 | nopdb==0.2.0 65 | notebook==6.4.5 66 | numba==0.55.1 67 | numpy==1.21.4 68 | oauthlib==3.1.1 69 | omegaconf==2.0.0 70 | opencv-python==4.5.5.64 71 | packaging==21.2 72 | pandas==1.3.4 73 | pandocfilters==1.5.0 74 | parso==0.8.2 75 | pexpect==4.8.0 76 | pickleshare==0.7.5 77 | Pillow==8.4.0 78 | prometheus-client==0.12.0 79 | prompt-toolkit==3.0.22 80 | protobuf==3.19.1 81 | psutil==5.9.1 82 | ptyprocess==0.7.0 83 | pure-eval==0.2.2 84 | pyasn1==0.4.8 85 | pyasn1-modules==0.2.8 86 | pycparser==2.21 87 | pyDeprecate==0.3.1 88 | Pygments==2.10.0 89 | pynndescent==0.5.6 90 | pyparsing==2.4.7 91 | pyrsistent==0.18.0 92 | python-dateutil==2.8.2 93 | pytorch-lightning==1.0.8 94 | pytorch-memlab==0.2.4 95 | pytz==2021.3 96 | PyYAML==6.0 97 | pyzmq==22.3.0 98 | qtconsole==5.2.0 99 | QtPy==1.11.2 100 | regex==2021.11.10 101 | requests==2.26.0 102 | requests-oauthlib==1.3.0 103 | rotary-embedding-torch==0.1.2 104 | rsa==4.7.2 105 | s3transfer==0.5.0 106 | sacremoses==0.0.46 107 | scikit-learn==1.0.2 108 | scipy==1.7.2 109 | Send2Trash==1.8.0 110 | six==1.16.0 111 | stack-data==0.2.0 112 | tape-proteins==0.5 113 | tensorboard==2.7.0 114 | tensorboard-data-server==0.6.1 115 | tensorboard-plugin-wit==1.8.0 116 | tensorboardX==2.4 117 | terminado==0.12.1 118 | test-tube==0.7.5 119 | testpath==0.5.0 120 | threadpoolctl==3.1.0 121 | tokenizers==0.10.3 122 | torchmetrics==0.6.0 123 | tornado==6.1 124 | tqdm==4.62.3 125 | traitlets==5.1.1 126 | transformers==4.3.1 127 | typing-extensions==4.1.1 128 | umap-learn==0.5.2 129 | urllib3==1.26.7 130 | wcwidth==0.2.5 131 | webencodings==0.5.1 132 | Werkzeug==2.0.2 133 | widgetsnbextension==3.5.2 134 | zipp==3.6.0 135 | -------------------------------------------------------------------------------- /taming/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n): 33 | return self.schedule(n) 34 | 35 | -------------------------------------------------------------------------------- /taming/models/cond_transformer.py: -------------------------------------------------------------------------------- 1 | import os, math 2 | import torch 3 | import torch.nn.functional as F 4 | import pytorch_lightning as pl 5 | 6 | from main import instantiate_from_config 7 | from taming.modules.util import SOSProvider 8 | 9 | 10 | def disabled_train(self, mode=True): 11 | """Overwrite model.train with this function to make sure train/eval mode 12 | does not change anymore.""" 13 | return self 14 | 15 | 16 | class Net2NetTransformer(pl.LightningModule): 17 | def __init__(self, 18 | transformer_config, 19 | first_stage_config, 20 | cond_stage_config, 21 | permuter_config=None, 22 | ckpt_path=None, 23 | ignore_keys=[], 24 | first_stage_key="image", 25 | cond_stage_key="depth", 26 | downsample_cond_size=-1, 27 | pkeep=1.0, 28 | sos_token=0, 29 | unconditional=False, 30 | ): 31 | super().__init__() 32 | self.be_unconditional = unconditional 33 | self.sos_token = sos_token 34 | self.first_stage_key = first_stage_key 35 | self.cond_stage_key = cond_stage_key 36 | self.init_first_stage_from_ckpt(first_stage_config) 37 | self.init_cond_stage_from_ckpt(cond_stage_config) 38 | if permuter_config is None: 39 | permuter_config = {"target": "taming.modules.transformer.permuter.Identity"} 40 | self.permuter = instantiate_from_config(config=permuter_config) 41 | self.transformer = instantiate_from_config(config=transformer_config) 42 | 43 | if ckpt_path is not None: 44 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 45 | self.downsample_cond_size = downsample_cond_size 46 | self.pkeep = pkeep 47 | 48 | def init_from_ckpt(self, path, ignore_keys=list()): 49 | sd = torch.load(path, map_location="cpu")["state_dict"] 50 | for k in sd.keys(): 51 | for ik in ignore_keys: 52 | if k.startswith(ik): 53 | self.print("Deleting key {} from state_dict.".format(k)) 54 | del sd[k] 55 | self.load_state_dict(sd, strict=False) 56 | print(f"Restored from {path}") 57 | 58 | def init_first_stage_from_ckpt(self, config): 59 | model = instantiate_from_config(config) 60 | model = model.eval() 61 | model.train = disabled_train 62 | self.first_stage_model = model 63 | 64 | def init_cond_stage_from_ckpt(self, config): 65 | if config == "__is_first_stage__": 66 | print("Using first stage also as cond stage.") 67 | self.cond_stage_model = self.first_stage_model 68 | elif config == "__is_unconditional__" or self.be_unconditional: 69 | print(f"Using no cond stage. Assuming the training is intended to be unconditional. " 70 | f"Prepending {self.sos_token} as a sos token.") 71 | self.be_unconditional = True 72 | self.cond_stage_key = self.first_stage_key 73 | self.cond_stage_model = SOSProvider(self.sos_token) 74 | else: 75 | model = instantiate_from_config(config) 76 | model = model.eval() 77 | model.train = disabled_train 78 | self.cond_stage_model = model 79 | 80 | def forward(self, x, c): 81 | # one step to produce the logits 82 | # x = target 83 | # c = nucleus 84 | _, z_indices = self.encode_to_z(x) 85 | _, c_indices = self.encode_to_c(c) 86 | 87 | if self.training and self.pkeep < 1.0: 88 | mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape, 89 | device=z_indices.device)) 90 | mask = mask.round().to(dtype=torch.int64) 91 | r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size) 92 | a_indices = mask*z_indices+(1-mask)*r_indices 93 | else: 94 | a_indices = z_indices 95 | 96 | cz_indices = torch.cat((c_indices, a_indices), dim=1) 97 | 98 | # target includes all sequence elements (no need to handle first one 99 | # differently because we are conditioning) 100 | target = z_indices 101 | # make the prediction 102 | logits, _ = self.transformer(cz_indices[:, :-1]) 103 | # cut off conditioning outputs - output i corresponds to p(z_i | z_{ -1: 180 | c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size)) 181 | 182 | #quant_c, _, info = self.cond_stage_model.encode(x) 183 | #indices = info[2].view(quant_c.shape[0], -1) 184 | #indices = self.permuter(indices) 185 | quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c) 186 | if len(indices.shape) != 2: 187 | indices = indices.view(c.shape[0], -1) 188 | return quant_c, indices 189 | 190 | @torch.no_grad() 191 | def decode_to_img(self, index, zshape): 192 | index = self.permuter(index, reverse=True) 193 | bhwc = (zshape[0],zshape[2],zshape[3],zshape[1]) 194 | quant_z = self.first_stage_model.quantize.get_codebook_entry( 195 | index.reshape(-1), shape=bhwc) 196 | x = self.first_stage_model.decode(quant_z) 197 | return x 198 | 199 | @torch.no_grad() 200 | def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs): 201 | log = dict() 202 | 203 | N = 4 204 | if lr_interface: 205 | x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8) 206 | else: 207 | x, c = self.get_xc(batch, N) 208 | x = x.to(device=self.device) 209 | c = c.to(device=self.device) 210 | 211 | quant_z, z_indices = self.encode_to_z(x) 212 | quant_c, c_indices = self.encode_to_c(c) 213 | 214 | # create a "half"" sample 215 | z_start_indices = z_indices[:,:z_indices.shape[1]//2] 216 | index_sample = self.sample(z_start_indices, c_indices, 217 | steps=z_indices.shape[1]-z_start_indices.shape[1], 218 | temperature=temperature if temperature is not None else 1.0, 219 | sample=True, 220 | top_k=top_k if top_k is not None else 100, 221 | callback=callback if callback is not None else lambda k: None) 222 | x_sample = self.decode_to_img(index_sample, quant_z.shape) 223 | 224 | # sample 225 | z_start_indices = z_indices[:, :0] 226 | index_sample = self.sample(z_start_indices, c_indices, 227 | steps=z_indices.shape[1], 228 | temperature=temperature if temperature is not None else 1.0, 229 | sample=True, 230 | top_k=top_k if top_k is not None else 100, 231 | callback=callback if callback is not None else lambda k: None) 232 | x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape) 233 | 234 | # det sample 235 | z_start_indices = z_indices[:, :0] 236 | index_sample = self.sample(z_start_indices, c_indices, 237 | steps=z_indices.shape[1], 238 | sample=False, 239 | callback=callback if callback is not None else lambda k: None) 240 | x_sample_det = self.decode_to_img(index_sample, quant_z.shape) 241 | 242 | # reconstruction 243 | x_rec = self.decode_to_img(z_indices, quant_z.shape) 244 | 245 | log["inputs"] = x 246 | log["reconstructions"] = x_rec 247 | 248 | if self.cond_stage_key != "image" or self.cond_stage_key != "nucleus" or self.cond_stage_key != "target": 249 | cond_rec = self.cond_stage_model.decode(quant_c) 250 | if self.cond_stage_key == "segmentation": 251 | # get image from segmentation mask 252 | num_classes = cond_rec.shape[1] 253 | 254 | c = torch.argmax(c, dim=1, keepdim=True) 255 | c = F.one_hot(c, num_classes=num_classes) 256 | c = c.squeeze(1).permute(0, 3, 1, 2).float() 257 | c = self.cond_stage_model.to_rgb(c) 258 | 259 | cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True) 260 | cond_rec = F.one_hot(cond_rec, num_classes=num_classes) 261 | cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float() 262 | cond_rec = self.cond_stage_model.to_rgb(cond_rec) 263 | log["conditioning_rec"] = cond_rec 264 | log["conditioning"] = c 265 | 266 | log["samples_half"] = x_sample 267 | log["samples_nopix"] = x_sample_nopix 268 | log["samples_det"] = x_sample_det 269 | return log 270 | 271 | def get_input(self, key, batch): 272 | x = batch[key] 273 | if len(x.shape) == 3: 274 | x = x[..., None] 275 | #if len(x.shape) == 4: 276 | # x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) 277 | if x.dtype == torch.double: 278 | x = x.float() 279 | return x 280 | 281 | def get_xc(self, batch, N=None): 282 | x = self.get_input(self.first_stage_key, batch) 283 | c = self.get_input(self.cond_stage_key, batch) 284 | if N is not None: 285 | x = x[:N] 286 | c = c[:N] 287 | return x, c 288 | 289 | def shared_step(self, batch): 290 | x, c = self.get_xc(batch) 291 | logits, target = self(x, c) 292 | loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1)) 293 | return loss 294 | 295 | def training_step(self, batch, batch_idx): 296 | loss = self.shared_step(batch) 297 | self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 298 | return loss 299 | 300 | def validation_step(self, batch, batch_idx): 301 | loss = self.shared_step(batch) 302 | self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 303 | return loss 304 | 305 | def configure_optimizers(self): 306 | """ 307 | Following minGPT: 308 | This long function is unfortunately doing something very simple and is being very defensive: 309 | We are separating out all parameters of the model into two buckets: those that will experience 310 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 311 | We are then returning the PyTorch optimizer object. 312 | """ 313 | # separate out all parameters to those that will and won't experience regularizing weight decay 314 | decay = set() 315 | no_decay = set() 316 | whitelist_weight_modules = (torch.nn.Linear, ) 317 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) 318 | for mn, m in self.transformer.named_modules(): 319 | for pn, p in m.named_parameters(): 320 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 321 | 322 | if pn.endswith('bias'): 323 | # all biases will not be decayed 324 | no_decay.add(fpn) 325 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 326 | # weights of whitelist modules will be weight decayed 327 | decay.add(fpn) 328 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 329 | # weights of blacklist modules will NOT be weight decayed 330 | no_decay.add(fpn) 331 | 332 | # special case the position embedding parameter in the root GPT module as not decayed 333 | no_decay.add('pos_emb') 334 | 335 | # validate that we considered every parameter 336 | param_dict = {pn: p for pn, p in self.transformer.named_parameters()} 337 | inter_params = decay & no_decay 338 | union_params = decay | no_decay 339 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) 340 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ 341 | % (str(param_dict.keys() - union_params), ) 342 | 343 | # create the pytorch optimizer object 344 | optim_groups = [ 345 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01}, 346 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 347 | ] 348 | optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95)) 349 | return optimizer 350 | -------------------------------------------------------------------------------- /taming/models/dummy_cond_stage.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | 4 | class DummyCondStage: 5 | def __init__(self, conditional_key): 6 | self.conditional_key = conditional_key 7 | self.train = None 8 | 9 | def eval(self): 10 | return self 11 | 12 | @staticmethod 13 | def encode(c: Tensor): 14 | return c, None, (None, None, c) 15 | 16 | @staticmethod 17 | def decode(c: Tensor): 18 | return c 19 | 20 | @staticmethod 21 | def to_rgb(c: Tensor): 22 | return c 23 | -------------------------------------------------------------------------------- /taming/models/vqgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import pytorch_lightning as pl 4 | 5 | from celle_taming_main import instantiate_from_config 6 | 7 | from taming.modules.diffusionmodules.model import Encoder, Decoder 8 | from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer 9 | from taming.modules.vqvae.quantize import GumbelQuantize 10 | from taming.modules.vqvae.quantize import EMAVectorQuantizer 11 | 12 | 13 | class VQModel(pl.LightningModule): 14 | def __init__( 15 | self, 16 | ddconfig, 17 | lossconfig, 18 | n_embed, 19 | embed_dim, 20 | ckpt_path=None, 21 | ignore_keys=[], 22 | image_key="image", 23 | colorize_nlabels=None, 24 | monitor=None, 25 | remap=None, 26 | sane_index_shape=False, # tell vector quantizer to return indices as bhw 27 | ): 28 | super().__init__() 29 | self.image_key = image_key 30 | self.encoder = Encoder(**ddconfig) 31 | self.decoder = Decoder(**ddconfig) 32 | self.loss = instantiate_from_config(lossconfig) 33 | self.quantize = VectorQuantizer( 34 | n_embed, 35 | embed_dim, 36 | beta=0.25, 37 | remap=remap, 38 | sane_index_shape=sane_index_shape, 39 | ) 40 | self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) 41 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 42 | if ckpt_path is not None: 43 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 44 | self.image_key = image_key 45 | if colorize_nlabels is not None: 46 | assert type(colorize_nlabels) == int 47 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 48 | if monitor is not None: 49 | self.monitor = monitor 50 | 51 | def init_from_ckpt(self, path, ignore_keys=list()): 52 | sd = torch.load(path, map_location="cpu")["state_dict"] 53 | keys = list(sd.keys()) 54 | for k in keys: 55 | for ik in ignore_keys: 56 | if k.startswith(ik): 57 | print("Deleting key {} from state_dict.".format(k)) 58 | del sd[k] 59 | self.load_state_dict(sd, strict=False) 60 | print(f"Restored from {path}") 61 | 62 | def encode(self, x): 63 | h = self.encoder(x) 64 | h = self.quant_conv(h) 65 | quant, emb_loss, info = self.quantize(h) 66 | return quant, emb_loss, info 67 | 68 | def decode(self, quant): 69 | quant = self.post_quant_conv(quant) 70 | dec = self.decoder(quant) 71 | return dec 72 | 73 | def decode_code(self, code_b): 74 | quant_b = self.quantize.embed_code(code_b) 75 | dec = self.decode(quant_b) 76 | return dec 77 | 78 | def forward(self, input): 79 | quant, diff, _ = self.encode(input) 80 | dec = self.decode(quant) 81 | return dec, diff 82 | 83 | def get_input(self, batch, k): 84 | 85 | if k == "mixed": 86 | keys = ["nucleus", "target"] 87 | index = torch.randint(low=0, high=2, size=(1,), dtype=int).item() 88 | k = keys[index] 89 | 90 | x = batch[k] 91 | if len(x.shape) == 3: 92 | x = x[..., None] 93 | 94 | # x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) 95 | return x 96 | 97 | def training_step(self, batch, batch_idx=None, optimizer_idx=0): 98 | 99 | if type(batch) == dict: 100 | 101 | x = self.get_input(batch, self.image_key) 102 | 103 | else: 104 | x = batch 105 | 106 | xrec, qloss = self( 107 | x, 108 | ) 109 | 110 | if optimizer_idx == 0: 111 | # autoencode 112 | aeloss, log_dict_ae = self.loss( 113 | qloss, 114 | x, 115 | xrec, 116 | optimizer_idx, 117 | self.global_step, 118 | last_layer=self.get_last_layer(), 119 | split="train", 120 | ) 121 | 122 | self.log( 123 | "train/aeloss", 124 | aeloss, 125 | prog_bar=True, 126 | logger=True, 127 | on_step=True, 128 | on_epoch=True, 129 | ) 130 | self.log_dict( 131 | log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True 132 | ) 133 | return aeloss 134 | 135 | if optimizer_idx == 1: 136 | # discriminator 137 | discloss, log_dict_disc = self.loss( 138 | qloss, 139 | x, 140 | xrec, 141 | optimizer_idx, 142 | self.global_step, 143 | last_layer=self.get_last_layer(), 144 | split="train", 145 | ) 146 | self.log( 147 | "train/discloss", 148 | discloss, 149 | prog_bar=True, 150 | logger=True, 151 | on_step=True, 152 | on_epoch=True, 153 | ) 154 | self.log_dict( 155 | log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True 156 | ) 157 | return discloss 158 | 159 | def validation_step(self, batch, batch_idx): 160 | 161 | if type(batch) == dict: 162 | 163 | x = self.get_input(batch, self.image_key) 164 | 165 | else: 166 | x = batch 167 | 168 | xrec, qloss = self(x) 169 | aeloss, log_dict_ae = self.loss( 170 | qloss, 171 | x, 172 | xrec, 173 | 0, 174 | self.global_step, 175 | last_layer=self.get_last_layer(), 176 | split="val", 177 | ) 178 | 179 | discloss, log_dict_disc = self.loss( 180 | qloss, 181 | x, 182 | xrec, 183 | 1, 184 | self.global_step, 185 | last_layer=self.get_last_layer(), 186 | split="val", 187 | ) 188 | rec_loss = log_dict_ae["val/rec_loss"] 189 | self.log( 190 | "val/rec_loss", 191 | rec_loss, 192 | prog_bar=True, 193 | logger=True, 194 | on_step=True, 195 | on_epoch=True, 196 | sync_dist=True, 197 | ) 198 | self.log( 199 | "val/aeloss", 200 | aeloss, 201 | prog_bar=True, 202 | logger=True, 203 | on_step=True, 204 | on_epoch=True, 205 | sync_dist=True, 206 | ) 207 | self.log_dict(log_dict_ae) 208 | self.log_dict(log_dict_disc) 209 | return self.log_dict 210 | 211 | def configure_optimizers(self): 212 | lr = self.learning_rate 213 | opt_ae = torch.optim.Adam( 214 | list(self.encoder.parameters()) 215 | + list(self.decoder.parameters()) 216 | + list(self.quantize.parameters()) 217 | + list(self.quant_conv.parameters()) 218 | + list(self.post_quant_conv.parameters()), 219 | lr=lr, 220 | betas=(0.5, 0.9), 221 | ) 222 | opt_disc = torch.optim.Adam( 223 | self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9) 224 | ) 225 | return [opt_ae, opt_disc], [] 226 | 227 | def get_last_layer(self): 228 | return self.decoder.conv_out.weight 229 | 230 | def log_images(self, batch, **kwargs): 231 | log = dict() 232 | x = self.get_input(batch, self.image_key) 233 | x = x.to(self.device) 234 | xrec, _ = self(x) 235 | if x.shape[1] > 3: 236 | # colorize with random projection 237 | assert xrec.shape[1] > 3 238 | x = self.to_rgb(x) 239 | xrec = self.to_rgb(xrec) 240 | log["inputs"] = x 241 | log["reconstructions"] = xrec 242 | return log 243 | 244 | def to_rgb(self, x): 245 | assert self.image_key == "segmentation" 246 | if not hasattr(self, "colorize"): 247 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 248 | x = F.conv2d(x, weight=self.colorize) 249 | x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 250 | return x 251 | 252 | 253 | class VQSegmentationModel(VQModel): 254 | def __init__(self, n_labels, *args, **kwargs): 255 | super().__init__(*args, **kwargs) 256 | self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1)) 257 | 258 | def configure_optimizers(self): 259 | lr = self.learning_rate 260 | opt_ae = torch.optim.Adam( 261 | list(self.encoder.parameters()) 262 | + list(self.decoder.parameters()) 263 | + list(self.quantize.parameters()) 264 | + list(self.quant_conv.parameters()) 265 | + list(self.post_quant_conv.parameters()), 266 | lr=lr, 267 | betas=(0.5, 0.9), 268 | ) 269 | return opt_ae 270 | 271 | def training_step(self, batch, batch_idx): 272 | x = self.get_input(batch, self.image_key) 273 | xrec, qloss = self(x) 274 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train") 275 | self.log_dict( 276 | log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True 277 | ) 278 | return aeloss 279 | 280 | def validation_step(self, batch, batch_idx): 281 | x = self.get_input(batch, self.image_key) 282 | xrec, qloss = self(x) 283 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val") 284 | self.log_dict( 285 | log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True 286 | ) 287 | total_loss = log_dict_ae["val/total_loss"] 288 | self.log( 289 | "val/total_loss", 290 | total_loss, 291 | prog_bar=True, 292 | logger=True, 293 | on_step=True, 294 | on_epoch=True, 295 | sync_dist=True, 296 | ) 297 | return aeloss 298 | 299 | @torch.no_grad() 300 | def log_images(self, batch, **kwargs): 301 | log = dict() 302 | x = self.get_input(batch, self.image_key) 303 | x = x.to(self.device) 304 | xrec, _ = self(x) 305 | if x.shape[1] > 3: 306 | # colorize with random projection 307 | assert xrec.shape[1] > 3 308 | # convert logits to indices 309 | xrec = torch.argmax(xrec, dim=1, keepdim=True) 310 | xrec = F.one_hot(xrec, num_classes=x.shape[1]) 311 | xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float() 312 | x = self.to_rgb(x) 313 | xrec = self.to_rgb(xrec) 314 | log["inputs"] = x 315 | log["reconstructions"] = xrec 316 | return log 317 | 318 | 319 | class VQNoDiscModel(VQModel): 320 | def __init__( 321 | self, 322 | ddconfig, 323 | lossconfig, 324 | n_embed, 325 | embed_dim, 326 | ckpt_path=None, 327 | ignore_keys=[], 328 | image_key="image", 329 | colorize_nlabels=None, 330 | ): 331 | super().__init__( 332 | ddconfig=ddconfig, 333 | lossconfig=lossconfig, 334 | n_embed=n_embed, 335 | embed_dim=embed_dim, 336 | ckpt_path=ckpt_path, 337 | ignore_keys=ignore_keys, 338 | image_key=image_key, 339 | colorize_nlabels=colorize_nlabels, 340 | ) 341 | 342 | def training_step(self, batch, batch_idx): 343 | x = self.get_input(batch, self.image_key) 344 | xrec, qloss = self(x) 345 | # autoencode 346 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train") 347 | output = pl.TrainResult(minimize=aeloss) 348 | output.log( 349 | "train/aeloss", 350 | aeloss, 351 | prog_bar=True, 352 | logger=True, 353 | on_step=True, 354 | on_epoch=True, 355 | ) 356 | output.log_dict( 357 | log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True 358 | ) 359 | return output 360 | 361 | def validation_step(self, batch, batch_idx): 362 | x = self.get_input(batch, self.image_key) 363 | xrec, qloss = self(x) 364 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val") 365 | rec_loss = log_dict_ae["val/rec_loss"] 366 | output = pl.EvalResult(checkpoint_on=rec_loss) 367 | output.log( 368 | "val/rec_loss", 369 | rec_loss, 370 | prog_bar=True, 371 | logger=True, 372 | on_step=True, 373 | on_epoch=True, 374 | ) 375 | output.log( 376 | "val/aeloss", 377 | aeloss, 378 | prog_bar=True, 379 | logger=True, 380 | on_step=True, 381 | on_epoch=True, 382 | ) 383 | output.log_dict(log_dict_ae) 384 | 385 | return output 386 | 387 | def configure_optimizers(self): 388 | optimizer = torch.optim.Adam( 389 | list(self.encoder.parameters()) 390 | + list(self.decoder.parameters()) 391 | + list(self.quantize.parameters()) 392 | + list(self.quant_conv.parameters()) 393 | + list(self.post_quant_conv.parameters()), 394 | lr=self.learning_rate, 395 | betas=(0.5, 0.9), 396 | ) 397 | return optimizer 398 | 399 | 400 | class GumbelVQ(VQModel): 401 | def __init__( 402 | self, 403 | ddconfig, 404 | lossconfig, 405 | n_embed, 406 | embed_dim, 407 | temperature_scheduler_config, 408 | ckpt_path=None, 409 | ignore_keys=[], 410 | image_key="image", 411 | colorize_nlabels=None, 412 | monitor=None, 413 | kl_weight=1e-8, 414 | remap=None, 415 | ): 416 | 417 | z_channels = ddconfig["z_channels"] 418 | super().__init__( 419 | ddconfig, 420 | lossconfig, 421 | n_embed, 422 | embed_dim, 423 | ckpt_path=None, 424 | ignore_keys=ignore_keys, 425 | image_key=image_key, 426 | colorize_nlabels=colorize_nlabels, 427 | monitor=monitor, 428 | ) 429 | 430 | self.loss.n_classes = n_embed 431 | self.vocab_size = n_embed 432 | 433 | self.quantize = GumbelQuantize( 434 | z_channels, 435 | embed_dim, 436 | n_embed=n_embed, 437 | kl_weight=kl_weight, 438 | temp_init=1.0, 439 | remap=remap, 440 | ) 441 | 442 | self.temperature_scheduler = instantiate_from_config( 443 | temperature_scheduler_config 444 | ) # annealing of temp 445 | 446 | if ckpt_path is not None: 447 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 448 | 449 | def temperature_scheduling(self): 450 | self.quantize.temperature = self.temperature_scheduler(self.global_step) 451 | 452 | def encode_to_prequant(self, x): 453 | h = self.encoder(x) 454 | h = self.quant_conv(h) 455 | return h 456 | 457 | def decode_code(self, code_b): 458 | raise NotImplementedError 459 | 460 | def training_step(self, batch, batch_idx, optimizer_idx): 461 | self.temperature_scheduling() 462 | x = self.get_input(batch, self.image_key) 463 | xrec, qloss = self(x) 464 | 465 | if optimizer_idx == 0: 466 | # autoencode 467 | aeloss, log_dict_ae = self.loss( 468 | qloss, 469 | x, 470 | xrec, 471 | optimizer_idx, 472 | self.global_step, 473 | last_layer=self.get_last_layer(), 474 | split="train", 475 | ) 476 | 477 | self.log_dict( 478 | log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True 479 | ) 480 | self.log( 481 | "temperature", 482 | self.quantize.temperature, 483 | prog_bar=False, 484 | logger=True, 485 | on_step=True, 486 | on_epoch=True, 487 | ) 488 | return aeloss 489 | 490 | if optimizer_idx == 1: 491 | # discriminator 492 | discloss, log_dict_disc = self.loss( 493 | qloss, 494 | x, 495 | xrec, 496 | optimizer_idx, 497 | self.global_step, 498 | last_layer=self.get_last_layer(), 499 | split="train", 500 | ) 501 | self.log_dict( 502 | log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True 503 | ) 504 | return discloss 505 | 506 | def validation_step(self, batch, batch_idx): 507 | x = self.get_input(batch, self.image_key) 508 | xrec, qloss = self(x) 509 | aeloss, log_dict_ae = self.loss( 510 | qloss, 511 | x, 512 | xrec, 513 | 0, 514 | self.global_step, 515 | last_layer=self.get_last_layer(), 516 | split="val", 517 | ) 518 | 519 | discloss, log_dict_disc = self.loss( 520 | qloss, 521 | x, 522 | xrec, 523 | 1, 524 | self.global_step, 525 | last_layer=self.get_last_layer(), 526 | split="val", 527 | ) 528 | rec_loss = log_dict_ae["val/rec_loss"] 529 | self.log( 530 | "val/rec_loss", 531 | rec_loss, 532 | prog_bar=True, 533 | logger=True, 534 | on_step=False, 535 | on_epoch=True, 536 | sync_dist=True, 537 | ) 538 | self.log( 539 | "val/aeloss", 540 | aeloss, 541 | prog_bar=True, 542 | logger=True, 543 | on_step=False, 544 | on_epoch=True, 545 | sync_dist=True, 546 | ) 547 | self.log_dict(log_dict_ae) 548 | self.log_dict(log_dict_disc) 549 | return self.log_dict 550 | 551 | def log_images(self, batch, **kwargs): 552 | log = dict() 553 | x = self.get_input(batch, self.image_key) 554 | x = x.to(self.device) 555 | # encode 556 | h = self.encoder(x) 557 | h = self.quant_conv(h) 558 | quant, _, _ = self.quantize(h) 559 | # decode 560 | x_rec = self.decode(quant) 561 | log["inputs"] = x 562 | log["reconstructions"] = x_rec 563 | return log 564 | 565 | 566 | class EMAVQ(VQModel): 567 | def __init__( 568 | self, 569 | ddconfig, 570 | lossconfig, 571 | n_embed, 572 | embed_dim, 573 | ckpt_path=None, 574 | ignore_keys=[], 575 | image_key="image", 576 | colorize_nlabels=None, 577 | monitor=None, 578 | remap=None, 579 | sane_index_shape=False, # tell vector quantizer to return indices as bhw 580 | ): 581 | super().__init__( 582 | ddconfig, 583 | lossconfig, 584 | n_embed, 585 | embed_dim, 586 | ckpt_path=None, 587 | ignore_keys=ignore_keys, 588 | image_key=image_key, 589 | colorize_nlabels=colorize_nlabels, 590 | monitor=monitor, 591 | ) 592 | self.quantize = EMAVectorQuantizer( 593 | n_embed=n_embed, embedding_dim=embed_dim, beta=0.25, remap=remap 594 | ) 595 | 596 | def configure_optimizers(self): 597 | lr = self.learning_rate 598 | # Remove self.quantize from parameter list since it is updated via EMA 599 | opt_ae = torch.optim.Adam( 600 | list(self.encoder.parameters()) 601 | + list(self.decoder.parameters()) 602 | + list(self.quant_conv.parameters()) 603 | + list(self.post_quant_conv.parameters()), 604 | lr=lr, 605 | betas=(0.5, 0.9), 606 | ) 607 | opt_disc = torch.optim.Adam( 608 | self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9) 609 | ) 610 | return [opt_ae, opt_disc], [] 611 | -------------------------------------------------------------------------------- /taming/modules/autoencoder/lpips/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoHuangLab/Protein-Localization-Transformer/9d7d0bb4296a0363d1af21a7e50ddc8672e0eca6/taming/modules/autoencoder/lpips/vgg.pth -------------------------------------------------------------------------------- /taming/modules/discriminator/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | 4 | 5 | from taming.modules.util import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find('Conv') != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find('BatchNorm') != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class NLayerDiscriminator(nn.Module): 18 | """Defines a PatchGAN discriminator as in Pix2Pix 19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 20 | """ 21 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 22 | """Construct a PatchGAN discriminator 23 | Parameters: 24 | input_nc (int) -- the number of channels in input images 25 | ndf (int) -- the number of filters in the last conv layer 26 | n_layers (int) -- the number of conv layers in the discriminator 27 | norm_layer -- normalization layer 28 | """ 29 | super(NLayerDiscriminator, self).__init__() 30 | if not use_actnorm: 31 | norm_layer = nn.BatchNorm2d 32 | else: 33 | norm_layer = ActNorm 34 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 35 | use_bias = norm_layer.func != nn.BatchNorm2d 36 | else: 37 | use_bias = norm_layer != nn.BatchNorm2d 38 | 39 | kw = 4 40 | padw = 1 41 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 42 | nf_mult = 1 43 | nf_mult_prev = 1 44 | for n in range(1, n_layers): # gradually increase the number of filters 45 | nf_mult_prev = nf_mult 46 | nf_mult = min(2 ** n, 8) 47 | sequence += [ 48 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 49 | norm_layer(ndf * nf_mult), 50 | nn.LeakyReLU(0.2, True) 51 | ] 52 | 53 | nf_mult_prev = nf_mult 54 | nf_mult = min(2 ** n_layers, 8) 55 | sequence += [ 56 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 57 | norm_layer(ndf * nf_mult), 58 | nn.LeakyReLU(0.2, True) 59 | ] 60 | 61 | sequence += [ 62 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 63 | self.main = nn.Sequential(*sequence) 64 | 65 | def forward(self, input): 66 | """Standard forward.""" 67 | return self.main(input) 68 | -------------------------------------------------------------------------------- /taming/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from taming.modules.losses.vqperceptual import DummyLoss 2 | 3 | -------------------------------------------------------------------------------- /taming/modules/losses/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | 8 | from taming.util import get_ckpt_path 9 | 10 | 11 | class LPIPS(nn.Module): 12 | # Learned perceptual metric 13 | def __init__(self, use_dropout=True): 14 | super().__init__() 15 | self.scaling_layer = ScalingLayer() 16 | self.chns = [64, 128, 256, 512, 512] # vg16 features 17 | self.net = vgg16(pretrained=True, requires_grad=False) 18 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 19 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 20 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 21 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 22 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 23 | self.load_from_pretrained() 24 | for param in self.parameters(): 25 | param.requires_grad = False 26 | 27 | def load_from_pretrained(self, name="vgg_lpips"): 28 | ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") 29 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 30 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 31 | 32 | @classmethod 33 | def from_pretrained(cls, name="vgg_lpips"): 34 | if name != "vgg_lpips": 35 | raise NotImplementedError 36 | model = cls() 37 | ckpt = get_ckpt_path(name) 38 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 39 | return model 40 | 41 | def forward(self, input, target): 42 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 43 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 44 | feats0, feats1, diffs = {}, {}, {} 45 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 46 | for kk in range(len(self.chns)): 47 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 49 | 50 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 51 | val = res[0] 52 | for l in range(1, len(self.chns)): 53 | val += res[l] 54 | return val 55 | 56 | 57 | class ScalingLayer(nn.Module): 58 | def __init__(self): 59 | super(ScalingLayer, self).__init__() 60 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 61 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 62 | 63 | def forward(self, inp): 64 | return (inp - self.shift) / self.scale 65 | 66 | 67 | class NetLinLayer(nn.Module): 68 | """ A single linear layer which does a 1x1 conv """ 69 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 70 | super(NetLinLayer, self).__init__() 71 | layers = [nn.Dropout(), ] if (use_dropout) else [] 72 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 73 | self.model = nn.Sequential(*layers) 74 | 75 | 76 | class vgg16(torch.nn.Module): 77 | def __init__(self, requires_grad=False, pretrained=True): 78 | super(vgg16, self).__init__() 79 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 80 | self.slice1 = torch.nn.Sequential() 81 | self.slice2 = torch.nn.Sequential() 82 | self.slice3 = torch.nn.Sequential() 83 | self.slice4 = torch.nn.Sequential() 84 | self.slice5 = torch.nn.Sequential() 85 | self.N_slices = 5 86 | for x in range(4): 87 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 88 | for x in range(4, 9): 89 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 90 | for x in range(9, 16): 91 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 92 | for x in range(16, 23): 93 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 94 | for x in range(23, 30): 95 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 96 | if not requires_grad: 97 | for param in self.parameters(): 98 | param.requires_grad = False 99 | 100 | def forward(self, X): 101 | h = self.slice1(X) 102 | h_relu1_2 = h 103 | h = self.slice2(h) 104 | h_relu2_2 = h 105 | h = self.slice3(h) 106 | h_relu3_3 = h 107 | h = self.slice4(h) 108 | h_relu4_3 = h 109 | h = self.slice5(h) 110 | h_relu5_3 = h 111 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 112 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 113 | return out 114 | 115 | 116 | def normalize_tensor(x,eps=1e-10): 117 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 118 | return x/(norm_factor+eps) 119 | 120 | 121 | def spatial_average(x, keepdim=True): 122 | return x.mean([2,3],keepdim=keepdim) 123 | 124 | -------------------------------------------------------------------------------- /taming/modules/losses/segmentation.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class BCELoss(nn.Module): 6 | def forward(self, prediction, target): 7 | loss = F.binary_cross_entropy_with_logits(prediction,target) 8 | return loss, {} 9 | 10 | 11 | class BCELossWithQuant(nn.Module): 12 | def __init__(self, codebook_weight=1.): 13 | super().__init__() 14 | self.codebook_weight = codebook_weight 15 | 16 | def forward(self, qloss, target, prediction, split): 17 | bce_loss = F.binary_cross_entropy_with_logits(prediction,target) 18 | loss = bce_loss + self.codebook_weight*qloss 19 | return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(), 20 | "{}/bce_loss".format(split): bce_loss.detach().mean(), 21 | "{}/quant_loss".format(split): qloss.detach().mean() 22 | } 23 | -------------------------------------------------------------------------------- /taming/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from taming.modules.losses.lpips import LPIPS 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | 8 | 9 | class DummyLoss(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | 14 | def adopt_weight(weight, global_step, threshold=0, value=0.): 15 | if global_step < threshold: 16 | weight = value 17 | return weight 18 | 19 | 20 | def hinge_d_loss(logits_real, logits_fake): 21 | loss_real = torch.mean(F.relu(1. - logits_real)) 22 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 23 | d_loss = 0.5 * (loss_real + loss_fake) 24 | return d_loss 25 | 26 | 27 | def vanilla_d_loss(logits_real, logits_fake): 28 | d_loss = 0.5 * ( 29 | torch.mean(torch.nn.functional.softplus(-logits_real)) + 30 | torch.mean(torch.nn.functional.softplus(logits_fake))) 31 | return d_loss 32 | 33 | 34 | class VQLPIPSWithDiscriminator(nn.Module): 35 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 36 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 37 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 38 | disc_ndf=64, disc_loss="hinge"): 39 | super().__init__() 40 | assert disc_loss in ["hinge", "vanilla"] 41 | self.codebook_weight = codebook_weight 42 | self.pixel_weight = pixelloss_weight 43 | self.perceptual_loss = LPIPS().eval() 44 | self.perceptual_weight = perceptual_weight 45 | 46 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 47 | n_layers=disc_num_layers, 48 | use_actnorm=use_actnorm, 49 | ndf=disc_ndf 50 | ).apply(weights_init) 51 | self.discriminator_iter_start = disc_start 52 | if disc_loss == "hinge": 53 | self.disc_loss = hinge_d_loss 54 | elif disc_loss == "vanilla": 55 | self.disc_loss = vanilla_d_loss 56 | else: 57 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 58 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 59 | self.disc_factor = disc_factor 60 | self.discriminator_weight = disc_weight 61 | self.disc_conditional = disc_conditional 62 | 63 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 64 | if last_layer is not None: 65 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 66 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 67 | else: 68 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 69 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 70 | 71 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 72 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 73 | d_weight = d_weight * self.discriminator_weight 74 | return d_weight 75 | 76 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 77 | global_step, last_layer=None, cond=None, split="train"): 78 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 79 | if self.perceptual_weight > 0: 80 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 81 | rec_loss = rec_loss + self.perceptual_weight * p_loss 82 | else: 83 | p_loss = torch.tensor([0.0]) 84 | 85 | nll_loss = rec_loss 86 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 87 | nll_loss = torch.mean(nll_loss) 88 | 89 | # now the GAN part 90 | if optimizer_idx == 0: 91 | # generator update 92 | if cond is None: 93 | assert not self.disc_conditional 94 | logits_fake = self.discriminator(reconstructions.contiguous()) 95 | else: 96 | assert self.disc_conditional 97 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 98 | g_loss = -torch.mean(logits_fake) 99 | 100 | try: 101 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 102 | except RuntimeError: 103 | assert not self.training 104 | d_weight = torch.tensor(0.0) 105 | 106 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 107 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 108 | 109 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 110 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 111 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 112 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 113 | "{}/p_loss".format(split): p_loss.detach().mean(), 114 | "{}/d_weight".format(split): d_weight.detach(), 115 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 116 | "{}/g_loss".format(split): g_loss.detach().mean(), 117 | } 118 | return loss, log 119 | 120 | if optimizer_idx == 1: 121 | # second pass for discriminator update 122 | if cond is None: 123 | logits_real = self.discriminator(inputs.contiguous().detach()) 124 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 125 | else: 126 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 127 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 128 | 129 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 130 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 131 | 132 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 133 | "{}/logits_real".format(split): logits_real.detach().mean(), 134 | "{}/logits_fake".format(split): logits_fake.detach().mean() 135 | } 136 | return d_loss, log 137 | -------------------------------------------------------------------------------- /taming/modules/misc/coord.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class CoordStage(object): 4 | def __init__(self, n_embed, down_factor): 5 | self.n_embed = n_embed 6 | self.down_factor = down_factor 7 | 8 | def eval(self): 9 | return self 10 | 11 | def encode(self, c): 12 | """fake vqmodel interface""" 13 | assert 0.0 <= c.min() and c.max() <= 1.0 14 | b,ch,h,w = c.shape 15 | assert ch == 1 16 | 17 | c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor, 18 | mode="area") 19 | c = c.clamp(0.0, 1.0) 20 | c = self.n_embed*c 21 | c_quant = c.round() 22 | c_ind = c_quant.to(dtype=torch.long) 23 | 24 | info = None, None, c_ind 25 | return c_quant, None, info 26 | 27 | def decode(self, c): 28 | c = c/self.n_embed 29 | c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor, 30 | mode="nearest") 31 | return c 32 | -------------------------------------------------------------------------------- /taming/modules/transformer/mingpt.py: -------------------------------------------------------------------------------- 1 | """ 2 | taken from: https://github.com/karpathy/minGPT/ 3 | GPT model: 4 | - the initial stem consists of a combination of token encoding and a positional encoding 5 | - the meat of it is a uniform sequence of Transformer blocks 6 | - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block 7 | - all blocks feed into a central residual pathway similar to resnets 8 | - the final decoder is a linear projection into a vanilla Softmax classifier 9 | """ 10 | 11 | import math 12 | import logging 13 | 14 | import torch 15 | import torch.nn as nn 16 | from torch.nn import functional as F 17 | from transformers import top_k_top_p_filtering 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class GPTConfig: 23 | """ base GPT config, params common to all GPT versions """ 24 | embd_pdrop = 0.1 25 | resid_pdrop = 0.1 26 | attn_pdrop = 0.1 27 | 28 | def __init__(self, vocab_size, block_size, **kwargs): 29 | self.vocab_size = vocab_size 30 | self.block_size = block_size 31 | for k,v in kwargs.items(): 32 | setattr(self, k, v) 33 | 34 | 35 | class GPT1Config(GPTConfig): 36 | """ GPT-1 like network roughly 125M params """ 37 | n_layer = 12 38 | n_head = 12 39 | n_embd = 768 40 | 41 | 42 | class CausalSelfAttention(nn.Module): 43 | """ 44 | A vanilla multi-head masked self-attention layer with a projection at the end. 45 | It is possible to use torch.nn.MultiheadAttention here but I am including an 46 | explicit implementation here to show that there is nothing too scary here. 47 | """ 48 | 49 | def __init__(self, config): 50 | super().__init__() 51 | assert config.n_embd % config.n_head == 0 52 | # key, query, value projections for all heads 53 | self.key = nn.Linear(config.n_embd, config.n_embd) 54 | self.query = nn.Linear(config.n_embd, config.n_embd) 55 | self.value = nn.Linear(config.n_embd, config.n_embd) 56 | # regularization 57 | self.attn_drop = nn.Dropout(config.attn_pdrop) 58 | self.resid_drop = nn.Dropout(config.resid_pdrop) 59 | # output projection 60 | self.proj = nn.Linear(config.n_embd, config.n_embd) 61 | # causal mask to ensure that attention is only applied to the left in the input sequence 62 | mask = torch.tril(torch.ones(config.block_size, 63 | config.block_size)) 64 | if hasattr(config, "n_unmasked"): 65 | mask[:config.n_unmasked, :config.n_unmasked] = 1 66 | self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size)) 67 | self.n_head = config.n_head 68 | 69 | def forward(self, x, layer_past=None): 70 | B, T, C = x.size() 71 | 72 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 73 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 74 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 75 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 76 | 77 | present = torch.stack((k, v)) 78 | if layer_past is not None: 79 | past_key, past_value = layer_past 80 | k = torch.cat((past_key, k), dim=-2) 81 | v = torch.cat((past_value, v), dim=-2) 82 | 83 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 84 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 85 | if layer_past is None: 86 | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) 87 | 88 | att = F.softmax(att, dim=-1) 89 | att = self.attn_drop(att) 90 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 91 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 92 | 93 | # output projection 94 | y = self.resid_drop(self.proj(y)) 95 | return y, present # TODO: check that this does not break anything 96 | 97 | 98 | class Block(nn.Module): 99 | """ an unassuming Transformer block """ 100 | def __init__(self, config): 101 | super().__init__() 102 | self.ln1 = nn.LayerNorm(config.n_embd) 103 | self.ln2 = nn.LayerNorm(config.n_embd) 104 | self.attn = CausalSelfAttention(config) 105 | self.mlp = nn.Sequential( 106 | nn.Linear(config.n_embd, 4 * config.n_embd), 107 | nn.GELU(), # nice 108 | nn.Linear(4 * config.n_embd, config.n_embd), 109 | nn.Dropout(config.resid_pdrop), 110 | ) 111 | 112 | def forward(self, x, layer_past=None, return_present=False): 113 | # TODO: check that training still works 114 | if return_present: assert not self.training 115 | # layer past: tuple of length two with B, nh, T, hs 116 | attn, present = self.attn(self.ln1(x), layer_past=layer_past) 117 | 118 | x = x + attn 119 | x = x + self.mlp(self.ln2(x)) 120 | if layer_past is not None or return_present: 121 | return x, present 122 | return x 123 | 124 | 125 | class GPT(nn.Module): 126 | """ the full GPT language model, with a context size of block_size """ 127 | def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256, 128 | embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0): 129 | super().__init__() 130 | config = GPTConfig(vocab_size=vocab_size, block_size=block_size, 131 | embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop, 132 | n_layer=n_layer, n_head=n_head, n_embd=n_embd, 133 | n_unmasked=n_unmasked) 134 | # input embedding stem 135 | self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) 136 | self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) 137 | self.drop = nn.Dropout(config.embd_pdrop) 138 | # transformer 139 | self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) 140 | # decoder head 141 | self.ln_f = nn.LayerNorm(config.n_embd) 142 | self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 143 | self.block_size = config.block_size 144 | self.apply(self._init_weights) 145 | self.config = config 146 | logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) 147 | 148 | def get_block_size(self): 149 | return self.block_size 150 | 151 | def _init_weights(self, module): 152 | if isinstance(module, (nn.Linear, nn.Embedding)): 153 | module.weight.data.normal_(mean=0.0, std=0.02) 154 | if isinstance(module, nn.Linear) and module.bias is not None: 155 | module.bias.data.zero_() 156 | elif isinstance(module, nn.LayerNorm): 157 | module.bias.data.zero_() 158 | module.weight.data.fill_(1.0) 159 | 160 | def forward(self, idx, embeddings=None, targets=None): 161 | # forward the GPT model 162 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector 163 | 164 | if embeddings is not None: # prepend explicit embeddings 165 | token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) 166 | 167 | t = token_embeddings.shape[1] 168 | assert t <= self.block_size, "Cannot forward, model block size is exhausted." 169 | position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector 170 | x = self.drop(token_embeddings + position_embeddings) 171 | x = self.blocks(x) 172 | x = self.ln_f(x) 173 | logits = self.head(x) 174 | 175 | # if we are given some desired targets also calculate the loss 176 | loss = None 177 | if targets is not None: 178 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) 179 | 180 | return logits, loss 181 | 182 | def forward_with_past(self, idx, embeddings=None, targets=None, past=None, past_length=None): 183 | # inference only 184 | assert not self.training 185 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector 186 | if embeddings is not None: # prepend explicit embeddings 187 | token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) 188 | 189 | if past is not None: 190 | assert past_length is not None 191 | past = torch.cat(past, dim=-2) # n_layer, 2, b, nh, len_past, dim_head 192 | past_shape = list(past.shape) 193 | expected_shape = [self.config.n_layer, 2, idx.shape[0], self.config.n_head, past_length, self.config.n_embd//self.config.n_head] 194 | assert past_shape == expected_shape, f"{past_shape} =/= {expected_shape}" 195 | position_embeddings = self.pos_emb[:, past_length, :] # each position maps to a (learnable) vector 196 | else: 197 | position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :] 198 | 199 | x = self.drop(token_embeddings + position_embeddings) 200 | presents = [] # accumulate over layers 201 | for i, block in enumerate(self.blocks): 202 | x, present = block(x, layer_past=past[i, ...] if past is not None else None, return_present=True) 203 | presents.append(present) 204 | 205 | x = self.ln_f(x) 206 | logits = self.head(x) 207 | # if we are given some desired targets also calculate the loss 208 | loss = None 209 | if targets is not None: 210 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) 211 | 212 | return logits, loss, torch.stack(presents) # _, _, n_layer, 2, b, nh, 1, dim_head 213 | 214 | 215 | class DummyGPT(nn.Module): 216 | # for debugging 217 | def __init__(self, add_value=1): 218 | super().__init__() 219 | self.add_value = add_value 220 | 221 | def forward(self, idx): 222 | return idx + self.add_value, None 223 | 224 | 225 | class CodeGPT(nn.Module): 226 | """Takes in semi-embeddings""" 227 | def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256, 228 | embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0): 229 | super().__init__() 230 | config = GPTConfig(vocab_size=vocab_size, block_size=block_size, 231 | embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop, 232 | n_layer=n_layer, n_head=n_head, n_embd=n_embd, 233 | n_unmasked=n_unmasked) 234 | # input embedding stem 235 | self.tok_emb = nn.Linear(in_channels, config.n_embd) 236 | self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) 237 | self.drop = nn.Dropout(config.embd_pdrop) 238 | # transformer 239 | self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) 240 | # decoder head 241 | self.ln_f = nn.LayerNorm(config.n_embd) 242 | self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 243 | self.block_size = config.block_size 244 | self.apply(self._init_weights) 245 | self.config = config 246 | logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) 247 | 248 | def get_block_size(self): 249 | return self.block_size 250 | 251 | def _init_weights(self, module): 252 | if isinstance(module, (nn.Linear, nn.Embedding)): 253 | module.weight.data.normal_(mean=0.0, std=0.02) 254 | if isinstance(module, nn.Linear) and module.bias is not None: 255 | module.bias.data.zero_() 256 | elif isinstance(module, nn.LayerNorm): 257 | module.bias.data.zero_() 258 | module.weight.data.fill_(1.0) 259 | 260 | def forward(self, idx, embeddings=None, targets=None): 261 | # forward the GPT model 262 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector 263 | 264 | if embeddings is not None: # prepend explicit embeddings 265 | token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) 266 | 267 | t = token_embeddings.shape[1] 268 | assert t <= self.block_size, "Cannot forward, model block size is exhausted." 269 | position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector 270 | x = self.drop(token_embeddings + position_embeddings) 271 | x = self.blocks(x) 272 | x = self.taming_cinln_f(x) 273 | logits = self.head(x) 274 | 275 | # if we are given some desired targets also calculate the loss 276 | loss = None 277 | if targets is not None: 278 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) 279 | 280 | return logits, loss 281 | 282 | 283 | 284 | #### sampling utils 285 | 286 | def top_k_logits(logits, k): 287 | v, ix = torch.topk(logits, k) 288 | out = logits.clone() 289 | out[out < v[:, [-1]]] = -float('Inf') 290 | return out 291 | 292 | @torch.no_grad() 293 | def sample(model, x, steps, temperature=1.0, sample=False, top_k=None): 294 | """ 295 | take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in 296 | the sequence, feeding the predictions back into the model each time. Clearly the sampling 297 | has quadratic complexity unlike an RNN that is only linear, and has a finite context window 298 | of block_size, unlike an RNN that has an infinite context window. 299 | """ 300 | block_size = model.get_block_size() 301 | model.eval() 302 | for k in range(steps): 303 | x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed 304 | logits, _ = model(x_cond) 305 | # pluck the logits at the final step and scale by temperature 306 | logits = logits[:, -1, :] / temperature 307 | # optionally crop probabilities to only the top k options 308 | if top_k is not None: 309 | logits = top_k_logits(logits, top_k) 310 | # apply softmax to convert to probabilities 311 | probs = F.softmax(logits, dim=-1) 312 | # sample from the distribution or take the most likely 313 | if sample: 314 | ix = torch.multinomial(probs, num_samples=1) 315 | else: 316 | _, ix = torch.topk(probs, k=1, dim=-1) 317 | # append to the sequence and continue 318 | x = torch.cat((x, ix), dim=1) 319 | 320 | return x 321 | 322 | 323 | @torch.no_grad() 324 | def sample_with_past(x, model, steps, temperature=1., sample_logits=True, 325 | top_k=None, top_p=None, callback=None): 326 | # x is conditioning 327 | sample = x 328 | cond_len = x.shape[1] 329 | past = None 330 | for n in range(steps): 331 | if callback is not None: 332 | callback(n) 333 | logits, _, present = model.forward_with_past(x, past=past, past_length=(n+cond_len-1)) 334 | if past is None: 335 | past = [present] 336 | else: 337 | past.append(present) 338 | logits = logits[:, -1, :] / temperature 339 | if top_k is not None: 340 | logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) 341 | 342 | probs = F.softmax(logits, dim=-1) 343 | if not sample_logits: 344 | _, x = torch.topk(probs, k=1, dim=-1) 345 | else: 346 | x = torch.multinomial(probs, num_samples=1) 347 | # append to the sequence and continue 348 | sample = torch.cat((sample, x), dim=1) 349 | del past 350 | sample = sample[:, cond_len:] # cut conditioning off 351 | return sample 352 | 353 | 354 | #### clustering utils 355 | 356 | class KMeans(nn.Module): 357 | def __init__(self, ncluster=512, nc=3, niter=10): 358 | super().__init__() 359 | self.ncluster = ncluster 360 | self.nc = nc 361 | self.niter = niter 362 | self.shape = (3,32,32) 363 | self.register_buffer("C", torch.zeros(self.ncluster,nc)) 364 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 365 | 366 | def is_initialized(self): 367 | return self.initialized.item() == 1 368 | 369 | @torch.no_grad() 370 | def initialize(self, x): 371 | N, D = x.shape 372 | assert D == self.nc, D 373 | c = x[torch.randperm(N)[:self.ncluster]] # init clusters at random 374 | for i in range(self.niter): 375 | # assign all pixels to the closest codebook element 376 | a = ((x[:, None, :] - c[None, :, :])**2).sum(-1).argmin(1) 377 | # move each codebook element to be the mean of the pixels that assigned to it 378 | c = torch.stack([x[a==k].mean(0) for k in range(self.ncluster)]) 379 | # re-assign any poorly positioned codebook elements 380 | nanix = torch.any(torch.isnan(c), dim=1) 381 | ndead = nanix.sum().item() 382 | print('done step %d/%d, re-initialized %d dead clusters' % (i+1, self.niter, ndead)) 383 | c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters 384 | 385 | self.C.copy_(c) 386 | self.initialized.fill_(1) 387 | 388 | 389 | def forward(self, x, reverse=False, shape=None): 390 | if not reverse: 391 | # flatten 392 | bs,c,h,w = x.shape 393 | assert c == self.nc 394 | x = x.reshape(bs,c,h*w,1) 395 | C = self.C.permute(1,0) 396 | C = C.reshape(1,c,1,self.ncluster) 397 | a = ((x-C)**2).sum(1).argmin(-1) # bs, h*w indices 398 | return a 399 | else: 400 | # flatten 401 | bs, HW = x.shape 402 | """ 403 | c = self.C.reshape( 1, self.nc, 1, self.ncluster) 404 | c = c[bs*[0],:,:,:] 405 | c = c[:,:,HW*[0],:] 406 | x = x.reshape(bs, 1, HW, 1) 407 | x = x[:,3*[0],:,:] 408 | x = torch.gather(c, dim=3, index=x) 409 | """ 410 | x = self.C[x] 411 | x = x.permute(0,2,1) 412 | shape = shape if shape is not None else self.shape 413 | x = x.reshape(bs, *shape) 414 | 415 | return x 416 | -------------------------------------------------------------------------------- /taming/modules/transformer/permuter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class AbstractPermuter(nn.Module): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__() 9 | def forward(self, x, reverse=False): 10 | raise NotImplementedError 11 | 12 | 13 | class Identity(AbstractPermuter): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def forward(self, x, reverse=False): 18 | return x 19 | 20 | 21 | class Subsample(AbstractPermuter): 22 | def __init__(self, H, W): 23 | super().__init__() 24 | C = 1 25 | indices = np.arange(H*W).reshape(C,H,W) 26 | while min(H, W) > 1: 27 | indices = indices.reshape(C,H//2,2,W//2,2) 28 | indices = indices.transpose(0,2,4,1,3) 29 | indices = indices.reshape(C*4,H//2, W//2) 30 | H = H//2 31 | W = W//2 32 | C = C*4 33 | assert H == W == 1 34 | idx = torch.tensor(indices.ravel()) 35 | self.register_buffer('forward_shuffle_idx', 36 | nn.Parameter(idx, requires_grad=False)) 37 | self.register_buffer('backward_shuffle_idx', 38 | nn.Parameter(torch.argsort(idx), requires_grad=False)) 39 | 40 | def forward(self, x, reverse=False): 41 | if not reverse: 42 | return x[:, self.forward_shuffle_idx] 43 | else: 44 | return x[:, self.backward_shuffle_idx] 45 | 46 | 47 | def mortonify(i, j): 48 | """(i,j) index to linear morton code""" 49 | i = np.uint64(i) 50 | j = np.uint64(j) 51 | 52 | z = np.uint(0) 53 | 54 | for pos in range(32): 55 | z = (z | 56 | ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) | 57 | ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1)) 58 | ) 59 | return z 60 | 61 | 62 | class ZCurve(AbstractPermuter): 63 | def __init__(self, H, W): 64 | super().__init__() 65 | reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)] 66 | idx = np.argsort(reverseidx) 67 | idx = torch.tensor(idx) 68 | reverseidx = torch.tensor(reverseidx) 69 | self.register_buffer('forward_shuffle_idx', 70 | idx) 71 | self.register_buffer('backward_shuffle_idx', 72 | reverseidx) 73 | 74 | def forward(self, x, reverse=False): 75 | if not reverse: 76 | return x[:, self.forward_shuffle_idx] 77 | else: 78 | return x[:, self.backward_shuffle_idx] 79 | 80 | 81 | class SpiralOut(AbstractPermuter): 82 | def __init__(self, H, W): 83 | super().__init__() 84 | assert H == W 85 | size = W 86 | indices = np.arange(size*size).reshape(size,size) 87 | 88 | i0 = size//2 89 | j0 = size//2-1 90 | 91 | i = i0 92 | j = j0 93 | 94 | idx = [indices[i0, j0]] 95 | step_mult = 0 96 | for c in range(1, size//2+1): 97 | step_mult += 1 98 | # steps left 99 | for k in range(step_mult): 100 | i = i - 1 101 | j = j 102 | idx.append(indices[i, j]) 103 | 104 | # step down 105 | for k in range(step_mult): 106 | i = i 107 | j = j + 1 108 | idx.append(indices[i, j]) 109 | 110 | step_mult += 1 111 | if c < size//2: 112 | # step right 113 | for k in range(step_mult): 114 | i = i + 1 115 | j = j 116 | idx.append(indices[i, j]) 117 | 118 | # step up 119 | for k in range(step_mult): 120 | i = i 121 | j = j - 1 122 | idx.append(indices[i, j]) 123 | else: 124 | # end reached 125 | for k in range(step_mult-1): 126 | i = i + 1 127 | idx.append(indices[i, j]) 128 | 129 | assert len(idx) == size*size 130 | idx = torch.tensor(idx) 131 | self.register_buffer('forward_shuffle_idx', idx) 132 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) 133 | 134 | def forward(self, x, reverse=False): 135 | if not reverse: 136 | return x[:, self.forward_shuffle_idx] 137 | else: 138 | return x[:, self.backward_shuffle_idx] 139 | 140 | 141 | class SpiralIn(AbstractPermuter): 142 | def __init__(self, H, W): 143 | super().__init__() 144 | assert H == W 145 | size = W 146 | indices = np.arange(size*size).reshape(size,size) 147 | 148 | i0 = size//2 149 | j0 = size//2-1 150 | 151 | i = i0 152 | j = j0 153 | 154 | idx = [indices[i0, j0]] 155 | step_mult = 0 156 | for c in range(1, size//2+1): 157 | step_mult += 1 158 | # steps left 159 | for k in range(step_mult): 160 | i = i - 1 161 | j = j 162 | idx.append(indices[i, j]) 163 | 164 | # step down 165 | for k in range(step_mult): 166 | i = i 167 | j = j + 1 168 | idx.append(indices[i, j]) 169 | 170 | step_mult += 1 171 | if c < size//2: 172 | # step right 173 | for k in range(step_mult): 174 | i = i + 1 175 | j = j 176 | idx.append(indices[i, j]) 177 | 178 | # step up 179 | for k in range(step_mult): 180 | i = i 181 | j = j - 1 182 | idx.append(indices[i, j]) 183 | else: 184 | # end reached 185 | for k in range(step_mult-1): 186 | i = i + 1 187 | idx.append(indices[i, j]) 188 | 189 | assert len(idx) == size*size 190 | idx = idx[::-1] 191 | idx = torch.tensor(idx) 192 | self.register_buffer('forward_shuffle_idx', idx) 193 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) 194 | 195 | def forward(self, x, reverse=False): 196 | if not reverse: 197 | return x[:, self.forward_shuffle_idx] 198 | else: 199 | return x[:, self.backward_shuffle_idx] 200 | 201 | 202 | class Random(nn.Module): 203 | def __init__(self, H, W): 204 | super().__init__() 205 | indices = np.random.RandomState(1).permutation(H*W) 206 | idx = torch.tensor(indices.ravel()) 207 | self.register_buffer('forward_shuffle_idx', idx) 208 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) 209 | 210 | def forward(self, x, reverse=False): 211 | if not reverse: 212 | return x[:, self.forward_shuffle_idx] 213 | else: 214 | return x[:, self.backward_shuffle_idx] 215 | 216 | 217 | class AlternateParsing(AbstractPermuter): 218 | def __init__(self, H, W): 219 | super().__init__() 220 | indices = np.arange(W*H).reshape(H,W) 221 | for i in range(1, H, 2): 222 | indices[i, :] = indices[i, ::-1] 223 | idx = indices.flatten() 224 | assert len(idx) == H*W 225 | idx = torch.tensor(idx) 226 | self.register_buffer('forward_shuffle_idx', idx) 227 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) 228 | 229 | def forward(self, x, reverse=False): 230 | if not reverse: 231 | return x[:, self.forward_shuffle_idx] 232 | else: 233 | return x[:, self.backward_shuffle_idx] 234 | 235 | 236 | if __name__ == "__main__": 237 | p0 = AlternateParsing(16, 16) 238 | print(p0.forward_shuffle_idx) 239 | print(p0.backward_shuffle_idx) 240 | 241 | x = torch.randint(0, 768, size=(11, 256)) 242 | y = p0(x) 243 | xre = p0(y, reverse=True) 244 | assert torch.equal(x, xre) 245 | 246 | p1 = SpiralOut(2, 2) 247 | print(p1.forward_shuffle_idx) 248 | print(p1.backward_shuffle_idx) 249 | -------------------------------------------------------------------------------- /taming/modules/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def count_params(model): 6 | total_params = sum(p.numel() for p in model.parameters()) 7 | return total_params 8 | 9 | 10 | class ActNorm(nn.Module): 11 | def __init__(self, num_features, logdet=False, affine=True, 12 | allow_reverse_init=False): 13 | assert affine 14 | super().__init__() 15 | self.logdet = logdet 16 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 17 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 18 | self.allow_reverse_init = allow_reverse_init 19 | 20 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 21 | 22 | def initialize(self, input): 23 | with torch.no_grad(): 24 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 25 | mean = ( 26 | flatten.mean(1) 27 | .unsqueeze(1) 28 | .unsqueeze(2) 29 | .unsqueeze(3) 30 | .permute(1, 0, 2, 3) 31 | ) 32 | std = ( 33 | flatten.std(1) 34 | .unsqueeze(1) 35 | .unsqueeze(2) 36 | .unsqueeze(3) 37 | .permute(1, 0, 2, 3) 38 | ) 39 | 40 | self.loc.data.copy_(-mean) 41 | self.scale.data.copy_(1 / (std + 1e-6)) 42 | 43 | def forward(self, input, reverse=False): 44 | if reverse: 45 | return self.reverse(input) 46 | if len(input.shape) == 2: 47 | input = input[:,:,None,None] 48 | squeeze = True 49 | else: 50 | squeeze = False 51 | 52 | _, _, height, width = input.shape 53 | 54 | if self.training and self.initialized.item() == 0: 55 | self.initialize(input) 56 | self.initialized.fill_(1) 57 | 58 | h = self.scale * (input + self.loc) 59 | 60 | if squeeze: 61 | h = h.squeeze(-1).squeeze(-1) 62 | 63 | if self.logdet: 64 | log_abs = torch.log(torch.abs(self.scale)) 65 | logdet = height*width*torch.sum(log_abs) 66 | logdet = logdet * torch.ones(input.shape[0]).to(input) 67 | return h, logdet 68 | 69 | return h 70 | 71 | def reverse(self, output): 72 | if self.training and self.initialized.item() == 0: 73 | if not self.allow_reverse_init: 74 | raise RuntimeError( 75 | "Initializing ActNorm in reverse direction is " 76 | "disabled by default. Use allow_reverse_init=True to enable." 77 | ) 78 | else: 79 | self.initialize(output) 80 | self.initialized.fill_(1) 81 | 82 | if len(output.shape) == 2: 83 | output = output[:,:,None,None] 84 | squeeze = True 85 | else: 86 | squeeze = False 87 | 88 | h = output / self.scale - self.loc 89 | 90 | if squeeze: 91 | h = h.squeeze(-1).squeeze(-1) 92 | return h 93 | 94 | 95 | class AbstractEncoder(nn.Module): 96 | def __init__(self): 97 | super().__init__() 98 | 99 | def encode(self, *args, **kwargs): 100 | raise NotImplementedError 101 | 102 | 103 | class Labelator(AbstractEncoder): 104 | """Net2Net Interface for Class-Conditional Model""" 105 | def __init__(self, n_classes, quantize_interface=True): 106 | super().__init__() 107 | self.n_classes = n_classes 108 | self.quantize_interface = quantize_interface 109 | 110 | def encode(self, c): 111 | c = c[:,None] 112 | if self.quantize_interface: 113 | return c, None, [None, None, c.long()] 114 | return c 115 | 116 | 117 | class SOSProvider(AbstractEncoder): 118 | # for unconditional training 119 | def __init__(self, sos_token, quantize_interface=True): 120 | super().__init__() 121 | self.sos_token = sos_token 122 | self.quantize_interface = quantize_interface 123 | 124 | def encode(self, x): 125 | # get batch size from data and replicate sos_token 126 | c = torch.ones(x.shape[0], 1)*self.sos_token 127 | c = c.long().to(x.device) 128 | if self.quantize_interface: 129 | return c, None, [None, None, c] 130 | return c 131 | -------------------------------------------------------------------------------- /taming/modules/vqvae/quantize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from torch import einsum 6 | from einops import rearrange 7 | 8 | 9 | class VectorQuantizer(nn.Module): 10 | """ 11 | see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py 12 | ____________________________________________ 13 | Discretization bottleneck part of the VQ-VAE. 14 | Inputs: 15 | - n_e : number of embeddings 16 | - e_dim : dimension of embedding 17 | - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 18 | _____________________________________________ 19 | """ 20 | 21 | # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for 22 | # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be 23 | # used wherever VectorQuantizer has been used before and is additionally 24 | # more efficient. 25 | def __init__(self, n_e, e_dim, beta): 26 | super(VectorQuantizer, self).__init__() 27 | self.n_e = n_e 28 | self.e_dim = e_dim 29 | self.beta = beta 30 | 31 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 32 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) 33 | 34 | def forward(self, z): 35 | """ 36 | Inputs the output of the encoder network z and maps it to a discrete 37 | one-hot vector that is the index of the closest embedding vector e_j 38 | z (continuous) -> z_q (discrete) 39 | z.shape = (batch, channel, height, width) 40 | quantization pipeline: 41 | 1. get encoder input (B,C,H,W) 42 | 2. flatten input to (B*H*W,C) 43 | """ 44 | # reshape z -> (batch, height, width, channel) and flatten 45 | z = z.permute(0, 2, 3, 1).contiguous() 46 | z_flattened = z.view(-1, self.e_dim) 47 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 48 | 49 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 50 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \ 51 | torch.matmul(z_flattened, self.embedding.weight.t()) 52 | 53 | ## could possible replace this here 54 | # #\start... 55 | # find closest encodings 56 | min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) 57 | 58 | min_encodings = torch.zeros( 59 | min_encoding_indices.shape[0], self.n_e).to(z) 60 | min_encodings.scatter_(1, min_encoding_indices, 1) 61 | 62 | # dtype min encodings: torch.float32 63 | # min_encodings shape: torch.Size([2048, 512]) 64 | # min_encoding_indices.shape: torch.Size([2048, 1]) 65 | 66 | # get quantized latent vectors 67 | z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) 68 | #.........\end 69 | 70 | # with: 71 | # .........\start 72 | #min_encoding_indices = torch.argmin(d, dim=1) 73 | #z_q = self.embedding(min_encoding_indices) 74 | # ......\end......... (TODO) 75 | 76 | # compute loss for embedding 77 | loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ 78 | torch.mean((z_q - z.detach()) ** 2) 79 | 80 | # preserve gradients 81 | z_q = z + (z_q - z).detach() 82 | 83 | # perplexity 84 | e_mean = torch.mean(min_encodings, dim=0) 85 | perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) 86 | 87 | # reshape back to match original input shape 88 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 89 | 90 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices) 91 | 92 | def get_codebook_entry(self, indices, shape): 93 | # shape specifying (batch, height, width, channel) 94 | # TODO: check for more easy handling with nn.Embedding 95 | min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices) 96 | min_encodings.scatter_(1, indices[:,None], 1) 97 | 98 | # get quantized latent vectors 99 | z_q = torch.matmul(min_encodings.float(), self.embedding.weight) 100 | 101 | if shape is not None: 102 | z_q = z_q.view(shape) 103 | 104 | # reshape back to match original input shape 105 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 106 | 107 | return z_q 108 | 109 | 110 | class GumbelQuantize(nn.Module): 111 | """ 112 | credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) 113 | Gumbel Softmax trick quantizer 114 | Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 115 | https://arxiv.org/abs/1611.01144 116 | """ 117 | def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True, 118 | kl_weight=5e-4, temp_init=1.0, use_vqinterface=True, 119 | remap=None, unknown_index="random"): 120 | super().__init__() 121 | 122 | self.embedding_dim = embedding_dim 123 | self.n_embed = n_embed 124 | 125 | self.straight_through = straight_through 126 | self.temperature = temp_init 127 | self.kl_weight = kl_weight 128 | 129 | self.proj = nn.Conv2d(num_hiddens, n_embed, 1) 130 | self.embed = nn.Embedding(n_embed, embedding_dim) 131 | 132 | self.use_vqinterface = use_vqinterface 133 | 134 | self.remap = remap 135 | if self.remap is not None: 136 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 137 | self.re_embed = self.used.shape[0] 138 | self.unknown_index = unknown_index # "random" or "extra" or integer 139 | if self.unknown_index == "extra": 140 | self.unknown_index = self.re_embed 141 | self.re_embed = self.re_embed+1 142 | print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " 143 | f"Using {self.unknown_index} for unknown indices.") 144 | else: 145 | self.re_embed = n_embed 146 | 147 | def remap_to_used(self, inds): 148 | ishape = inds.shape 149 | assert len(ishape)>1 150 | inds = inds.reshape(ishape[0],-1) 151 | used = self.used.to(inds) 152 | match = (inds[:,:,None]==used[None,None,...]).long() 153 | new = match.argmax(-1) 154 | unknown = match.sum(2)<1 155 | if self.unknown_index == "random": 156 | new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) 157 | else: 158 | new[unknown] = self.unknown_index 159 | return new.reshape(ishape) 160 | 161 | def unmap_to_all(self, inds): 162 | ishape = inds.shape 163 | assert len(ishape)>1 164 | inds = inds.reshape(ishape[0],-1) 165 | used = self.used.to(inds) 166 | if self.re_embed > self.used.shape[0]: # extra token 167 | inds[inds>=self.used.shape[0]] = 0 # simply set to zero 168 | back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) 169 | return back.reshape(ishape) 170 | 171 | def forward(self, z, temp=None, return_logits=False): 172 | # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work 173 | hard = self.straight_through if self.training else True 174 | temp = self.temperature if temp is None else temp 175 | 176 | logits = self.proj(z) 177 | if self.remap is not None: 178 | # continue only with used logits 179 | full_zeros = torch.zeros_like(logits) 180 | logits = logits[:,self.used,...] 181 | 182 | soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) 183 | if self.remap is not None: 184 | # go back to all entries but unused set to zero 185 | full_zeros[:,self.used,...] = soft_one_hot 186 | soft_one_hot = full_zeros 187 | z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight) 188 | 189 | # + kl divergence to the prior loss 190 | qy = F.softmax(logits, dim=1) 191 | diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() 192 | 193 | ind = soft_one_hot.argmax(dim=1) 194 | if self.remap is not None: 195 | ind = self.remap_to_used(ind) 196 | if self.use_vqinterface: 197 | if return_logits: 198 | return z_q, diff, (None, None, ind), logits 199 | return z_q, diff, (None, None, ind) 200 | return z_q, diff, ind 201 | 202 | def get_codebook_entry(self, indices, shape): 203 | b, h, w, c = shape 204 | assert b*h*w == indices.shape[0] 205 | indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w) 206 | if self.remap is not None: 207 | indices = self.unmap_to_all(indices) 208 | one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() 209 | z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight) 210 | return z_q 211 | 212 | 213 | class VectorQuantizer2(nn.Module): 214 | """ 215 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly 216 | avoids costly matrix multiplications and allows for post-hoc remapping of indices. 217 | """ 218 | # NOTE: due to a bug the beta term was applied to the wrong term. for 219 | # backwards compatibility we use the buggy version by default, but you can 220 | # specify legacy=False to fix it. 221 | def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", 222 | sane_index_shape=False, legacy=True): 223 | super().__init__() 224 | self.n_e = n_e 225 | self.e_dim = e_dim 226 | self.beta = beta 227 | self.legacy = legacy 228 | 229 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 230 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) 231 | 232 | self.remap = remap 233 | if self.remap is not None: 234 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 235 | self.re_embed = self.used.shape[0] 236 | self.unknown_index = unknown_index # "random" or "extra" or integer 237 | if self.unknown_index == "extra": 238 | self.unknown_index = self.re_embed 239 | self.re_embed = self.re_embed+1 240 | print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " 241 | f"Using {self.unknown_index} for unknown indices.") 242 | else: 243 | self.re_embed = n_e 244 | 245 | self.sane_index_shape = sane_index_shape 246 | 247 | def remap_to_used(self, inds): 248 | ishape = inds.shape 249 | assert len(ishape)>1 250 | inds = inds.reshape(ishape[0],-1) 251 | used = self.used.to(inds) 252 | match = (inds[:,:,None]==used[None,None,...]).long() 253 | new = match.argmax(-1) 254 | unknown = match.sum(2)<1 255 | if self.unknown_index == "random": 256 | new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) 257 | else: 258 | new[unknown] = self.unknown_index 259 | return new.reshape(ishape) 260 | 261 | def unmap_to_all(self, inds): 262 | ishape = inds.shape 263 | assert len(ishape)>1 264 | inds = inds.reshape(ishape[0],-1) 265 | used = self.used.to(inds) 266 | if self.re_embed > self.used.shape[0]: # extra token 267 | inds[inds>=self.used.shape[0]] = 0 # simply set to zero 268 | back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) 269 | return back.reshape(ishape) 270 | 271 | def forward(self, z, temp=None, rescale_logits=False, return_logits=False): 272 | assert temp is None or temp==1.0, "Only for interface compatible with Gumbel" 273 | assert rescale_logits==False, "Only for interface compatible with Gumbel" 274 | assert return_logits==False, "Only for interface compatible with Gumbel" 275 | # reshape z -> (batch, height, width, channel) and flatten 276 | z = rearrange(z, 'b c h w -> b h w c').contiguous() 277 | z_flattened = z.view(-1, self.e_dim) 278 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 279 | 280 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 281 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \ 282 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) 283 | 284 | min_encoding_indices = torch.argmin(d, dim=1) 285 | z_q = self.embedding(min_encoding_indices).view(z.shape) 286 | perplexity = None 287 | min_encodings = None 288 | 289 | # compute loss for embedding 290 | if not self.legacy: 291 | loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ 292 | torch.mean((z_q - z.detach()) ** 2) 293 | else: 294 | loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ 295 | torch.mean((z_q - z.detach()) ** 2) 296 | 297 | # preserve gradients 298 | z_q = z + (z_q - z).detach() 299 | 300 | # reshape back to match original input shape 301 | z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() 302 | 303 | if self.remap is not None: 304 | min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis 305 | min_encoding_indices = self.remap_to_used(min_encoding_indices) 306 | min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten 307 | 308 | if self.sane_index_shape: 309 | min_encoding_indices = min_encoding_indices.reshape( 310 | z_q.shape[0], z_q.shape[2], z_q.shape[3]) 311 | 312 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices) 313 | 314 | def get_codebook_entry(self, indices, shape): 315 | # shape specifying (batch, height, width, channel) 316 | if self.remap is not None: 317 | indices = indices.reshape(shape[0],-1) # add batch axis 318 | indices = self.unmap_to_all(indices) 319 | indices = indices.reshape(-1) # flatten again 320 | 321 | # get quantized latent vectors 322 | z_q = self.embedding(indices) 323 | 324 | if shape is not None: 325 | z_q = z_q.view(shape) 326 | # reshape back to match original input shape 327 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 328 | 329 | return z_q 330 | 331 | class EmbeddingEMA(nn.Module): 332 | def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5): 333 | super().__init__() 334 | self.decay = decay 335 | self.eps = eps 336 | weight = torch.randn(num_tokens, codebook_dim) 337 | self.weight = nn.Parameter(weight, requires_grad = False) 338 | self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False) 339 | self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False) 340 | self.update = True 341 | 342 | def forward(self, embed_id): 343 | return F.embedding(embed_id, self.weight) 344 | 345 | def cluster_size_ema_update(self, new_cluster_size): 346 | self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) 347 | 348 | def embed_avg_ema_update(self, new_embed_avg): 349 | self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) 350 | 351 | def weight_update(self, num_tokens): 352 | n = self.cluster_size.sum() 353 | smoothed_cluster_size = ( 354 | (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n 355 | ) 356 | #normalize embedding average with smoothed cluster size 357 | embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) 358 | self.weight.data.copy_(embed_normalized) 359 | 360 | 361 | class EMAVectorQuantizer(nn.Module): 362 | def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, 363 | remap=None, unknown_index="random"): 364 | super().__init__() 365 | self.codebook_dim = codebook_dim 366 | self.num_tokens = num_tokens 367 | self.beta = beta 368 | self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps) 369 | 370 | self.remap = remap 371 | if self.remap is not None: 372 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 373 | self.re_embed = self.used.shape[0] 374 | self.unknown_index = unknown_index # "random" or "extra" or integer 375 | if self.unknown_index == "extra": 376 | self.unknown_index = self.re_embed 377 | self.re_embed = self.re_embed+1 378 | print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " 379 | f"Using {self.unknown_index} for unknown indices.") 380 | else: 381 | self.re_embed = n_embed 382 | 383 | def remap_to_used(self, inds): 384 | ishape = inds.shape 385 | assert len(ishape)>1 386 | inds = inds.reshape(ishape[0],-1) 387 | used = self.used.to(inds) 388 | match = (inds[:,:,None]==used[None,None,...]).long() 389 | new = match.argmax(-1) 390 | unknown = match.sum(2)<1 391 | if self.unknown_index == "random": 392 | new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) 393 | else: 394 | new[unknown] = self.unknown_index 395 | return new.reshape(ishape) 396 | 397 | def unmap_to_all(self, inds): 398 | ishape = inds.shape 399 | assert len(ishape)>1 400 | inds = inds.reshape(ishape[0],-1) 401 | used = self.used.to(inds) 402 | if self.re_embed > self.used.shape[0]: # extra token 403 | inds[inds>=self.used.shape[0]] = 0 # simply set to zero 404 | back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) 405 | return back.reshape(ishape) 406 | 407 | def forward(self, z): 408 | # reshape z -> (batch, height, width, channel) and flatten 409 | #z, 'b c h w -> b h w c' 410 | z = rearrange(z, 'b c h w -> b h w c') 411 | z_flattened = z.reshape(-1, self.codebook_dim) 412 | 413 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 414 | d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ 415 | self.embedding.weight.pow(2).sum(dim=1) - 2 * \ 416 | torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n' 417 | 418 | 419 | encoding_indices = torch.argmin(d, dim=1) 420 | 421 | z_q = self.embedding(encoding_indices).view(z.shape) 422 | encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) 423 | avg_probs = torch.mean(encodings, dim=0) 424 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) 425 | 426 | if self.training and self.embedding.update: 427 | #EMA cluster size 428 | encodings_sum = encodings.sum(0) 429 | self.embedding.cluster_size_ema_update(encodings_sum) 430 | #EMA embedding average 431 | embed_sum = encodings.transpose(0,1) @ z_flattened 432 | self.embedding.embed_avg_ema_update(embed_sum) 433 | #normalize embed_avg and update weight 434 | self.embedding.weight_update(self.num_tokens) 435 | 436 | # compute loss for embedding 437 | loss = self.beta * F.mse_loss(z_q.detach(), z) 438 | 439 | # preserve gradients 440 | z_q = z + (z_q - z).detach() 441 | 442 | # reshape back to match original input shape 443 | #z_q, 'b h w c -> b c h w' 444 | z_q = rearrange(z_q, 'b h w c -> b c h w') 445 | return z_q, loss, (perplexity, encodings, encoding_indices) 446 | -------------------------------------------------------------------------------- /taming/util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 7 | } 8 | 9 | CKPT_MAP = { 10 | "vgg_lpips": "vgg.pth" 11 | } 12 | 13 | MD5_MAP = { 14 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" 15 | } 16 | 17 | 18 | def download(url, local_path, chunk_size=1024): 19 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 20 | with requests.get(url, stream=True) as r: 21 | total_size = int(r.headers.get("content-length", 0)) 22 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 23 | with open(local_path, "wb") as f: 24 | for data in r.iter_content(chunk_size=chunk_size): 25 | if data: 26 | f.write(data) 27 | pbar.update(chunk_size) 28 | 29 | 30 | def md5_hash(path): 31 | with open(path, "rb") as f: 32 | content = f.read() 33 | return hashlib.md5(content).hexdigest() 34 | 35 | 36 | def get_ckpt_path(name, root, check=False): 37 | assert name in URL_MAP 38 | path = os.path.join(root, CKPT_MAP[name]) 39 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 40 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 41 | download(URL_MAP[name], path) 42 | md5 = md5_hash(path) 43 | assert md5 == MD5_MAP[name], md5 44 | return path 45 | 46 | 47 | class KeyNotFoundError(Exception): 48 | def __init__(self, cause, keys=None, visited=None): 49 | self.cause = cause 50 | self.keys = keys 51 | self.visited = visited 52 | messages = list() 53 | if keys is not None: 54 | messages.append("Key not found: {}".format(keys)) 55 | if visited is not None: 56 | messages.append("Visited: {}".format(visited)) 57 | messages.append("Cause:\n{}".format(cause)) 58 | message = "\n".join(messages) 59 | super().__init__(message) 60 | 61 | 62 | def retrieve( 63 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False 64 | ): 65 | """Given a nested list or dict return the desired value at key expanding 66 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion 67 | is done in-place. 68 | 69 | Parameters 70 | ---------- 71 | list_or_dict : list or dict 72 | Possibly nested list or dictionary. 73 | key : str 74 | key/to/value, path like string describing all keys necessary to 75 | consider to get to the desired value. List indices can also be 76 | passed here. 77 | splitval : str 78 | String that defines the delimiter between keys of the 79 | different depth levels in `key`. 80 | default : obj 81 | Value returned if :attr:`key` is not found. 82 | expand : bool 83 | Whether to expand callable nodes on the path or not. 84 | 85 | Returns 86 | ------- 87 | The desired value or if :attr:`default` is not ``None`` and the 88 | :attr:`key` is not found returns ``default``. 89 | 90 | Raises 91 | ------ 92 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is 93 | ``None``. 94 | """ 95 | 96 | keys = key.split(splitval) 97 | 98 | success = True 99 | try: 100 | visited = [] 101 | parent = None 102 | last_key = None 103 | for key in keys: 104 | if callable(list_or_dict): 105 | if not expand: 106 | raise KeyNotFoundError( 107 | ValueError( 108 | "Trying to get past callable node with expand=False." 109 | ), 110 | keys=keys, 111 | visited=visited, 112 | ) 113 | list_or_dict = list_or_dict() 114 | parent[last_key] = list_or_dict 115 | 116 | last_key = key 117 | parent = list_or_dict 118 | 119 | try: 120 | if isinstance(list_or_dict, dict): 121 | list_or_dict = list_or_dict[key] 122 | else: 123 | list_or_dict = list_or_dict[int(key)] 124 | except (KeyError, IndexError, ValueError) as e: 125 | raise KeyNotFoundError(e, keys=keys, visited=visited) 126 | 127 | visited += [key] 128 | # final expansion of retrieved value 129 | if expand and callable(list_or_dict): 130 | list_or_dict = list_or_dict() 131 | parent[last_key] = list_or_dict 132 | except KeyNotFoundError as e: 133 | if default is None: 134 | raise e 135 | else: 136 | list_or_dict = default 137 | success = False 138 | 139 | if not pass_success: 140 | return list_or_dict 141 | else: 142 | return list_or_dict, success 143 | 144 | 145 | if __name__ == "__main__": 146 | config = {"keya": "a", 147 | "keyb": "b", 148 | "keyc": 149 | {"cc1": 1, 150 | "cc2": 2, 151 | } 152 | } 153 | from omegaconf import OmegaConf 154 | config = OmegaConf.create(config) 155 | print(config) 156 | retrieve(config, "keya") 157 | 158 | --------------------------------------------------------------------------------