├── LICENSE
├── README.md
├── celle
├── __init__.py
├── attention.py
├── celle.py
├── reversible.py
├── transformer.py
└── vae.py
├── celle_main.py
├── celle_taming_main.py
├── configs
├── celle.yaml
├── nucleus_vqgan.yaml
└── threshold_vqgan.yaml
├── data
└── aaDescriptors.csv
├── dataloader.py
├── images
├── generate.gif
├── huanglogo.jpeg
├── nucleus.jpg
└── preview.png
├── notebooks
├── Demo.ipynb
├── grad_map.py
└── visualize_attention.ipynb
├── requirements.txt
└── taming
├── lr_scheduler.py
├── models
├── cond_transformer.py
├── dummy_cond_stage.py
└── vqgan.py
├── modules
├── autoencoder
│ └── lpips
│ │ └── vgg.pth
├── diffusionmodules
│ └── model.py
├── discriminator
│ └── model.py
├── losses
│ ├── __init__.py
│ ├── lpips.py
│ ├── segmentation.py
│ └── vqperceptual.py
├── misc
│ └── coord.py
├── transformer
│ ├── mingpt.py
│ └── permuter.py
├── util.py
└── vqvae
│ └── quantize.py
└── util.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Emaad Khwaja, Yun Song, & Bo Huang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | # CELL-E: Biological Zero-Shot Text-to-Image Synthesis for Protein Localization Prediction
8 |
9 | This repository is the official implementation of [CELL-E: Biological Zero-Shot Text-to-Image Synthesis for Protein Localization Prediction](https://www.biorxiv.org/content/10.1101/2022.05.27.493774v1).
10 |
11 | 
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 | ## Requirements
21 |
22 | Create a virtual environment and install the required packages via:
23 |
24 | ```setup
25 | pip install -r requirements.txt
26 | ```
27 |
28 | Next, install ```torch = 1.7.1``` and ```torchvision==0.8.2``` with the [appropriate CUDA version](https://pytorch.org/get-started/previous-versions/#v171)
29 |
30 | ## Preparing Dataset
31 |
32 | ### Images
33 |
34 | We used OpenCell for CELL-E, which has [information on downloading the entire dataset](https://opencell.czbiohub.org/download). A ```data_csv``` is needed to for the dataloader. You must generate a csv file which contains the columns ```nucleus_image_path```, ```protein_image_path```, ```metadata_path```, and ```split``` (train or val). It is assumed that this file exists within the the same general ```data``` folder as the images and metadata files.
35 |
36 | ### Metadata
37 |
38 | Metadata is a JSON which should accompany every protein sequence. If a sequence does not appear in the ```data_csv```, it must appear in ```metadata.json``` with the a key named ```protein_sequence```.
39 |
40 | Adding more information here can be useful for querying individual proteins. They can be retrieved via ```retrieve_metadata```, which creates a ```self.metadata``` variable within the dataset object.
41 |
42 | ## Training
43 |
44 | Training for CELL-E occurs in 2 (or 3) stages:
45 |
46 | - Training Protein Threshold Image encoder
47 | - (Optional, but recommended) Training a Nucleus Image (Conditional Image) Encoder
48 | - Training CELL-E Transformer
49 |
50 | ### Image Encoders
51 |
52 | There are two available image encoders in this repository: Discrete VAE (Similar to the original OpenAI implementation) and VQGAN (recommended). If using the protein threshold image, set ```threshold: True``` for the dataset.
53 |
54 | #### Discrete VAE
55 |
56 | The discrete VAE can be trained using the following code:
57 |
58 | ```python
59 | from celle import DiscreteVAE
60 |
61 | vae = DiscreteVAE(
62 | image_size = 256,
63 | num_layers = 3, # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map)
64 | channels = 1,
65 | num_tokens = 512, # number of visual tokens. in the paper, they used 8192, but could be smaller for downsized projects
66 | codebook_dim = 512, # codebook dimension
67 | hidden_dim = 64, # hidden dimension
68 | num_resnet_blocks = 1, # number of resnet blocks
69 | temperature = 0.9, # gumbel softmax temperature, the lower this is, the harder the discretization
70 | straight_through = False, # straight-through for gumbel softmax. unclear if it is better one way or the other)
71 |
72 | loss = vae(images, return_loss = True)
73 | loss.backward()
74 | ```
75 |
76 | #### VQGAN (Recommended)
77 |
78 | We use a slightly modified version of the [taming-transformers](https://github.com/CompVis/taming-transformers) code.
79 |
80 | To train, run the following script:
81 |
82 | ```python celle_taming_main.py --base configs/threshold_vqgan.yaml -t True```
83 |
84 | Please refer to the original repo for additional flags, such as ```--gpus```.
85 |
86 |
87 | ### CELL-E
88 |
89 | To train, run the following script:
90 |
91 | ```python celle_main.py --base configs/celle.yaml -t True```
92 |
93 | Specify ```--gpus``` in the same format as VQGAN.
94 |
95 | CELL-E contains the following options from [dalle-pytorch](https://github.com/lucidrains/DALLE-pytorch).
96 |
97 | - ```ckpt_path``` : Resume previous CELL-E training. Saved model with state_dict
98 | - ```vqgan_model_path``` : Saved protein image model (with state_dict) for protein image encoder
99 | - ```vqgan_config_path```: Saved protein image model yaml
100 | - ```condition_model_path``` : (Optional) Saved condition (nucleus) model (with state_dict) for protein image encoder
101 | - ```condition_config_path```: (Optional) Saved condition (nucleus) model yaml
102 | - ```num_images```: 1 if only using protein image encoder, 2 if including condition image encoder
103 | - ```image_key```: ```nucleus```, ```target```, or ```threshold```
104 | - ```dim```: Dimension of language model embedding (768 for BERT)
105 | - ```num_text_tokens```: total number of tokens in language model (30 for BERT)
106 | - ```text_seq_len```: Total number of amino acids considered
107 |
108 |
109 | - ```depth```: Transformer model depth, deeper is usually better at the cost of VRAM
110 | - ```heads```: number of heads used in multi-headed attention
111 | - ```dim_head```: size of attention heads
112 | - ```reversible```: See [https://github.com/lucidrains/DALLE-pytorch#scaling-depth](https://github.com/lucidrains/DALLE-pytorch#scaling-depth)
113 | - ```attn_dropout```: Attention Dropout rate in training
114 | - ```ff_dropout```: Feed-Forward Dropout rate in training
115 | - ```attn_types```: See [https://github.com/lucidrains/DALLE-pytorch#sparse-attention](https://github.com/lucidrains/DALLE-pytorch#sparse-attention). Sparse attention not supported
116 | - ```loss_img_weight```: Weighting applied to image reconstruction. text weight = 1
117 | - ```loss_cond_weight```: Weighting applied to condition image reconstruction.
118 | - ```stable```: Norms weights (for when exploding gradients occur)
119 | - ```sandwich_norm```: See [https://github.com/lucidrains/x-transformers#sandwich-norm](https://github.com/lucidrains/x-transformers#sandwich-norm)
120 | - ```shift_tokens```: Applies shift in feature dimension. Only applied to images.
121 | - ```rotary_emb```: [Rotary embedding](https://github.com/lucidrains/x-transformers#rotary-positional-embeddings) scheme for positional encoding
122 | - ```text_embedding```: language used for model. ```no_text```, ```unirep```, ```bert```, ```esm1b```, ```onehot```, ```aadescriptors``` available
123 | - ```fixed_embedding```: Setting to ```True``` allows for protein sequence embeddings to be updated during training
124 | - ```learning_rate```: Learning rate for Adam optimizer
125 | - ```monitor```: Param used to save models
126 |
127 | ## Generating Images
128 |
129 | To generate images, set the saved model as the ckpt_path. This method can be unstable, so refer to ```Demo.ipynb``` to see another way of loading.
130 |
131 | ```python
132 | import OmegaConf
133 | from celle_main import instantiate_from_config
134 |
135 | configs = OmegaConf.load(configs/celle.yaml);
136 |
137 | model = instantiate_from_config(configs.model).to(device);
138 |
139 | model.generate_images(text=sequence,
140 | condition=nucleus,
141 | return_logits=True,
142 | progress=True,
143 | use_cache=True)
144 | ```
145 |
146 | ## Citation
147 |
148 | Please cite us if you decide to use our code for any part of your research.
149 | ```
150 | CELL-E: Biological Zero-Shot Text-to-Image Synthesis for Protein Localization Prediction
151 | Emaad Khwaja, Yun S. Song, Bo Huang
152 | bioRxiv 2022.05.27.493774; doi: https://doi.org/10.1101/2022.05.27.493774
153 | ```
154 |
155 | ## Acknowledgements
156 |
157 | Huge shoutout to [@lucidrains](https://github.com/lucidrains) for putting out [dalle-pytorch](https://github.com/lucidrains/DALLE-pytorch), which this code is based on. This work would not have been possible without the invaluable contribution.
158 |
--------------------------------------------------------------------------------
/celle/__init__.py:
--------------------------------------------------------------------------------
1 | from celle.celle import CELLE, DiscreteVAE
2 | from celle.vae import VQGanVAE
3 |
4 | from pkg_resources import get_distribution
5 |
6 | # __version__ = get_distribution('dalle_pytorch').version
7 | __version__ = "1.4.1"
8 |
--------------------------------------------------------------------------------
/celle/attention.py:
--------------------------------------------------------------------------------
1 | from inspect import isfunction
2 |
3 | import torch
4 | from torch import nn, einsum
5 | import torch.nn.functional as F
6 | from einops import rearrange, repeat
7 |
8 | from rotary_embedding_torch import apply_rotary_emb
9 |
10 | # helpers
11 |
12 |
13 | def exists(val):
14 | return val is not None
15 |
16 |
17 | def uniq(arr):
18 | return {el: True for el in arr}.keys()
19 |
20 |
21 | def default(val, d):
22 | if exists(val):
23 | return val
24 | return d() if isfunction(d) else d
25 |
26 |
27 | def max_neg_value(t):
28 | return -torch.finfo(t.dtype).max
29 |
30 |
31 | def stable_softmax(t, dim=-1, alpha=32 ** 2):
32 | t = t / alpha
33 | t = t - torch.amax(t, dim=dim, keepdim=True).detach()
34 | return (t * alpha).softmax(dim=dim)
35 |
36 |
37 | def apply_pos_emb(pos_emb, qkv):
38 | n = qkv[0].shape[-2]
39 | pos_emb = pos_emb[..., :n, :]
40 | return tuple(map(lambda t: apply_rotary_emb(pos_emb, t), qkv))
41 |
42 |
43 | # classes
44 |
45 |
46 | class Attention(nn.Module):
47 | def __init__(
48 | self,
49 | dim,
50 | seq_len,
51 | causal=True,
52 | heads=8,
53 | dim_head=64,
54 | dropout=0.0,
55 | stable=False,
56 | static_mask=None,
57 | ):
58 | super().__init__()
59 | inner_dim = dim_head * heads
60 | self.heads = heads
61 | self.seq_len = seq_len
62 | self.scale = dim_head ** -0.5
63 |
64 | self.stable = stable
65 | self.causal = causal
66 | self.register_buffer("static_mask", static_mask, persistent=False)
67 |
68 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
69 | self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
70 |
71 | self.save_attn = nn.Identity()
72 |
73 | def forward(self, x, mask=None, rotary_pos_emb=None, cache=None, cache_key=None):
74 | b, n, _, h, device = *x.shape, self.heads, x.device
75 | softmax = torch.softmax if not self.stable else stable_softmax
76 |
77 | offset = cache.get("offset", 0) if exists(cache) else 0
78 |
79 | qkv = self.to_qkv(x).chunk(3, dim=-1)
80 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv)
81 |
82 | if exists(rotary_pos_emb):
83 | q, k, v = apply_pos_emb(rotary_pos_emb[..., offset:, :], (q, k, v))
84 |
85 | q = q * self.scale
86 |
87 | if offset > 0:
88 | k_top, v_top = cache[cache_key]
89 | k = torch.cat([k_top, k], dim=-2)
90 | v = torch.cat([v_top, v], dim=-2)
91 | if exists(cache):
92 | cache[cache_key] = k, v
93 |
94 | dots = torch.einsum("b h i d, b h j d -> b h i j", q, k)
95 | mask_value = max_neg_value(dots)
96 |
97 | if exists(mask):
98 | mask = rearrange(mask, "b j -> b () () j")
99 | dots.masked_fill_(~mask, mask_value)
100 | del mask
101 |
102 | if (
103 | self.causal and offset == 0
104 | ): # causality is naturally enforced for the cached inference
105 | i, j = dots.shape[-2:]
106 | mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool()
107 | dots.masked_fill_(mask, mask_value)
108 |
109 | if exists(self.static_mask):
110 | dots.masked_fill_(
111 | ~self.static_mask[offset : offset + n, : offset + n], mask_value
112 | )
113 |
114 | attn = softmax(dots, dim=-1)
115 |
116 | self.save_attn(attn)
117 |
118 | out = torch.einsum("b h i j, b h j d -> b h i d", attn, v)
119 | out = rearrange(out, "b h n d -> b n (h d)")
120 | out = self.to_out(out)
121 | return out
122 |
123 |
124 | # sparse attention with convolutional pattern, as mentioned in the blog post. customizable kernel size and dilation
125 |
126 |
127 | class SparseConvCausalAttention(nn.Module):
128 | def __init__(
129 | self,
130 | dim,
131 | seq_len,
132 | image_size=32,
133 | kernel_size=5,
134 | dilation=1,
135 | heads=8,
136 | dim_head=64,
137 | dropout=0.0,
138 | stable=False,
139 | **kwargs,
140 | ):
141 | super().__init__()
142 | assert kernel_size % 2 == 1, "kernel size must be odd"
143 |
144 | inner_dim = dim_head * heads
145 | self.seq_len = seq_len
146 | self.heads = heads
147 | self.scale = dim_head ** -0.5
148 | self.image_size = image_size
149 | self.kernel_size = kernel_size
150 | self.dilation = dilation
151 |
152 | self.stable = stable
153 |
154 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
155 |
156 | self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
157 |
158 | def forward(self, x, mask=None, rotary_pos_emb=None):
159 | b, n, _, h, img_size, kernel_size, dilation, seq_len, device = (
160 | *x.shape,
161 | self.heads,
162 | self.image_size,
163 | self.kernel_size,
164 | self.dilation,
165 | self.seq_len,
166 | x.device,
167 | )
168 | softmax = torch.softmax if not self.stable else stable_softmax
169 |
170 | img_seq_len = img_size ** 2
171 | text_len = seq_len + 1 - img_seq_len
172 |
173 | # padding
174 |
175 | padding = seq_len - n + 1
176 | mask = default(mask, lambda: torch.ones(b, text_len, device=device).bool())
177 |
178 | x = F.pad(x, (0, 0, 0, padding), value=0)
179 | mask = mask[:, :text_len]
180 |
181 | # derive query / keys / values
182 |
183 | qkv = self.to_qkv(x).chunk(3, dim=-1)
184 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), qkv)
185 |
186 | if exists(rotary_pos_emb):
187 | q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
188 |
189 | q *= self.scale
190 |
191 | ((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(
192 | lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v)
193 | )
194 |
195 | # text attention
196 |
197 | dots_text = einsum("b i d, b j d -> b i j", q_text, k_text)
198 | mask_value = max_neg_value(dots_text)
199 |
200 | i, j = dots_text.shape[-2:]
201 | text_causal_mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool()
202 | dots_text.masked_fill_(text_causal_mask, mask_value)
203 |
204 | attn_text = softmax(dots_text, dim=-1)
205 | out_text = einsum("b i j, b j d -> b i d", attn_text, v_text)
206 |
207 | # image attention
208 |
209 | effective_kernel_size = (kernel_size - 1) * dilation + 1
210 | padding = effective_kernel_size // 2
211 |
212 | k_img, v_img = map(
213 | lambda t: rearrange(t, "b (h w) c -> b c h w", h=img_size), (k_img, v_img)
214 | )
215 | k_img, v_img = map(
216 | lambda t: F.unfold(t, kernel_size, padding=padding, dilation=dilation),
217 | (k_img, v_img),
218 | )
219 | k_img, v_img = map(
220 | lambda t: rearrange(t, "b (d j) i -> b i j d", j=kernel_size ** 2),
221 | (k_img, v_img),
222 | )
223 |
224 | # let image attend to all of text
225 |
226 | dots_image = einsum("b i d, b i j d -> b i j", q_img, k_img)
227 | dots_image_to_text = einsum("b i d, b j d -> b i j", q_img, k_text)
228 |
229 | # calculate causal attention for local convolution
230 |
231 | i, j = dots_image.shape[-2:]
232 | img_seq = torch.arange(img_seq_len, device=device)
233 | k_img_indices = rearrange(img_seq.float(), "(h w) -> () () h w", h=img_size)
234 | k_img_indices = F.pad(
235 | k_img_indices, (padding,) * 4, value=img_seq_len
236 | ) # padding set to be max, so it is never attended to
237 | k_img_indices = F.unfold(k_img_indices, kernel_size, dilation=dilation)
238 | k_img_indices = rearrange(k_img_indices, "b j i -> b i j")
239 |
240 | # mask image attention
241 |
242 | q_img_indices = rearrange(img_seq, "i -> () i ()")
243 | causal_mask = q_img_indices < k_img_indices
244 |
245 | # concat text mask with image causal mask
246 |
247 | causal_mask = repeat(causal_mask, "() i j -> b i j", b=b * h)
248 | mask = repeat(mask, "b j -> (b h) i j", i=i, h=h)
249 | mask = torch.cat((~mask, causal_mask), dim=-1)
250 |
251 | # image can attend to all of text
252 |
253 | dots = torch.cat((dots_image_to_text, dots_image), dim=-1)
254 | dots.masked_fill_(mask, mask_value)
255 |
256 | attn = softmax(dots, dim=-1)
257 |
258 | # aggregate
259 |
260 | attn_image_to_text, attn_image = attn[..., :text_len], attn[..., text_len:]
261 |
262 | out_image_to_image = einsum("b i j, b i j d -> b i d", attn_image, v_img)
263 | out_image_to_text = einsum("b i j, b j d -> b i d", attn_image_to_text, v_text)
264 |
265 | out_image = out_image_to_image + out_image_to_text
266 |
267 | # combine attended values for both text and image
268 |
269 | out = torch.cat((out_text, out_image), dim=1)
270 |
271 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
272 | out = self.to_out(out)
273 | return out[:, :n]
274 |
275 |
276 | # sparse axial causal attention
277 |
278 |
279 | class SparseAxialCausalAttention(nn.Module):
280 | def __init__(
281 | self,
282 | dim,
283 | seq_len,
284 | image_size=32,
285 | axis=0,
286 | heads=8,
287 | dim_head=64,
288 | dropout=0.0,
289 | stable=False,
290 | **kwargs,
291 | ):
292 | super().__init__()
293 | assert axis in {0, 1}, "axis must be either 0 (along height) or 1 (along width)"
294 | self.axis = axis
295 |
296 | inner_dim = dim_head * heads
297 | self.seq_len = seq_len
298 | self.heads = heads
299 | self.scale = dim_head ** -0.5
300 | self.image_size = image_size
301 |
302 | self.stable = stable
303 |
304 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
305 |
306 | self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
307 |
308 | def forward(self, x, mask=None, rotary_pos_emb=None):
309 | b, n, _, h, img_size, axis, seq_len, device = (
310 | *x.shape,
311 | self.heads,
312 | self.image_size,
313 | self.axis,
314 | self.seq_len,
315 | x.device,
316 | )
317 | softmax = torch.softmax if not self.stable else stable_softmax
318 |
319 | img_seq_len = img_size ** 2
320 | text_len = seq_len + 1 - img_seq_len
321 |
322 | # padding
323 |
324 | padding = seq_len - n + 1
325 | mask = default(mask, lambda: torch.ones(b, text_len, device=device).bool())
326 |
327 | x = F.pad(x, (0, 0, 0, padding), value=0)
328 | mask = mask[:, :text_len]
329 |
330 | # derive queries / keys / values
331 |
332 | qkv = self.to_qkv(x).chunk(3, dim=-1)
333 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), qkv)
334 |
335 | if exists(rotary_pos_emb):
336 | q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
337 |
338 | q *= self.scale
339 |
340 | ((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(
341 | lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v)
342 | )
343 |
344 | # text attention
345 |
346 | dots_text = einsum("b i d, b j d -> b i j", q_text, k_text)
347 | mask_value = max_neg_value(dots_text)
348 |
349 | i, j = dots_text.shape[-2:]
350 | text_causal_mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool()
351 | dots_text.masked_fill_(text_causal_mask, mask_value)
352 |
353 | attn_text = softmax(dots_text, dim=-1)
354 | out_text = einsum("b i j, b j d -> b i d", attn_text, v_text)
355 |
356 | # image attention
357 |
358 | split_axis_einops = (
359 | "b (h w) c -> b h w c" if axis == 0 else "b (h w) c -> b w h c"
360 | )
361 | merge_axis_einops = (
362 | "b x n d -> b (x n) d" if axis == 0 else "b x n d -> b (n x) d"
363 | )
364 |
365 | # split out axis
366 |
367 | q_img, k_img, v_img = map(
368 | lambda t: rearrange(t, split_axis_einops, h=img_size), (q_img, k_img, v_img)
369 | )
370 |
371 | # similarity
372 |
373 | dots_image_to_image = einsum("b x i d, b x j d -> b x i j", q_img, k_img)
374 | dots_image_to_text = einsum("b x i d, b j d -> b x i j", q_img, k_text)
375 |
376 | dots = torch.cat((dots_image_to_text, dots_image_to_image), dim=-1)
377 |
378 | # mask so image has full attention to text, but causal along axis
379 |
380 | bh, x, i, j = dots.shape
381 | causal_mask = (
382 | torch.ones(i, img_size, device=device).triu_(img_size - i + 1).bool()
383 | )
384 | causal_mask = repeat(causal_mask, "i j -> b x i j", b=bh, x=x)
385 |
386 | mask = repeat(mask, "b j -> (b h) x i j", h=h, x=x, i=i)
387 | mask = torch.cat((~mask, causal_mask), dim=-1)
388 |
389 | dots.masked_fill_(mask, mask_value)
390 |
391 | # attention.
392 |
393 | attn = softmax(dots, dim=-1)
394 |
395 | # aggregate
396 |
397 | attn_image_to_text, attn_image_to_image = (
398 | attn[..., :text_len],
399 | attn[..., text_len:],
400 | )
401 |
402 | out_image_to_image = einsum(
403 | "b x i j, b x j d -> b x i d", attn_image_to_image, v_img
404 | )
405 | out_image_to_text = einsum(
406 | "b x i j, b j d -> b x i d", attn_image_to_text, v_text
407 | )
408 |
409 | out_image = out_image_to_image + out_image_to_text
410 |
411 | # merge back axis
412 |
413 | out_image = rearrange(out_image, merge_axis_einops, x=img_size)
414 |
415 | # combine attended values for both text and image
416 |
417 | out = torch.cat((out_text, out_image), dim=1)
418 |
419 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
420 | out = self.to_out(out)
421 | return out[:, :n]
--------------------------------------------------------------------------------
/celle/reversible.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from operator import itemgetter
4 | from torch.autograd.function import Function
5 | from torch.utils.checkpoint import get_device_states, set_device_states
6 |
7 | # for routing arguments into the functions of the reversible layer
8 | def route_args(router, args, depth):
9 | routed_args = [(dict(), dict()) for _ in range(depth)]
10 | matched_keys = [key for key in args.keys() if key in router]
11 |
12 | for key in matched_keys:
13 | val = args[key]
14 | for depth, ((f_args, g_args), routes) in enumerate(
15 | zip(routed_args, router[key])
16 | ):
17 | new_f_args, new_g_args = map(
18 | lambda route: ({key: val} if route else {}), routes
19 | )
20 | routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
21 | return routed_args
22 |
23 |
24 | # following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html
25 | class Deterministic(nn.Module):
26 | def __init__(self, net):
27 | super().__init__()
28 | self.net = net
29 | self.cpu_state = None
30 | self.cuda_in_fwd = None
31 | self.gpu_devices = None
32 | self.gpu_states = None
33 |
34 | def record_rng(self, *args):
35 | self.cpu_state = torch.get_rng_state()
36 | if torch.cuda._initialized:
37 | self.cuda_in_fwd = True
38 | self.gpu_devices, self.gpu_states = get_device_states(*args)
39 |
40 | def forward(self, *args, record_rng=False, set_rng=False, **kwargs):
41 | if record_rng:
42 | self.record_rng(*args)
43 |
44 | if not set_rng:
45 | return self.net(*args, **kwargs)
46 |
47 | rng_devices = []
48 | if self.cuda_in_fwd:
49 | rng_devices = self.gpu_devices
50 |
51 | with torch.random.fork_rng(devices=rng_devices, enabled=True):
52 | torch.set_rng_state(self.cpu_state)
53 | if self.cuda_in_fwd:
54 | set_device_states(self.gpu_devices, self.gpu_states)
55 | return self.net(*args, **kwargs)
56 |
57 |
58 | # heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py
59 | # once multi-GPU is confirmed working, refactor and send PR back to source
60 | class ReversibleBlock(nn.Module):
61 | def __init__(self, f, g):
62 | super().__init__()
63 | self.f = Deterministic(f)
64 | self.g = Deterministic(g)
65 |
66 | def forward(self, x, f_args={}, g_args={}):
67 | x1, x2 = torch.chunk(x, 2, dim=2)
68 | y1, y2 = None, None
69 |
70 | with torch.no_grad():
71 | y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
72 | y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
73 |
74 | return torch.cat([y1, y2], dim=2)
75 |
76 | def backward_pass(self, y, dy, f_args={}, g_args={}):
77 | y1, y2 = torch.chunk(y, 2, dim=2)
78 | del y
79 |
80 | dy1, dy2 = torch.chunk(dy, 2, dim=2)
81 | del dy
82 |
83 | with torch.enable_grad():
84 | y1.requires_grad = True
85 | gy1 = self.g(y1, set_rng=True, **g_args)
86 | torch.autograd.backward(gy1, dy2)
87 |
88 | with torch.no_grad():
89 | x2 = y2 - gy1
90 | del y2, gy1
91 |
92 | dx1 = dy1 + y1.grad
93 | del dy1
94 | y1.grad = None
95 |
96 | with torch.enable_grad():
97 | x2.requires_grad = True
98 | fx2 = self.f(x2, set_rng=True, **f_args)
99 | torch.autograd.backward(fx2, dx1, retain_graph=True)
100 |
101 | with torch.no_grad():
102 | x1 = y1 - fx2
103 | del y1, fx2
104 |
105 | dx2 = dy2 + x2.grad
106 | del dy2
107 | x2.grad = None
108 |
109 | x = torch.cat([x1, x2.detach()], dim=2)
110 | dx = torch.cat([dx1, dx2], dim=2)
111 |
112 | return x, dx
113 |
114 |
115 | class _ReversibleFunction(Function):
116 | @staticmethod
117 | def forward(ctx, x, blocks, args):
118 | ctx.args = args
119 | for block, kwarg in zip(blocks, args):
120 | x = block(x, **kwarg)
121 | ctx.y = x.detach()
122 | ctx.blocks = blocks
123 | return x
124 |
125 | @staticmethod
126 | def backward(ctx, dy):
127 | y = ctx.y
128 | args = ctx.args
129 | for block, kwargs in zip(ctx.blocks[::-1], args[::-1]):
130 | y, dy = block.backward_pass(y, dy, **kwargs)
131 | return dy, None, None
132 |
133 |
134 | class SequentialSequence(nn.Module):
135 | def __init__(self, layers, args_route={}, layer_dropout=0.0):
136 | super().__init__()
137 | assert all(
138 | len(route) == len(layers) for route in args_route.values()
139 | ), "each argument route map must have the same depth as the number of sequential layers"
140 | self.layers = layers
141 | self.args_route = args_route
142 | self.layer_dropout = layer_dropout
143 |
144 | def forward(self, x, **kwargs):
145 | args = route_args(self.args_route, kwargs, len(self.layers))
146 | layers_and_args = list(zip(self.layers, args))
147 |
148 | for (f, g), (f_args, g_args) in layers_and_args:
149 | x = x + f(x, **f_args)
150 | x = x + g(x, **g_args)
151 | return x
152 |
153 |
154 | class ReversibleSequence(nn.Module):
155 | def __init__(self, blocks, args_route={}):
156 | super().__init__()
157 | self.args_route = args_route
158 | self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks])
159 |
160 | def forward(self, x, **kwargs):
161 | x = torch.cat([x, x], dim=-1)
162 |
163 | blocks = self.blocks
164 | args = route_args(self.args_route, kwargs, len(blocks))
165 | args = list(map(lambda x: {"f_args": x[0], "g_args": x[1]}, args))
166 |
167 | out = _ReversibleFunction.apply(x, blocks, args)
168 | return torch.stack(out.chunk(2, dim=-1)).mean(dim=0)
169 |
--------------------------------------------------------------------------------
/celle/transformer.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 | from functools import partial
3 | from itertools import islice, cycle
4 |
5 | import torch
6 | from torch import nn
7 | import torch.nn.functional as F
8 | from einops import rearrange
9 |
10 | from celle.reversible import ReversibleSequence, SequentialSequence
11 | from celle.attention import (
12 | Attention,
13 | SparseConvCausalAttention,
14 | SparseAxialCausalAttention,
15 | )
16 |
17 | from rotary_embedding_torch import RotaryEmbedding, broadcat
18 |
19 | # helpers
20 |
21 |
22 | def exists(val):
23 | return val is not None
24 |
25 |
26 | def default(val, d):
27 | return val if exists(val) else d
28 |
29 |
30 | def cast_tuple(val, depth=1):
31 | if isinstance(val, list):
32 | val = tuple(val)
33 | return val if isinstance(val, tuple) else (val,) * depth
34 |
35 |
36 | # classes
37 |
38 |
39 | class DivideMax(nn.Module):
40 | def __init__(self, dim):
41 | super().__init__()
42 | self.dim = dim
43 |
44 | def forward(self, x):
45 | maxes = x.amax(dim=self.dim, keepdim=True).detach()
46 | return x / maxes
47 |
48 |
49 | class NonCached(nn.Module):
50 | """
51 | A wrapper for layers that don't support the inference cache themselves.
52 | Reconstructs the full sequence before the layer and
53 | cuts the suffix of the outputs after the layer.
54 | """
55 |
56 | def __init__(self, fn):
57 | super().__init__()
58 | self.fn = fn
59 |
60 | def forward(self, x, *, cache=None, cache_key=None, **kwargs):
61 | n = x.shape[-2]
62 | if exists(cache):
63 | if cache_key in cache:
64 | x = torch.cat([cache[cache_key], x], dim=-2)
65 | cache[cache_key] = x
66 |
67 | out = self.fn(x, **kwargs)
68 |
69 | return out[:, -n:]
70 |
71 |
72 | class CachedAs(nn.Module):
73 | """
74 | A wrapper that defines a key for the inference cache.
75 | """
76 |
77 | def __init__(self, cache_key, fn):
78 | super().__init__()
79 | self.cache_key = cache_key
80 | self.fn = fn
81 |
82 | def forward(self, x, *, cache=None, **kwargs):
83 | return self.fn(x, cache=cache, cache_key=self.cache_key, **kwargs)
84 |
85 |
86 | # https://arxiv.org/abs/2103.17239
87 | class LayerScale(nn.Module):
88 | def __init__(self, dim, depth, fn):
89 | super().__init__()
90 | if depth <= 18:
91 | init_eps = 0.1
92 | elif depth > 18 and depth <= 24:
93 | init_eps = 1e-5
94 | else:
95 | init_eps = 1e-6
96 |
97 | scale = torch.zeros(1, 1, dim).fill_(init_eps)
98 | self.scale = nn.Parameter(scale)
99 | self.fn = fn
100 |
101 | def forward(self, x, **kwargs):
102 | return self.fn(x, **kwargs) * self.scale
103 |
104 |
105 | # layer norm
106 |
107 |
108 | class PreNorm(nn.Module):
109 | def __init__(self, dim, fn, sandwich=False):
110 | super().__init__()
111 | self.norm = nn.LayerNorm(dim)
112 | self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity()
113 | self.fn = fn
114 |
115 | def forward(self, x, **kwargs):
116 | x = self.norm(x)
117 | x = self.fn(x, **kwargs)
118 | return self.norm_out(x)
119 |
120 |
121 | # feed forward
122 |
123 |
124 | class GEGLU(nn.Module):
125 | def forward(self, x):
126 | x, gates = x.chunk(2, dim=-1)
127 | return x * F.gelu(gates)
128 |
129 |
130 | class FeedForward(nn.Module):
131 | def __init__(self, dim, dropout=0.0, mult=4.0):
132 | super().__init__()
133 | self.net = nn.Sequential(
134 | nn.Linear(dim, dim * mult * 2),
135 | GEGLU(),
136 | nn.Dropout(dropout),
137 | nn.Linear(dim * mult, dim),
138 | )
139 |
140 | def forward(self, x, cache=None, cache_key=None):
141 | return self.net(x)
142 |
143 |
144 | # token shift classes
145 |
146 |
147 | class PreShiftToken(nn.Module):
148 | def __init__(self, fn, image_size, num_images, seq_len):
149 | super().__init__()
150 | self.fn = fn
151 | self.image_size = image_size
152 | self.seq_len = seq_len
153 | img_seq_len = ((image_size // num_images) ** 2) * num_images
154 | self.text_len = seq_len - img_seq_len + 1
155 |
156 | def forward(self, x, cache=None, cache_key=None, **kwargs):
157 |
158 | seq_len, image_size, text_len = self.seq_len, self.image_size, self.text_len
159 |
160 | if exists(cache) and cache_key in cache:
161 | offset = cache["offset"]
162 | assert offset >= text_len, "cached inference for text is not supported"
163 | q = cache[cache_key]
164 | assert isinstance(q, deque) and len(q) == image_size
165 |
166 | x_top, x_left, *x_pass = x[:, -1].chunk(4, dim=-1)
167 |
168 | q.append((x_top, x_left))
169 | x_top = q.popleft()[0]
170 | x_left = q[-2][1]
171 | if (offset - text_len) % image_size == 0:
172 | x_left = torch.zeros_like(x_left)
173 |
174 | x = torch.cat((x_top, x_left, *x_pass), dim=-1)
175 | return self.fn(x[:, None], cache=cache, **kwargs)
176 |
177 | n = x.shape[1]
178 |
179 | padding = seq_len - n + 1
180 |
181 | if n < text_len:
182 | return self.fn(x, **kwargs)
183 |
184 | # get text and image tokens
185 |
186 | x_text, x_img = x[:, :text_len], x[:, text_len:]
187 | x_img = F.pad(x_img, (0, 0, 0, padding))
188 | x_img = rearrange(x_img, "b (h w) d -> b h w d", h=image_size)
189 |
190 | # shift 1 from the left for text tokens
191 |
192 | # x_text_shift, x_text_pass = x_text.chunk(2, dim=-1)
193 | # x_text_shift = F.pad(x_text_shift, (0, 0, 1, -1))
194 | # x_text = torch.cat((x_text_shift, x_text_pass), dim=-1)
195 |
196 | # shift from top, left for image tokens
197 |
198 | x_img_shift_top, x_img_shift_left, *x_img_pass = x_img.chunk(4, dim=-1)
199 | x_img_shift_left = F.pad(x_img_shift_left, (0, 0, 1, -1))
200 | x_img_shift_top = F.pad(x_img_shift_top, (0, 0, 0, 0, 1, -1))
201 | x_img = torch.cat((x_img_shift_top, x_img_shift_left, *x_img_pass), dim=-1)
202 |
203 | # merge text and image sequence back together
204 |
205 | x_img = rearrange(x_img, "b h w d -> b (h w) d")
206 | x_img = x_img[:, :-padding]
207 | x = torch.cat((x_text, x_img), dim=1)
208 |
209 | if exists(cache):
210 | dummy_top, dummy_left, *_ = x[:, -1].chunk(4, dim=-1)
211 | dummy_top, dummy_left = torch.zeros_like(dummy_top), torch.zeros_like(
212 | dummy_left
213 | )
214 |
215 | q = deque()
216 | x_img = x_img[:, -image_size:]
217 | for _ in range(image_size - x_img.shape[1]):
218 | q.append((dummy_top, dummy_left))
219 | for i in range(x_img.shape[1]):
220 | q.append(x_img[:, i].chunk(4, dim=-1)[:2])
221 | cache[cache_key] = q
222 |
223 | return self.fn(x, cache=cache, **kwargs)
224 |
225 |
226 | # main transformer class
227 |
228 |
229 | class Transformer(nn.Module):
230 | def __init__(
231 | self,
232 | *,
233 | dim,
234 | depth,
235 | seq_len,
236 | reversible=False,
237 | causal=True,
238 | heads=8,
239 | dim_head=64,
240 | ff_mult=4,
241 | attn_dropout=0.0,
242 | ff_dropout=0.0,
243 | attn_types=None,
244 | image_fmap_size=None,
245 | num_images=None,
246 | stable=False,
247 | sandwich_norm=False,
248 | shift_tokens=False,
249 | rotary_emb=True,
250 | shared_attn_ids=None,
251 | shared_ff_ids=None,
252 | optimize_for_inference=False, # use cache-friendly masked attention instead of sparse one
253 | ):
254 | super().__init__()
255 | layers = nn.ModuleList([])
256 |
257 | self.seq_len = seq_len
258 | self.image_fmap_size = image_fmap_size
259 |
260 | attn_types = default(attn_types, ("full",))
261 | attn_types = cast_tuple(attn_types)
262 | attn_type_layer = islice(cycle(attn_types), depth)
263 |
264 | shared_attn_ids = cycle(default(shared_attn_ids, range(depth)))
265 | shared_ff_ids = cycle(default(shared_ff_ids, range(depth)))
266 | shared_attn_layers = {}
267 | shared_ff_layers = {}
268 |
269 | for (ind, attn_type, attn_id, ff_id) in zip(
270 | range(depth), attn_type_layer, shared_attn_ids, shared_ff_ids
271 | ):
272 |
273 | if attn_type == "full":
274 | attn_class = partial(Attention, stable=stable)
275 |
276 | elif attn_type == "axial_row":
277 | if optimize_for_inference:
278 | attn_class = partial(
279 | Attention,
280 | stable=stable,
281 | static_mask=self._get_attention_mask(attn_type),
282 | )
283 | else:
284 | attn_class = partial(
285 | SparseAxialCausalAttention,
286 | seq_len=seq_len,
287 | axis=0,
288 | image_size=image_fmap_size,
289 | stable=stable,
290 | )
291 | elif attn_type == "axial_col":
292 | if optimize_for_inference:
293 | attn_class = partial(
294 | Attention,
295 | stable=stable,
296 | static_mask=self._get_attention_mask(attn_type),
297 | )
298 | else:
299 | attn_class = partial(
300 | SparseAxialCausalAttention,
301 | seq_len=seq_len,
302 | axis=1,
303 | image_size=image_fmap_size,
304 | stable=stable,
305 | )
306 | elif attn_type == "conv_like":
307 | attn_class = partial(
308 | SparseConvCausalAttention,
309 | seq_len=seq_len,
310 | image_size=image_fmap_size,
311 | stable=stable,
312 | )
313 |
314 | else:
315 | raise ValueError(f'attention type "{attn_type}" is not valid')
316 |
317 | attn, reused_attn_type = shared_attn_layers.get(attn_id, (None, None))
318 |
319 | if not exists(attn):
320 | attn = attn_class(
321 | dim,
322 | causal=causal,
323 | seq_len=seq_len,
324 | heads=heads,
325 | dim_head=dim_head,
326 | dropout=attn_dropout,
327 | )
328 |
329 | shared_attn_layers[attn_id] = (attn, attn_type)
330 | elif attn_type != reused_attn_type:
331 | raise ValueError(
332 | "attn_types do not match shared_attn_ids "
333 | f'(ind = {ind}, attn_type = "{attn_type}", reused_attn_type = "{reused_attn_type}")'
334 | )
335 |
336 | ff = shared_ff_layers.get(ff_id)
337 | if not exists(ff):
338 | ff = FeedForward(dim, mult=ff_mult, dropout=ff_dropout)
339 | shared_ff_layers[ff_id] = ff
340 |
341 | if isinstance(attn, Attention):
342 | attn = CachedAs(f"attn_{ind}", attn)
343 | else:
344 | # at the moment, other attention classes don't support cache
345 | attn = NonCached(attn)
346 |
347 | ff = FeedForward(dim, mult=ff_mult, dropout=ff_dropout)
348 |
349 | if shift_tokens:
350 | attn = CachedAs(
351 | f"preshift_attn_{ind}",
352 | PreShiftToken(
353 | attn,
354 | image_size=image_fmap_size,
355 | num_images=num_images,
356 | seq_len=seq_len,
357 | ),
358 | )
359 | ff = CachedAs(
360 | f"preshift_ff_{ind}",
361 | PreShiftToken(
362 | ff,
363 | image_size=image_fmap_size,
364 | num_images=num_images,
365 | seq_len=seq_len,
366 | ),
367 | )
368 |
369 | layers.append(
370 | nn.ModuleList(
371 | [
372 | LayerScale(
373 | dim, ind + 1, PreNorm(dim, attn, sandwich=sandwich_norm)
374 | ),
375 | LayerScale(
376 | dim, ind + 1, PreNorm(dim, ff, sandwich=sandwich_norm)
377 | ),
378 | ]
379 | )
380 | )
381 |
382 | execute_type = ReversibleSequence if reversible else SequentialSequence
383 | route_attn = ((True, False),) * depth
384 | route_all = ((True, True),) * depth
385 | attn_route_map = {
386 | "mask": route_attn,
387 | "rotary_pos_emb": route_attn,
388 | "cache": route_all,
389 | }
390 |
391 | self.layers = execute_type(layers, args_route=attn_route_map)
392 |
393 | # generate positional embeddings for rotary
394 |
395 | pos_emb = None
396 | if rotary_emb:
397 |
398 | rot_dim = dim_head // 3
399 | img_seq_len = ((image_fmap_size // num_images) ** 2) * num_images
400 |
401 | text_len = seq_len - img_seq_len + 1
402 |
403 | text_pos_emb = RotaryEmbedding(dim=rot_dim)
404 |
405 | img_axial_pos_emb = RotaryEmbedding(dim=rot_dim, freqs_for="pixel")
406 |
407 | text_freqs = text_pos_emb(torch.arange(text_len))
408 |
409 | img_to_text_freqs = text_pos_emb(
410 | torch.full((img_seq_len,), 8192)
411 | ) # image is given a position far away from text
412 |
413 | text_freqs = torch.cat((text_freqs, img_to_text_freqs), dim=0)
414 |
415 | img_freqs_axial = img_axial_pos_emb(
416 | torch.linspace(-1, 1, steps=image_fmap_size)
417 | )
418 |
419 | if num_images > 1:
420 |
421 | split_img_freqs_axial = torch.split(
422 | img_freqs_axial, image_fmap_size // num_images, dim=0
423 | )
424 |
425 | split_img_freqs = [
426 | broadcat(
427 | (
428 | rearrange(img_freqs_axial_per_image, "i d -> i () d"),
429 | rearrange(img_freqs_axial_per_image, "j d -> () j d"),
430 | ),
431 | dim=-1,
432 | )
433 | for img_freqs_axial_per_image in split_img_freqs_axial
434 | ]
435 |
436 | split_img_freqs = [
437 | rearrange(img_freqs_per_image, "h w d -> (h w) d")
438 | for img_freqs_per_image in split_img_freqs
439 | ]
440 |
441 | # concat per image-image_freqs
442 |
443 | img_freqs = torch.cat(split_img_freqs, dim=0)
444 |
445 | elif num_images == 1:
446 | img_freqs = broadcat(
447 | (
448 | rearrange(img_freqs_axial, "i d -> i () d"),
449 | rearrange(img_freqs_axial, "j d -> () j d"),
450 | ),
451 | dim=-1,
452 | )
453 |
454 | img_freqs = rearrange(img_freqs, "h w d -> (h w) d")
455 |
456 | else:
457 | assert False, "num_images must be int greater than 0"
458 | self.img_axial_pos_emb = img_axial_pos_emb
459 | self.text_pos_emb = text_pos_emb
460 |
461 | text_axial_freqs = img_axial_pos_emb(
462 | torch.full((text_len,), -10.0)
463 | ) # text is given a position of -10 apart from the image axial positions, which is from range [-1, 1]
464 |
465 | text_axial_freqs = torch.cat((text_axial_freqs, text_axial_freqs), dim=-1)
466 |
467 | img_freqs = torch.cat((text_axial_freqs, img_freqs), dim=0)
468 |
469 | pos_emb = torch.cat((text_freqs, img_freqs), dim=-1)
470 |
471 | pos_emb = rearrange(pos_emb, "n d -> () n d")
472 |
473 | self.register_buffer("pos_emb", pos_emb)
474 |
475 | def forward(self, x, **kwargs):
476 | return self.layers(x, rotary_pos_emb=self.pos_emb, **kwargs)
477 |
478 | def _get_attention_mask(self, attn_type):
479 | img_seq_len = self.image_fmap_size ** 2
480 | text_len = self.seq_len + 1 - img_seq_len
481 |
482 | static_mask = torch.zeros(self.seq_len, self.seq_len, dtype=torch.bool)
483 | static_mask[:, :text_len] = True
484 | if attn_type == "axial_row":
485 | for row in range(self.image_fmap_size):
486 | begin = text_len + row * self.image_fmap_size
487 | end = text_len + (row + 1) * self.image_fmap_size
488 | static_mask[begin:end, begin:end] = True
489 | elif attn_type == "axial_col":
490 | for col in range(self.image_fmap_size):
491 | begin = text_len + col
492 | static_mask[
493 | begin :: self.image_fmap_size, begin :: self.image_fmap_size
494 | ] = True
495 | else:
496 | raise ValueError(
497 | f'attention type "{attn_type}" can\'t be simulated with a static mask'
498 | )
499 | return static_mask
--------------------------------------------------------------------------------
/celle/vae.py:
--------------------------------------------------------------------------------
1 | import os
2 | import urllib
3 | from tqdm import tqdm
4 | from math import sqrt, log
5 | from omegaconf import OmegaConf
6 | from taming.models.vqgan import VQModel, GumbelVQ
7 | import importlib
8 |
9 | import torch
10 | from torch import nn
11 | import torch.nn.functional as F
12 |
13 | from einops import rearrange
14 |
15 | # constants
16 |
17 | CACHE_PATH = os.path.expanduser("~/.cache/dalle")
18 |
19 | # helpers methods
20 |
21 |
22 | def exists(val):
23 | return val is not None
24 |
25 |
26 | def default(val, d):
27 | return val if exists(val) else d
28 |
29 |
30 | def load_model(path):
31 | with open(path, "rb") as f:
32 | return torch.load(f, map_location=torch.device("cpu"))
33 |
34 |
35 | def map_pixels(x, eps=0.1):
36 | return (1 - 2 * eps) * x + eps
37 |
38 |
39 | def unmap_pixels(x, eps=0.1):
40 | return torch.clamp((x - eps) / (1 - 2 * eps), 0, 1)
41 |
42 | def make_contiguous(module):
43 | with torch.no_grad():
44 | for param in module.parameters():
45 | param.set_(param.contiguous())
46 |
47 |
48 | # VQGAN from Taming Transformers paper
49 | # https://arxiv.org/abs/2012.09841
50 |
51 |
52 | def get_obj_from_str(string, reload=False):
53 | module, cls = string.rsplit(".", 1)
54 | if reload:
55 | module_imp = importlib.import_module(module)
56 | importlib.reload(module_imp)
57 | return getattr(importlib.import_module(module, package=None), cls)
58 |
59 |
60 | def instantiate_from_config(config):
61 | if not "target" in config:
62 | raise KeyError("Expected key `target` to instantiate.")
63 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
64 |
65 |
66 | class VQGanVAE(nn.Module):
67 | def __init__(self, vqgan_model_path=None, vqgan_config_path=None, channels = 1):
68 | super().__init__()
69 |
70 | assert(vqgan_model_path is not None)
71 |
72 | model_path = vqgan_model_path
73 | config_path = vqgan_config_path
74 |
75 | config = OmegaConf.load(config_path)
76 |
77 | model = instantiate_from_config(config["model"])
78 |
79 | state = torch.load(model_path, map_location="cpu")["state_dict"]
80 | model.load_state_dict(state, strict=False)
81 |
82 | print(f"Loaded VQGAN from {model_path} and {config_path}")
83 |
84 | self.model = model
85 |
86 | # f as used in https://github.com/CompVis/taming-transformers#overview-of-pretrained-models
87 | f = (
88 | config.model.params.ddconfig.resolution
89 | / config.model.params.ddconfig.attn_resolutions[0]
90 | )
91 | self.num_layers = int(log(f) / log(2))
92 | self.image_size = config.model.params.ddconfig.resolution
93 | self.num_tokens = config.model.params.n_embed
94 | self.is_gumbel = isinstance(self.model, GumbelVQ)
95 | self.channels = config.model.params.ddconfig.in_channels
96 |
97 | def get_codebook_indices(self, img):
98 | b = img.shape[0]
99 | # img = (2 * img) - 1
100 | _, _, [_, _, indices] = self.model.encode(img)
101 | if self.is_gumbel:
102 | return rearrange(indices, "b h w -> b (h w)", b=b)
103 | return rearrange(indices, "(b n) -> b n", b=b)
104 |
105 | def decode(self, img_seq):
106 | b, n = img_seq.shape
107 | one_hot_indices = F.one_hot(img_seq, num_classes=self.num_tokens).float()
108 | z = (
109 | one_hot_indices @ self.model.quantize.embed.weight
110 | if self.is_gumbel
111 | else (one_hot_indices @ self.model.quantize.embedding.weight)
112 | )
113 |
114 | z = rearrange(z, "b (h w) c -> b c h w", h=int(sqrt(n)))
115 | img = self.model.decode(z)
116 |
117 | # img = (img.clamp(-1.0, 1.0) + 1) * 0.5
118 | return img
119 |
120 | def forward(self, img, optimizer_idx=1):
121 | return self.model.training_step(img, optimizer_idx=optimizer_idx)
122 |
--------------------------------------------------------------------------------
/celle_main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | import torch
5 | import torch.random
6 | from torch.optim import Adam
7 | from torch.utils.data import DataLoader
8 | import pytorch_lightning as pl
9 | from pytorch_lightning import seed_everything
10 | from pytorch_lightning.trainer import Trainer
11 |
12 | from dataloader import OpenCellLoader
13 | from celle import VQGanVAE, CELLE
14 | from omegaconf import OmegaConf
15 | import argparse, os, sys, datetime, glob
16 |
17 |
18 | from celle.celle import gumbel_sample, top_k
19 |
20 | torch.random.manual_seed(42)
21 | np.random.seed(42)
22 |
23 | from celle_taming_main import (
24 | instantiate_from_config,
25 | nondefault_trainer_args,
26 | get_parser,
27 | )
28 |
29 |
30 | class CellDataModule(pl.LightningDataModule):
31 | def __init__(
32 | self,
33 | data_csv,
34 | sequence_mode="simple",
35 | vocab="bert",
36 | crop_size=256,
37 | batch_size=1,
38 | threshold=False,
39 | text_seq_len=1000,
40 | num_workers=1,
41 | **kwargs,
42 | ):
43 | super().__init__()
44 |
45 | self.data_csv = data_csv
46 | self.protein_sequence_length = 0
47 | self.image_folders = []
48 | self.crop_size = crop_size
49 | self.batch_size = batch_size
50 | self.sequence_mode = sequence_mode
51 | self.threshold = threshold
52 | self.text_seq_len = int(text_seq_len)
53 | self.vocab = vocab
54 | self.num_workers = num_workers if num_workers is not None else batch_size * 2
55 |
56 | def setup(self):
57 | # called on every GPU
58 | self.cell_dataset_train = OpenCellLoader(
59 | data_csv=self.data_csv,
60 | crop_size=self.crop_size,
61 | split_key="train",
62 | crop_method="random",
63 | sequence_mode=self.sequence_mode,
64 | vocab=self.vocab,
65 | text_seq_len=self.text_seq_len,
66 | threshold=self.threshold,
67 | )
68 |
69 | self.cell_dataset_val = OpenCellLoader(
70 | data_csv=self.data_csv,
71 | crop_size=self.crop_size,
72 | crop_method="center",
73 | split_key="val",
74 | sequence_mode=self.sequence_mode,
75 | vocab=self.vocab,
76 | text_seq_len=self.text_seq_len,
77 | threshold=self.threshold,
78 | )
79 |
80 | def prepare_data(self):
81 |
82 | pass
83 |
84 | def train_dataloader(self):
85 | return DataLoader(
86 | self.cell_dataset_train,
87 | num_workers=self.num_workers,
88 | shuffle=True,
89 | batch_size=self.batch_size,
90 | )
91 |
92 | def val_dataloader(self):
93 | return DataLoader(
94 | self.cell_dataset_val,
95 | num_workers=self.num_workers,
96 | batch_size=self.batch_size,
97 | )
98 |
99 | # def test_dataloader(self):
100 | # transforms = ...
101 | # return DataLoader(self.test, batch_size=64)
102 |
103 |
104 | class CELLE_trainer(pl.LightningModule):
105 | def __init__(
106 | self,
107 | vqgan_model_path,
108 | vqgan_config_path,
109 | ckpt_path=None,
110 | image_key="threshold",
111 | condition_model_path=None,
112 | condition_config_path=None,
113 | num_images=2,
114 | dim=2,
115 | num_text_tokens=30,
116 | text_seq_len=1000,
117 | depth=16,
118 | heads=16,
119 | dim_head=64,
120 | reversible=False,
121 | attn_dropout=0.1,
122 | ff_dropout=0.1,
123 | attn_types="full",
124 | loss_img_weight=7,
125 | stable=False,
126 | sandwich_norm=False,
127 | shift_tokens=True,
128 | rotary_emb=True,
129 | text_embedding="bert",
130 | fixed_embedding=True,
131 | loss_cond_weight=1,
132 | learning_rate=3e-4,
133 | monitor="val_loss",
134 | ):
135 | super().__init__()
136 |
137 | vae = VQGanVAE(
138 | vqgan_model_path=vqgan_model_path, vqgan_config_path=vqgan_config_path
139 | )
140 |
141 | self.image_key = image_key
142 |
143 | if condition_config_path:
144 | condition_vae = VQGanVAE(
145 | vqgan_model_path=condition_model_path,
146 | vqgan_config_path=condition_config_path,
147 | )
148 | else:
149 | condition_vae = None
150 |
151 | self.celle = CELLE(
152 | dim=dim,
153 | vae=vae, # automatically infer (1) image sequence length and (2) number of image tokens
154 | condition_vae=condition_vae,
155 | num_images=num_images,
156 | num_text_tokens=num_text_tokens, # vocab size for text
157 | text_seq_len=text_seq_len, # text sequence length
158 | depth=depth, # should aim to be 64
159 | heads=heads, # attention heads
160 | reversible=reversible, # should aim to be True
161 | dim_head=dim_head, # attention head dimension
162 | attn_dropout=attn_dropout, # attention dropout
163 | ff_dropout=ff_dropout,
164 | attn_types=attn_types,
165 | loss_img_weight=loss_img_weight,
166 | stable=stable,
167 | sandwich_norm=sandwich_norm,
168 | shift_tokens=shift_tokens,
169 | rotary_emb=rotary_emb,
170 | text_embedding=text_embedding,
171 | fixed_embedding=fixed_embedding,
172 | loss_cond_weight=loss_cond_weight
173 | # feedforward dropout
174 | )
175 |
176 | self.learning_rate = learning_rate
177 | self.num_text_tokens = num_text_tokens
178 | self.num_images = num_images
179 |
180 | if monitor is not None:
181 | self.monitor = monitor
182 |
183 | if ckpt_path is not None:
184 | self.init_from_ckpt(ckpt_path, ignore_keys=[])
185 |
186 | def init_from_ckpt(self, path, ignore_keys=list()):
187 | sd = torch.load(path, map_location="cpu")["state_dict"]
188 | for k in sd.keys():
189 | for ik in ignore_keys:
190 | if k.startswith(ik):
191 | self.print("Deleting key {} from state_dict.".format(k))
192 | del sd[k]
193 | self.celle.load_state_dict(sd, strict=False)
194 | print(f"Restored from {path}")
195 |
196 | def forward(self, text, condition, target, return_loss=True):
197 |
198 | return self.celle(
199 | text=text, condition=condition, image=target, return_loss=return_loss
200 | )
201 |
202 | def get_input(self, batch):
203 | text = batch["sequence"].squeeze(1)
204 | condition = batch["nucleus"]
205 | target = batch[self.image_key]
206 |
207 | return text, condition, target
208 |
209 | def get_image_from_logits(self, logits, temperature=0.9):
210 |
211 | filtered_logits = top_k(logits, thres=0.5)
212 | sample = gumbel_sample(filtered_logits, temperature=temperature, dim=-1)
213 |
214 | self.celle.vae.eval()
215 | out = self.celle.vae.decode(
216 | sample[:, self.celle.text_seq_len + self.celle.condition_seq_len :]
217 | - (self.celle.num_text_tokens + self.celle.num_condition_tokens)
218 | )
219 |
220 | return out
221 |
222 | def get_loss(self, text, condition, target):
223 |
224 |
225 | loss_dict = {}
226 |
227 | loss, loss_dict, logits = self(
228 | text, condition, target, return_loss=True
229 | )
230 |
231 | return loss, loss_dict
232 |
233 | def total_loss(
234 | self,
235 | loss,
236 | loss_dict={"loss_text": 0, "loss_cond": 0, "loss_img": 0},
237 | mode="train",
238 | ):
239 |
240 | loss_dict = {f"{mode}/{key}": value for key, value in loss_dict.items()}
241 |
242 | self.log(
243 | f"{mode}/loss_text",
244 | loss_dict[f"{mode}/loss_text"],
245 | prog_bar=True,
246 | logger=True,
247 | on_step=True,
248 | on_epoch=True,
249 | )
250 |
251 | self.log(
252 | f"{mode}/loss_cond",
253 | loss_dict[f"{mode}/loss_cond"],
254 | prog_bar=True,
255 | logger=True,
256 | on_step=True,
257 | on_epoch=True,
258 | )
259 |
260 | self.log(
261 | f"{mode}/loss_img",
262 | loss_dict[f"{mode}/loss_img"],
263 | prog_bar=True,
264 | logger=True,
265 | on_step=True,
266 | on_epoch=True,
267 | )
268 |
269 | return loss
270 |
271 | def training_step(self, batch, batch_idx):
272 |
273 | text, condition, target = self.get_input(batch)
274 | loss, log_dict = self.get_loss(text, condition, target)
275 |
276 | loss = self.total_loss(loss, log_dict, mode="train")
277 |
278 | return loss
279 |
280 | def validation_step(self, batch, batch_idx):
281 |
282 | with torch.no_grad():
283 |
284 | text, condition, target = self.get_input(batch)
285 | loss, log_dict = self.get_loss(text, condition, target)
286 |
287 | loss = self.total_loss(loss, log_dict, mode="val")
288 |
289 | return loss
290 |
291 | def configure_optimizers(self):
292 |
293 |
294 | optimizer = Adam(self.parameters(), lr=self.learning_rate)
295 |
296 | return optimizer
297 |
298 | def scale_image(self, image):
299 |
300 | for tensor in image:
301 | if torch.min(tensor) < 0:
302 | tensor += -torch.min(tensor)
303 | else:
304 | tensor -= torch.min(tensor)
305 |
306 | tensor /= torch.max(tensor)
307 |
308 | return image
309 |
310 | @torch.no_grad()
311 | def log_images(self, batch, **kwargs):
312 |
313 | log = dict()
314 | text, condition, target = self.get_input(batch)
315 | text = text.squeeze(1).to(self.device)
316 | condition = condition.to(self.device)
317 |
318 | out = self.celle.generate_images(text=text, condition=condition, use_cache=True)
319 |
320 | log["condition"] = self.scale_image(condition)
321 | log["output"] = self.scale_image(out)
322 | if self.image_key == "threshold":
323 | log["threshold"] = self.scale_image(target)
324 | log["target"] = self.scale_image(batch["target"])
325 | else:
326 | log["target"] = self.scale_image(target)
327 |
328 | return log
329 |
330 |
331 | # from https://github.com/CompVis/taming-transformers/blob/master/celle_main.py
332 |
333 | if __name__ == "__main__":
334 | # custom parser to specify config files, train, test and debug mode,
335 | # postfix, resume.
336 | # `--key value` arguments are interpreted as arguments to the trainer.
337 | # `nested.key=value` arguments are interpreted as config parameters.
338 | # configs are merged from left-to-right followed by command line parameters.
339 |
340 | # model:
341 | # learning_rate: float
342 | # target: path to lightning module
343 | # params:
344 | # key: value
345 | # data:
346 | # target: celle_main.DataModuleFromConfig
347 | # params:
348 | # batch_size: int
349 | # wrap: bool
350 | # train:
351 | # target: path to train dataset
352 | # params:
353 | # key: value
354 | # validation:
355 | # target: path to validation dataset
356 | # params:
357 | # key: value
358 | # test:
359 | # target: path to test dataset
360 | # params:
361 | # key: value
362 | # lightning: (optional, has sane defaults and can be specified on cmdline)
363 | # trainer:
364 | # additional arguments to trainer
365 | # logger:
366 | # logger to instantiate
367 | # modelcheckpoint:
368 | # modelcheckpoint to instantiate
369 | # callbacks:
370 | # callback1:
371 | # target: importpath
372 | # params:
373 | # key: value
374 |
375 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
376 |
377 | # add cwd for convenience and to make classes in this file available when
378 | # running as `python celle_main.py`
379 | # (in particular `celle_main.DataModuleFromConfig`)
380 | sys.path.append(os.getcwd())
381 |
382 | parser = get_parser()
383 | parser = Trainer.add_argparse_args(parser)
384 |
385 | opt, unknown = parser.parse_known_args()
386 | if opt.name and opt.resume:
387 | raise ValueError(
388 | "-n/--name and -r/--resume cannot be specified both."
389 | "If you want to resume training in a new log folder, "
390 | "use -n/--name in combination with --resume_from_checkpoint"
391 | )
392 | if opt.resume:
393 | if not os.path.exists(opt.resume):
394 | raise ValueError("Cannot find {}".format(opt.resume))
395 | if os.path.isfile(opt.resume):
396 | paths = opt.resume.split("/")
397 | idx = len(paths) - paths[::-1].index("logs") + 1
398 | logdir = "/".join(paths[:idx])
399 | ckpt = opt.resume
400 | else:
401 | assert os.path.isdir(opt.resume), opt.resume
402 | logdir = opt.resume.rstrip("/")
403 | ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
404 |
405 | opt.resume_from_checkpoint = ckpt
406 | base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
407 | opt.base = base_configs + opt.base
408 | _tmp = logdir.split("/")
409 | nowname = _tmp[_tmp.index("logs") + 1]
410 | else:
411 | if opt.name:
412 | name = "_" + opt.name
413 | elif opt.base:
414 | cfg_fname = os.path.split(opt.base[0])[-1]
415 | cfg_name = os.path.splitext(cfg_fname)[0]
416 | name = "_" + cfg_name
417 | else:
418 | name = ""
419 | nowname = now + name + opt.postfix
420 | logdir = os.path.join("logs", nowname)
421 |
422 | ckptdir = os.path.join(logdir, "checkpoints")
423 | cfgdir = os.path.join(logdir, "configs")
424 | seed_everything(opt.seed)
425 |
426 | try:
427 | # init and save configs
428 | configs = [OmegaConf.load(cfg) for cfg in opt.base]
429 | cli = OmegaConf.from_dotlist(unknown)
430 | config = OmegaConf.merge(*configs, cli)
431 | lightning_config = config.pop("lightning", OmegaConf.create())
432 | # merge trainer cli with config
433 | trainer_config = lightning_config.get("trainer", OmegaConf.create())
434 | # default to ddp
435 | # trainer_config["distributed_backend"] = "ddp"
436 | for k in nondefault_trainer_args(opt):
437 | trainer_config[k] = getattr(opt, k)
438 | if not "gpus" in trainer_config:
439 | del trainer_config["distributed_backend"]
440 | cpu = True
441 | else:
442 | gpuinfo = trainer_config["gpus"]
443 | print(f"Running on GPUs {gpuinfo}")
444 | cpu = False
445 | trainer_opt = argparse.Namespace(**trainer_config)
446 | lightning_config.trainer = trainer_config
447 |
448 | # model
449 | model = instantiate_from_config(config.model)
450 |
451 | # trainer and callbacks
452 | trainer_kwargs = dict()
453 |
454 | # default logger configs
455 | # NOTE wandb < 0.10.0 interferes with shutdown
456 | # wandb >= 0.10.0 seems to fix it but still interferes with pudb
457 | # debugging (wrongly sized pudb ui)
458 | # thus prefer testtube for now
459 | default_logger_cfgs = {
460 | "wandb": {
461 | "target": "pytorch_lightning.loggers.WandbLogger",
462 | "params": {
463 | "name": nowname,
464 | "save_dir": logdir,
465 | "offline": opt.debug,
466 | "id": nowname,
467 | },
468 | },
469 | "testtube": {
470 | "target": "pytorch_lightning.loggers.TestTubeLogger",
471 | "params": {
472 | "name": "testtube",
473 | "save_dir": logdir,
474 | },
475 | },
476 | }
477 | default_logger_cfg = default_logger_cfgs["testtube"]
478 | logger_cfg = lightning_config.logger or OmegaConf.create()
479 | logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
480 | trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
481 |
482 | # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
483 | # specify which metric is used to determine best models
484 | default_modelckpt_cfg = {
485 | "target": "pytorch_lightning.callbacks.ModelCheckpoint",
486 | "params": {
487 | "dirpath": ckptdir,
488 | "filename": "{epoch:06}",
489 | "verbose": True,
490 | "save_last": True,
491 | },
492 | }
493 | if hasattr(model, "monitor"):
494 | print(f"Monitoring {model.monitor} as checkpoint metric.")
495 | default_modelckpt_cfg["params"]["monitor"] = model.monitor
496 | default_modelckpt_cfg["params"]["save_top_k"] = 3
497 |
498 | modelckpt_cfg = lightning_config.modelcheckpoint or OmegaConf.create()
499 | modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
500 | trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
501 |
502 | # add callback which sets up log directory
503 | default_callbacks_cfg = {
504 | "setup_callback": {
505 | "target": "celle_taming_main.SetupCallback",
506 | "params": {
507 | "resume": opt.resume,
508 | "now": now,
509 | "logdir": logdir,
510 | "ckptdir": ckptdir,
511 | "cfgdir": cfgdir,
512 | "config": config,
513 | "lightning_config": lightning_config,
514 | },
515 | },
516 | "image_logger": {
517 | "target": "celle_taming_main.ImageLogger",
518 | "params": {
519 | "batch_frequency": 1500,
520 | "max_images": 5,
521 | "clamp": False,
522 | "increase_log_steps": False,
523 | },
524 | },
525 | # "learning_rate_logger": {
526 | # "target": "celle_taming_main.LearningRateMonitor",
527 | # "params": {
528 | # "logging_interval": "step",
529 | # # "log_momentum": True
530 | # },
531 | # },
532 | }
533 | callbacks_cfg = lightning_config.callbacks or OmegaConf.create()
534 | callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
535 | trainer_kwargs["callbacks"] = [
536 | instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
537 | ]
538 |
539 | trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
540 |
541 | # data
542 | data = instantiate_from_config(config.data)
543 | # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
544 | # calling these ourselves should not be necessary but it is.
545 | # lightning still takes care of proper multiprocessing though
546 | data.setup()
547 | data.prepare_data()
548 |
549 | # configure learning rate
550 | bs, lr = config.data.params.batch_size, config.model.learning_rate
551 |
552 | if not cpu:
553 | ngpu = len(lightning_config.trainer.gpus.strip(",").split(","))
554 | else:
555 | ngpu = 1
556 | accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches or 1
557 | print(f"accumulate_grad_batches = {accumulate_grad_batches}")
558 | lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
559 | model.learning_rate = accumulate_grad_batches * ngpu * bs * lr
560 |
561 | print(
562 | "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (lr)".format(
563 | model.learning_rate, accumulate_grad_batches, ngpu, bs, lr
564 | )
565 | )
566 |
567 | # allow checkpointing via USR1
568 | def melk(*args, **kwargs):
569 | # run all checkpoint hooks
570 | if trainer.global_rank == 0:
571 | print("Summoning checkpoint.")
572 | ckpt_path = os.path.join(ckptdir, "last.ckpt")
573 | trainer.save_checkpoint(ckpt_path)
574 |
575 | def divein(*args, **kwargs):
576 | if trainer.global_rank == 0:
577 | import pudb
578 |
579 | pudb.set_trace()
580 |
581 | import signal
582 |
583 | signal.signal(signal.SIGUSR1, melk)
584 | signal.signal(signal.SIGUSR2, divein)
585 |
586 | # run
587 | if opt.train:
588 | try:
589 | trainer.fit(model, data)
590 | except Exception:
591 | melk()
592 | raise
593 | if not opt.no_test and not trainer.interrupted:
594 | trainer.test(model, data)
595 | except Exception:
596 | if opt.debug and trainer.global_rank == 0:
597 | try:
598 | import pudb as debugger
599 | except ImportError:
600 | import pdb as debugger
601 | debugger.post_mortem()
602 | raise
603 | finally:
604 | # move newly created debug project to debug_runs
605 | if opt.debug and not opt.resume and trainer.global_rank == 0:
606 | dst, name = os.path.split(logdir)
607 | dst = os.path.join(dst, "debug_runs", name)
608 | os.makedirs(os.path.split(dst)[0], exist_ok=True)
609 | os.rename(logdir, dst)
610 |
--------------------------------------------------------------------------------
/configs/celle.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | learning_rate: 0.0003
3 | target: celle_main.CELLE_trainer
4 | params:
5 | ckpt_path:
6 | condition_model_path:
7 | condition_config_path:
8 | vqgan_model_path:
9 | vqgan_config_path:
10 | image_key: threshold
11 | num_images: 2
12 | dim: 768
13 | num_text_tokens: 30
14 | text_seq_len: 1000
15 | depth: 32
16 | heads: 16
17 | dim_head: 64
18 | attn_dropout: 0.1
19 | ff_dropout: 0.1
20 | attn_types: full
21 | rotary_emb: true
22 | fixed_embedding: True
23 | monitor: val/loss_img_epoch
24 | text_embedding: bert
25 |
26 | data:
27 | target: celle_main.CellDataModule
28 | params:
29 | data_csv:
30 | crop_size: 256
31 | batch_size: 1
32 | sequence_mode: embedding
33 | vocab: bert
34 | threshold: true
35 | text_seq_len: 1000
36 | num_workers: 4
37 |
--------------------------------------------------------------------------------
/configs/nucleus_vqgan.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: taming.models.vqgan.VQModel
4 | params:
5 | image_key: nucleus
6 | monitor:
7 | ckpt_path:
8 | embed_dim: 256
9 | n_embed: 512
10 | ddconfig:
11 | double_z: false
12 | z_channels: 256
13 | resolution: 256
14 | in_channels: 1
15 | out_ch: 1
16 | ch: 128
17 | ch_mult:
18 | - 1
19 | - 1
20 | - 2
21 | - 2
22 | - 4
23 | num_res_blocks: 2
24 | attn_resolutions:
25 | - 16
26 | dropout: 0.0
27 | lossconfig:
28 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
29 | params:
30 | disc_conditional: false
31 | disc_in_channels: 1
32 | disc_start: 50000
33 | disc_weight: 0.2
34 | codebook_weight: 1.0
35 |
36 | data:
37 | target: celle_main.CellDataModule
38 | params:
39 | data_csv:
40 | crop_size: 256
41 | batch_size: 1
42 | num_workers: 8
43 |
--------------------------------------------------------------------------------
/configs/threshold_vqgan.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: taming.models.vqgan.VQModel
4 | params:
5 | image_key: threshold
6 | monitor:
7 | embed_dim: 256
8 | n_embed: 512
9 | ddconfig:
10 | double_z: false
11 | z_channels: 256
12 | resolution: 256
13 | in_channels: 1
14 | out_ch: 1
15 | ch: 128
16 | ch_mult:
17 | - 1
18 | - 1
19 | - 2
20 | - 2
21 | - 4
22 | num_res_blocks: 2
23 | attn_resolutions:
24 | - 16
25 | dropout: 0.0
26 | lossconfig:
27 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
28 | params:
29 | disc_conditional: false
30 | disc_in_channels: 1
31 | disc_start: 50000
32 | disc_weight: 0.2
33 | codebook_weight: 1.0
34 |
35 | data:
36 | target: celle_main.CellDataModule
37 | params:
38 | data_csv:
39 | crop_size: 256
40 | batch_size: 1
41 | num_workers: 8
42 | threshold: True
--------------------------------------------------------------------------------
/data/aaDescriptors.csv:
--------------------------------------------------------------------------------
1 | "AA","PP1","PP2","PP3","KF1","KF2","KF3","KF4","KF5","KF6","KF7","KF8","KF9","KF10","Z1","Z2","Z3","Z4","Z5","F1","F2","F3","F4","F5","F6","T1","T2","T3","T4","T5","VHSE1","VHSE2","VHSE3","VHSE4","VHSE5","VHSE6","VHSE7","VHSE8","ProtFP1","ProtFP2","ProtFP3","ProtFP4","ProtFP5","ProtFP6","ProtFP7","ProtFP8","ST1","ST2","ST3","ST4","ST5","ST6","ST7","ST8","BLOSUM1","BLOSUM2","BLOSUM3","BLOSUM4","BLOSUM5","BLOSUM6","BLOSUM7","BLOSUM8","BLOSUM9","BLOSUM10","MSWHIM1","MSWHIM2","MSWHIM3"
2 | "A",-0.96,-0.76,0.31,-1.56,-1.67,-0.97,-0.27,-0.93,-0.78,-0.2,-0.08,0.21,-0.48,0.24,-2.32,0.6,-0.14,1.3,0.207,0.821,-1.009,1.387,0.063,-0.6,-9.11,-1.63,0.63,1.04,2.26,0.15,-1.11,-1.35,-0.92,0.02,-0.91,0.36,-0.48,-0.1,-4.94,-2.13,1.7,-0.39,1.06,-1.39,0.97,-1.552,-0.791,-0.627,0.237,-0.461,-2.229,0.283,1.221,0.08,-0.92,0.53,0,0.24,0.19,0.66,-0.05,1.36,0.33,-0.73,0.2,-0.62
3 | "R",0.8,0.63,0.99,0.22,1.27,1.37,1.87,-1.7,0.46,0.92,-0.39,0.23,0.93,3.52,2.5,-3.5,1.99,-0.17,-1.229,0.378,0.516,-0.328,-0.052,2.728,0.23,3.89,-1.16,-0.39,-0.06,-1.47,1.45,1.24,1.27,1.55,1.47,1.3,0.83,-2.79,6.6,1.21,2.07,1.67,0.76,0,0.32,-0.059,0.731,-0.013,-0.096,-0.253,0.3,1.256,0.854,1.01,0.19,-0.86,-0.61,1.28,0.2,0.66,0.18,-0.22,-0.52,-0.22,0.27,1
4 | "N",0.82,-0.57,0.02,1.14,-0.07,-0.12,0.81,0.18,0.37,-0.09,1.23,1.1,-1.73,3.05,1.62,1.04,-1.15,1.61,-1.009,-0.939,-0.428,-0.397,-0.539,-0.605,-4.62,0.66,1.16,-0.22,0.93,-0.99,0,-0.37,0.69,-0.55,0.85,0.73,-0.8,-4.88,0.81,0.14,-0.14,1.23,-0.65,1.02,-1.94,-0.888,-0.057,-0.651,-0.214,0.917,0.164,-0.14,-0.166,1.51,0.22,-0.05,1.01,0.12,0.83,-0.03,-0.57,-1.2,-0.14,0.14,0.2,-0.66
5 | "D",1,-0.89,-1,0.58,-0.22,-1.58,0.81,-0.92,0.15,-1.52,0.47,0.76,0.7,3.98,0.93,1.93,-2.46,0.75,-1.298,-0.444,-0.584,-0.175,-0.259,-1.762,-4.65,0.75,1.39,-0.4,1.05,-1.15,0.67,-0.41,-0.01,-2.68,1.31,0.03,0.56,-6.61,0.94,-3.04,-4.58,0.48,-1.31,0.1,0.94,-0.907,-0.054,-0.781,-0.248,1.12,0.101,-0.245,-0.075,1.55,0.01,0.32,0.49,-0.99,0.01,-1.62,0.53,-0.15,-0.28,0.11,-1,-0.96
6 | "C",-0.55,-0.47,0.19,0.12,-0.89,0.45,-1.05,-0.71,2.41,1.52,-0.69,1.13,1.1,0.84,-1.67,3.75,0.18,-2.65,0.997,0.021,-1.419,-2.08,-0.799,0.502,-7.35,-0.86,-0.33,0.8,0.98,0.18,-1.67,-0.46,-0.21,0,1.2,-1.61,-0.19,4.62,-3.54,1.5,-1.26,3.27,-0.34,-0.47,-0.23,-1.276,-0.401,0.134,0.859,-0.196,-0.72,0.639,-0.857,-1.08,-1.11,1.56,0.81,1.83,-1.05,-0.74,0.38,-0.12,-0.1,-0.66,0.26,-0.27
7 | "Q",0.78,-0.3,-0.38,-0.47,0.24,0.07,1.1,1.1,0.59,0.84,-0.71,-0.03,-2.33,1.75,0.5,-1.44,-1.34,0.66,-0.88,0.381,-0.044,-0.455,-0.04,0.405,-3,1.72,0.28,-0.39,0.33,-0.96,0.12,0.18,0.16,0.09,0.42,-0.2,-0.41,-3.95,2.88,-0.83,0.52,0.9,0.55,-0.08,0.64,-0.662,0.228,-0.193,-0.105,0.418,0.474,0.172,0.408,1.09,0.3,-0.87,-0.72,0.5,-0.08,-0.44,0.2,0.38,0.67,0.3,1,-0.3
8 | "E",0.94,-0.54,-0.99,-1.45,0.19,-1.61,1.17,-1.31,0.4,0.04,0.38,-0.35,-0.12,3.11,0.26,-0.11,-3.04,-0.25,-1.349,1.388,-0.361,0.213,0.424,-1.303,-3.03,1.82,0.51,-0.58,0.43,-1.18,0.4,0.1,0.36,-2.16,-0.17,0.91,0.02,-5.1,2.2,-3.59,-2.26,-2.14,1.35,-0.45,-1.31,-0.629,-0.39,-0.38,-0.366,0.635,0.514,0.175,0.367,1.48,0.23,-0.67,-0.36,-0.28,-0.08,-1.01,0.36,0.77,0.3,0.24,-0.39,-0.04
9 | "G",-0.88,-1,0.49,1.46,-1.96,-0.23,-0.16,0.1,-0.11,1.32,2.36,-1.66,0.46,2.05,-4.06,0.36,-0.82,-0.38,-0.205,-2.219,-1.656,1.229,-1.115,-1.146,-10.61,-1.21,-0.12,0.75,3.25,-0.2,-1.53,-2.63,2.28,-0.53,-1.18,2.01,-1.34,-5.7,-8.72,4.18,-1.35,-0.31,2.91,0.32,-0.11,-1.844,-0.018,-0.184,0.573,-0.728,-3.317,0.166,2.522,0.85,0.17,1.73,0.09,-0.55,1.19,1.21,0.87,0.01,0.24,-0.31,-0.28,-0.75
10 | "H",0.67,-0.11,0.37,-0.41,0.52,-0.28,0.28,1.61,1.01,-1.85,0.47,1.13,1.63,2.47,1.95,0.26,3.9,0.09,-0.27,0.461,-0.024,-1.407,0.001,0.169,-1.01,-1.31,0.01,-1.81,-0.21,-0.43,-0.25,0.37,0.19,0.51,1.28,0.93,0.65,0.17,2.14,1.2,0.71,1.16,-0.38,-1.85,-2.79,-0.225,0.361,0.079,-1.037,0.568,0.273,1.208,-0.001,0.72,1.55,-0.8,1.55,0.35,-0.79,0.66,-0.08,-0.19,0.99,0.84,0.67,-0.78
11 | "I",-0.94,-0.05,-0.18,-0.73,-0.16,1.79,-0.77,-0.54,0.03,-0.83,0.51,0.66,-1.78,-3.89,-1.73,-1.71,-0.84,0.26,1.524,0.536,0.809,0.734,-0.196,0.427,-4.25,-0.28,-0.15,1.4,-0.21,1.27,-0.14,0.3,-1.8,0.3,-1.61,-0.16,-0.13,6.58,-1.73,-2.49,1.09,-0.34,-0.28,1.97,-0.92,-0.785,-1.01,-0.349,-0.097,-0.402,1.091,-0.139,-0.764,-1.46,-1.13,-0.76,0.38,-0.6,0.28,-0.13,0.2,-0.22,0.21,-0.91,0.83,-0.25
12 | "L",-0.9,0.03,-0.24,-1.04,0,-0.24,-1.1,-0.55,-2.05,0.96,-0.76,0.45,0.93,-4.28,-1.3,-1.49,-0.72,0.84,1.2,1.128,0.703,1.904,0.536,-0.141,-4.38,0.28,-0.49,1.45,0.02,1.36,0.07,0.36,-0.8,0.22,-1.37,0.08,-0.62,5.76,-1.33,-1.71,0.63,-1.7,0.71,-0.05,-0.51,-0.826,-0.379,0.038,-0.059,-0.625,1.025,-0.229,-0.129,-1.41,-0.86,-0.88,-0.17,0.03,0.34,0.11,0.15,-0.44,-0.02,-0.74,0.72,-0.16
13 | "K",0.6,0.1,1,-0.34,0.82,-0.23,1.7,1.54,-1.62,1.15,-0.08,-0.48,0.6,2.29,0.89,-2.49,1.49,0.31,-1.387,0.572,0.285,0.333,-0.169,1.157,-2.59,2.34,-1.69,0.41,-0.21,-1.17,0.7,0.7,0.8,1.64,0.67,1.63,0.13,-4.99,5,0.7,3,-1.23,1.41,0.19,0.87,-0.504,0.245,0.297,-0.065,-0.387,1.011,0.525,0.553,1.14,-0.04,-0.8,-0.85,0.82,0.1,0.21,0.13,0.18,-0.85,-0.51,0.08,0.6
14 | "M",-0.82,0.03,-0.08,-1.4,0.18,-0.42,-0.73,2,1.52,0.26,0.11,-1.27,0.27,-2.85,-0.22,0.47,1.94,-0.98,0.886,1.346,0.277,-0.913,0.007,-0.265,-4.08,0.98,-2.34,1.64,-0.79,1.01,-0.53,0.43,0,0.23,0.1,-0.86,-0.68,5.11,0.19,-1.02,0.15,0.13,-0.3,-2.95,0.5,-0.693,0.498,0.658,0.457,-0.231,1.064,0.248,-0.778,-0.96,-0.59,-0.97,-0.53,0.24,0.37,0.06,0.21,-0.56,0.36,-0.7,1,-0.32
15 | "F",-0.85,0.48,-0.58,-0.21,0.98,-0.36,-1.43,0.22,-0.81,0.67,1.1,1.71,-0.44,-4.22,1.94,1.06,0.54,-0.62,1.247,0.293,1.336,-0.026,0.012,-0.015,0.49,-0.94,-0.63,-1.27,-0.44,1.52,0.61,0.96,-0.16,0.25,0.28,-1.33,-0.2,6.76,0.88,0.89,-1.12,-0.49,-0.55,-0.87,1.05,-0.019,0.024,1.08,-0.22,-0.937,0.57,-0.357,0.278,-1.62,1.01,-0.31,0.62,-0.55,0.29,-0.02,0.1,0.43,-1.29,0.76,0.85,-0.34
16 | "P",-0.81,-0.4,-0.07,2.06,-0.33,-1.15,-0.75,0.88,-0.45,0.3,-2.3,0.74,-0.28,-1.66,0.27,1.84,0.7,2,-0.407,-2.038,-0.564,-0.128,3.847,-1.108,-5.11,-3.54,-0.53,-0.36,-0.29,0.22,-0.17,-0.5,0.05,-0.01,-1.34,-0.19,3.56,-3.82,-2.31,3.45,1,-3.22,-3.54,-0.36,-0.3,-1.049,-0.407,-0.067,-0.066,-0.813,-0.89,0.021,-0.894,0.88,-0.68,0.38,-0.87,-1.24,-2.02,0.85,-0.35,-0.42,-0.3,-0.43,0.73,-0.6
17 | "S",0.41,-0.82,0.57,0.81,-1.08,0.16,0.42,-0.21,-0.43,-1.89,-1.15,-0.97,-0.23,2.39,-1.07,1.15,-1.39,0.67,-0.495,-0.847,-1.079,0.582,0.035,-0.068,-7.44,-0.65,0.68,-0.17,1.58,-0.67,-0.86,-1.07,-0.41,-0.32,0.27,-0.64,0.11,-4.57,-2.55,-0.67,1.11,0.99,-1.02,0.11,0.65,-1.343,-0.311,-0.917,-0.049,0.549,-1.533,0.166,0.28,0.84,-0.45,0.42,0.32,0.2,0.54,0.01,-0.8,0.62,-0.13,-0.8,0.61,-1
18 | "T",0.4,-0.64,0.37,0.26,-0.7,1.21,0.63,-0.1,0.21,0.24,-1.15,-0.56,0.19,0.75,-2.18,-1.12,-1.46,-0.4,-0.032,-0.45,-0.61,0.341,0.117,0.577,-5.97,-0.62,1.11,0.31,0.95,-0.34,-0.51,-0.55,-1.06,-0.06,-0.01,-0.79,0.39,-2,-1.77,-0.7,1.02,1.06,-1.2,0.74,1.65,-1.061,-0.928,-0.911,-0.063,0.538,-0.775,-0.147,-0.717,0.19,-0.73,0.18,-0.01,0.02,0.38,-0.3,-1.96,0.15,0.06,-0.58,0.85,-0.89
19 | "W",0.06,1,-0.47,0.3,2.1,-0.72,-1.57,-1.16,0.57,-0.48,-0.4,-2.3,-0.6,-4.36,3.94,0.59,3.44,-1.59,0.844,-0.075,2.069,-1.36,-0.81,-0.38,5.73,-2.67,-0.07,-1.96,-0.54,1.5,2.06,1.79,0.75,0.75,-0.13,-1.01,-0.85,7.33,4.55,2.77,-2.41,-1.08,1.04,0.23,0.59,0.853,0.039,0.26,-1.163,0.16,-0.202,1.01,0.195,-1.58,2.28,1.17,-1.61,0.12,0.24,-0.54,-0.4,-0.35,0.5,1,0.98,-0.47
20 | "Y",0.31,0.42,-0.2,1.38,1.48,0.8,-0.56,0,-0.68,-0.31,1.03,-0.05,0.53,-2.54,2.44,0.43,0.04,-1.47,0.329,-0.858,1.753,-0.479,-0.835,0.289,2.08,-0.47,0.07,-1.67,-0.35,0.61,1.6,1.17,0.73,0.53,0.25,-0.96,-0.52,3.14,3.59,2.45,-1.27,-0.06,-0.29,1.99,0.3,0.308,0.569,1.1,-0.464,-0.144,-0.354,-1.099,0.162,-1.14,1.74,-0.58,0.75,-0.12,-0.48,0.24,-0.25,0.71,-0.25,0.97,0.66,-0.16
21 | "V",-1,-0.43,-0.14,-0.74,-0.71,2.04,-0.4,0.5,-0.81,-1.07,0.06,-0.46,0.65,-2.59,-2.64,-1.54,-0.85,-0.02,-1.332,0.545,0.029,1.026,-0.229,1.038,-5.87,-0.94,0.28,1.1,0.48,0.76,-0.92,-0.17,-1.91,0.22,-1.4,-0.24,-0.03,5.04,-2.9,-2.29,1.38,0.06,0.08,1.79,-0.38,-1.133,-0.893,-0.325,0.303,-0.561,-0.175,-0.02,-0.311,-1.13,-1.23,-0.63,0.06,-0.6,0.16,0.01,0.02,0.25,0.61,-1,0.79,-0.58
22 |
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from PIL import Image
4 | import random
5 | import json
6 | import pandas as pd
7 |
8 | import torch
9 | from torch.utils.data import Dataset
10 | from torchvision import transforms
11 | import torchvision.transforms.functional as TF
12 |
13 |
14 | def simple_conversion(seq):
15 | """Create 26-dim embedding"""
16 | chars = [
17 | "-",
18 | "M",
19 | "R",
20 | "H",
21 | "K",
22 | "D",
23 | "E",
24 | "S",
25 | "T",
26 | "N",
27 | "Q",
28 | "C",
29 | "U",
30 | "G",
31 | "P",
32 | "A",
33 | "V",
34 | "I",
35 | "F",
36 | "Y",
37 | "W",
38 | "L",
39 | "O",
40 | "X",
41 | "Z",
42 | "B",
43 | "J",
44 | ]
45 |
46 | nums = range(len(chars))
47 |
48 | seqs_x = np.zeros(len(seq))
49 |
50 | for idx, char in enumerate(seq):
51 |
52 | lui = chars.index(char)
53 |
54 | seqs_x[idx] = nums[lui]
55 |
56 | return torch.tensor([seqs_x]).long()
57 |
58 |
59 | def convert_descriptor(seq):
60 | seq_dict = {
61 | "": 0,
62 | "M": 1,
63 | "R": 2,
64 | "H": 3,
65 | "K": 4,
66 | "D": 5,
67 | "E": 6,
68 | "S": 7,
69 | "T": 8,
70 | "N": 9,
71 | "Q": 10,
72 | "C": 11,
73 | "G": 12,
74 | "P": 13,
75 | "A": 14,
76 | "V": 15,
77 | "I": 16,
78 | "F": 17,
79 | "Y": 18,
80 | "W": 19,
81 | "L": 20,
82 | "": 21,
83 | }
84 | seq = seq.upper()
85 | return torch.tensor([seq_dict[char] for char in seq]).long()
86 |
87 |
88 | class OpenCellLoader(Dataset):
89 | """imports mined opencell images with protein sequence"""
90 |
91 | def __init__(
92 | self,
93 | data_csv,
94 | split_key=None,
95 | crop_size=600,
96 | crop_method="random",
97 | sequence_mode="simple",
98 | vocab="bert",
99 | threshold=False,
100 | text_seq_len=0,
101 | ):
102 | self.data_csv = data_csv
103 | self.image_folders = []
104 | self.crop_method = crop_method
105 | self.crop_size = crop_size
106 | self.sequence_mode = sequence_mode
107 | self.threshold = threshold
108 | self.text_seq_len = int(text_seq_len)
109 | self.vocab = vocab
110 |
111 | if self.sequence_mode == "embedding" or self.sequence_mode == "onehot":
112 |
113 | from tape import TAPETokenizer
114 |
115 | if self.vocab == "unirep" or self.sequence_mode == "onehot":
116 | self.tokenizer = TAPETokenizer(vocab="unirep")
117 |
118 | elif self.vocab == "bert":
119 | self.tokenizer = TAPETokenizer(vocab="iupac")
120 |
121 | elif self.vocab == "esm1b":
122 | from esm import Alphabet
123 |
124 | self.tokenizer = Alphabet.from_architecture(
125 | "ESM-1b"
126 | ).get_batch_converter()
127 |
128 | data = pd.read_csv(data_csv)
129 |
130 | self.parent_path = os.path.dirname(data_csv).split(data_csv)[0]
131 |
132 | if split_key == "train":
133 | self.data = data[data["split"] == "train"]
134 | elif split_key == "val":
135 | self.data = data[data["split"] == "val"]
136 | else:
137 | self.data = data
138 |
139 | self.data = self.data.reset_index(drop=True)
140 |
141 | def __len__(self):
142 | return len(self.data)
143 |
144 | def __getitem__(self, idx):
145 |
146 | protein_vector = self.get_protein_vector(idx)
147 |
148 | nucleus, target, threshold = self.get_images(idx)
149 |
150 | data_dict = {
151 | "nucleus": nucleus.float(),
152 | "target": target.float(),
153 | "threshold": threshold.float(),
154 | "sequence": protein_vector.long(),
155 | }
156 |
157 | return data_dict
158 |
159 | def get_protein_vector(self, idx):
160 |
161 | if "protein_sequence" not in self.data.columns:
162 |
163 | metadata = self.retrieve_metadata(idx)
164 | protein_sequence = metadata["sequence"]
165 | else:
166 | protein_sequence = self.data.iloc[idx]["protein_sequence"]
167 |
168 | protein_vector = self.tokenize_seqeuence(protein_sequence)
169 |
170 | return protein_vector
171 |
172 | def get_images(self, idx):
173 |
174 | nucleus = Image.open(
175 | os.path.join(self.parent_path, self.data.iloc[idx]["nucleus_image_path"])
176 | )
177 | target = Image.open(
178 | os.path.join(self.parent_path, self.data.iloc[idx]["protein_image_path"])
179 | )
180 |
181 | # from https://discuss.pytorch.org/t/how-to-apply-same-transform-on-a-pair-of-picture/14914
182 |
183 | if self.crop_method == "random":
184 |
185 | # Random crop
186 | i, j, h, w = transforms.RandomCrop.get_params(
187 | nucleus, output_size=(self.crop_size, self.crop_size)
188 | )
189 |
190 | nucleus = TF.crop(nucleus, i, j, h, w)
191 | target = TF.crop(target, i, j, h, w)
192 |
193 | # Random horizontal flipping
194 | if random.random() > 0.5:
195 | nucleus = TF.hflip(nucleus)
196 | target = TF.hflip(target)
197 |
198 | # Random vertical flipping
199 | if random.random() > 0.5:
200 | nucleus = TF.vflip(nucleus)
201 | target = TF.vflip(target)
202 |
203 | elif self.crop_method == "center":
204 | nucleus = TF.center_crop(nucleus, self.crop_size)
205 | target = TF.center_crop(target, self.crop_size)
206 |
207 | nucleus = TF.to_tensor(nucleus)
208 | target = TF.to_tensor(target)
209 |
210 | threshold = target
211 |
212 | if self.threshold:
213 | threshold = 1.0 * (threshold > (torch.mean(threshold)))
214 |
215 | return nucleus, target, threshold
216 |
217 | def retrieve_metadata(self, idx):
218 |
219 | with open(
220 | os.path.join(self.parent_path, self.data.iloc[idx]["metadata_path"])
221 | ) as f:
222 | metadata = json.load(f)
223 |
224 | return metadata
225 |
226 | def tokenize_seqeuence(self, protein_sequence):
227 |
228 | prot_len = len(protein_sequence)
229 |
230 | if prot_len > self.text_seq_len:
231 | start_int = np.random.randint(0, len(protein_sequence) - self.text_seq_len)
232 | protein_sequence = protein_sequence[
233 | start_int : start_int + self.text_seq_len
234 | ]
235 |
236 | if self.sequence_mode == "simple":
237 | protein_vector = simple_conversion(protein_sequence)
238 |
239 | elif self.sequence_mode == "center":
240 | protein_sequence = protein_sequence.center(self.text_seq_length, "-")
241 | protein_vector = simple_conversion(protein_sequence)
242 |
243 | elif self.sequence_mode == "alternating":
244 | protein_sequence = protein_sequence.center(self.text_seq_length, "-")
245 | protein_sequence = protein_sequence[::18]
246 | protein_sequence = protein_sequence.center(
247 | int(self.text_seq_length / 18) + 1, "-"
248 | )
249 | protein_vector = simple_conversion(protein_sequence)
250 |
251 | elif self.sequence_mode == "onehot":
252 |
253 | protein_vector = torch.tensor([self.tokenizer.encode(protein_sequence)])[
254 | :, 1:-1
255 | ]
256 |
257 | elif self.sequence_mode == "aadescriptors":
258 |
259 | protein_vector = convert_descriptor(protein_sequence).long().unsqueeze(0)
260 |
261 | elif self.sequence_mode == "embedding":
262 |
263 | if self.vocab == "esm1b":
264 | pad_token = 1
265 |
266 | protein_vector = self.tokenizer([("", protein_sequence)])[-1][:, 1:]
267 |
268 | elif self.vocab == "unirep" or self.vocab == "bert":
269 | pad_token = 0
270 | protein_vector = torch.tensor(
271 | [self.tokenizer.encode(protein_sequence)]
272 | )[:, 1:]
273 |
274 |
275 | if prot_len > self.text_seq_len:
276 | protein_vector = protein_vector[:, :-1]
277 | elif prot_len == self.text_seq_len:
278 | protein_vector = protein_vector[:, :-2]
279 |
280 | if protein_vector.shape[-1] < self.text_seq_len:
281 | diff = self.text_seq_len - protein_vector.shape[-1]
282 | protein_vector = torch.nn.functional.pad(
283 | protein_vector, (0, diff), "constant", pad_token
284 | )
285 |
286 | return protein_vector.long()
287 |
288 | else:
289 |
290 | assert("No valid sequence mode selected")
291 |
292 | if protein_vector.shape[-1] + 1 < self.text_seq_len:
293 | diff = self.text_seq_len - protein_vector.shape[-1]
294 | protein_vector = torch.nn.functional.pad(
295 | protein_vector, (0, diff), "constant", 0
296 | )
297 |
298 | return protein_vector.long()
299 |
--------------------------------------------------------------------------------
/images/generate.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BoHuangLab/Protein-Localization-Transformer/9d7d0bb4296a0363d1af21a7e50ddc8672e0eca6/images/generate.gif
--------------------------------------------------------------------------------
/images/huanglogo.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BoHuangLab/Protein-Localization-Transformer/9d7d0bb4296a0363d1af21a7e50ddc8672e0eca6/images/huanglogo.jpeg
--------------------------------------------------------------------------------
/images/nucleus.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BoHuangLab/Protein-Localization-Transformer/9d7d0bb4296a0363d1af21a7e50ddc8672e0eca6/images/nucleus.jpg
--------------------------------------------------------------------------------
/images/preview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BoHuangLab/Protein-Localization-Transformer/9d7d0bb4296a0363d1af21a7e50ddc8672e0eca6/images/preview.png
--------------------------------------------------------------------------------
/notebooks/Demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "protein_sequence = ''\n",
10 | "nucleus_image = 'images/nucleus.jpg'\n",
11 | "protein_name = None\n",
12 | "device = \"cuda:0\"\n",
13 | "config_file = 'configs/celle.yaml'\n",
14 | "ckpt_path = None"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "#run once\n",
24 | "import os\n",
25 | "\n",
26 | "if 'notebooks' in os.getcwd():\n",
27 | " os.chdir('..')\n",
28 | "\n",
29 | "import torch\n",
30 | "import numpy as np\n",
31 | "from matplotlib import pyplot as plt\n",
32 | "from matplotlib import cm\n",
33 | "from matplotlib.colors import LinearSegmentedColormap\n",
34 | "import torchvision\n",
35 | "\n",
36 | "from einops import rearrange\n",
37 | "from omegaconf import OmegaConf\n",
38 | "\n",
39 | "from celle_main import instantiate_from_config\n",
40 | "from dataloader import OpenCellLoader\n",
41 | "\n",
42 | "# color map for plot\n",
43 | "color_array = plt.get_cmap('gist_rainbow')(range(256))\n",
44 | "color_array[:,-1] = np.linspace(1.0,0.0,256)\n",
45 | "map_object = LinearSegmentedColormap.from_list(name='rainbow_alpha',colors=color_array[::-1])\n",
46 | "plt.register_cmap(cmap=map_object)\n",
47 | "\n",
48 | "device = torch.device(device)\n",
49 | "\n",
50 | "#load model\n",
51 | "configs = OmegaConf.load(config_file);\n",
52 | "model = instantiate_from_config(configs.model).to(device);\n",
53 | "if ckpt_path:\n",
54 | " t = torch.load(ckpt_path,map_location = 'cpu')['state_dict'];\n",
55 | " for key in list(t.keys()):\n",
56 | " t[key.replace('celle.','')] = t.pop(key);\n",
57 | "model.celle.load_state_dict(t,strict=False);\n",
58 | "model = model.celle\n",
59 | "model = model.to(device)\n",
60 | "model = model.eval()\n",
61 | "\n",
62 | "# get some params\n",
63 | "crop_size = configs.data.params.crop_size\n",
64 | "sequence_mode = configs.data.params.sequence_mode\n",
65 | "vocab = configs.data.params.vocab\n",
66 | "threshold = configs.data.params.threshold\n",
67 | "text_seq_len = configs.data.params.text_seq_len\n",
68 | "\n",
69 | "# convert string to numbered index\n",
70 | "dataset = OpenCellLoader(crop_size=crop_size, sequence_mode=sequence_mode, vocab=vocab, threshold=threshold, text_seq_len=text_seq_len)"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": null,
76 | "metadata": {},
77 | "outputs": [],
78 | "source": [
79 | "protein_sequence = ''.join(filter(str.isalpha, protein_sequence)) \n",
80 | "protein_sequence = dataset.tokenize_seqeuence(protein_sequence)\n",
81 | "\n",
82 | "# import nucleus, scale and crop\n",
83 | "nucleus = torch.tensor(plt.imread(nucleus_image)).float()\n",
84 | "nucleus /= 255\n",
85 | "nucleus = torchvision.transforms.RandomCrop(256)(nucleus).unsqueeze(0).unsqueeze(0)\n",
86 | "\n",
87 | "# generate image\n",
88 | "with torch.no_grad():\n",
89 | " output = model.generate_images(text=protein_sequence.to(device), condition = nucleus.to(device), return_logits=True, use_cache=True, progress=True)\n",
90 | " \n",
91 | " logits = output[-1][:,-256:,-512:]\n",
92 | " image_tokens = logits @ model.vae.model.quantize.embedding.weight\n",
93 | " image_tokens = rearrange(image_patches, \"b (h w) c -> b c h w\", h=int(np.sqrt(256)))\n",
94 | " pdf = model.vae.model.decode(image_tokens)\n",
95 | " pdf = torch.clip(pdf,0,1)\n",
96 | " \n",
97 | " plt.figure(dpi=300, clear=True) \n",
98 | " plt.axis('off')\n",
99 | " plt.imshow(nucleus[0,0],cmap='gray',interpolation='bicubic')\n",
100 | " plt.imshow(pdf.cpu()[0,0],cmap='rainbow_alpha',alpha = .75,interpolation='bicubic')\n",
101 | " plt.colorbar(mappable=cm.ScalarMappable(cmap='rainbow_alpha'))\n",
102 | " \n",
103 | " if protein_name:\n",
104 | " plt.title(protein_name)"
105 | ]
106 | }
107 | ],
108 | "metadata": {
109 | "interpreter": {
110 | "hash": "f703332b1593e5986aec844d60dd2796d9b0ddf157e3991cb22534f7b76c19d6"
111 | },
112 | "kernelspec": {
113 | "display_name": "Python 3.8.5 ('env': venv)",
114 | "language": "python",
115 | "name": "python3"
116 | },
117 | "language_info": {
118 | "codemirror_mode": {
119 | "name": "ipython",
120 | "version": 3
121 | },
122 | "file_extension": ".py",
123 | "mimetype": "text/x-python",
124 | "name": "python",
125 | "nbconvert_exporter": "python",
126 | "pygments_lexer": "ipython3",
127 | "version": "3.8.5"
128 | },
129 | "orig_nbformat": 4
130 | },
131 | "nbformat": 4,
132 | "nbformat_minor": 2
133 | }
134 |
--------------------------------------------------------------------------------
/notebooks/grad_map.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class ActivationsAndGradients:
5 | """Class for extracting activations and
6 | registering gradients from targetted intermediate layers"""
7 |
8 | def __init__(self, model, target_layers, reshape_transform):
9 | self.model = model
10 | self.gradients = []
11 | self.activations = []
12 | self.reshape_transform = reshape_transform
13 | self.handles = []
14 | for target_layer in target_layers:
15 | self.handles.append(
16 | target_layer.register_forward_hook(self.save_activation)
17 | )
18 | # Because of https://github.com/pytorch/pytorch/issues/61519,
19 | # we don't use backward hook to record gradients.
20 | self.handles.append(target_layer.register_forward_hook(self.save_gradient))
21 |
22 | def save_activation(self, module, input, output):
23 |
24 | if len(output) == 2:
25 | output = output["attn"]
26 |
27 | if len(self.activations) == 0:
28 | self.min_val = torch.min(output)
29 |
30 | attn = torch.nn.functional.pad(
31 | output,
32 | (0, 1512 - (output.shape[-1]), 0, 0, 0, 0, 0, 0),
33 | "constant",
34 | self.min_val,
35 | )
36 |
37 | self.activations.append(attn.cpu().detach())
38 |
39 | def save_gradient(self, module, input, output):
40 | # if not hasattr(output, "requires_grad") or not output.requires_grad:
41 | # # You can only register hooks on tensor requires grad.
42 | # return
43 |
44 | # Gradients are computed in reverse order
45 | def _store_grad(grad):
46 | if self.reshape_transform is not None:
47 | grad = self.reshape_transform(grad)
48 | self.gradients = [grad.cpu().detach()] + self.gradients
49 |
50 | # output.register_hook(_store_grad(output))
51 |
52 | def __call__(self, x, filter_thres=0.5):
53 | self.gradients = []
54 | self.activations = []
55 | sequence, nucleus, target = x
56 | return self.model.generate_images(
57 | text=sequence,
58 | condition=nucleus,
59 | return_logits=True,
60 | progress=True,
61 | use_cache=True,
62 | filter_thres=filter_thres,
63 | )
64 |
65 | def release(self):
66 | for handle in self.handles:
67 | handle.remove()
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.0.0
2 | antlr4-python3-runtime==4.8
3 | argon2-cffi==21.1.0
4 | asttokens==2.0.5
5 | attrs==21.2.0
6 | axial-positional-embedding==0.2.1
7 | backcall==0.2.0
8 | biopython==1.79
9 | bleach==4.1.0
10 | boto3==1.20.5
11 | botocore==1.23.5
12 | cachetools==4.2.4
13 | calmsize==0.1.3
14 | certifi==2021.10.8
15 | cffi==1.15.0
16 | charset-normalizer==2.0.7
17 | click==7.1.2
18 | cycler==0.11.0
19 | dataclasses==0.6
20 | debugpy==1.5.1
21 | decorator==5.1.0
22 | defusedxml==0.7.1
23 | einops==0.3.2
24 | entrypoints==0.3
25 | executing==0.8.3
26 | filelock==3.3.2
27 | fsspec==2021.11.0
28 | future==0.18.2
29 | g-mlp-pytorch==0.1.5
30 | google-auth==2.3.3
31 | google-auth-oauthlib==0.4.6
32 | grpcio==1.41.1
33 | idna==3.3
34 | imageio==2.10.3
35 | importlib-metadata==4.11.3
36 | importlib-resources==5.4.0
37 | ipykernel==6.6.0
38 | ipython==7.30.1
39 | ipython-genutils==0.2.0
40 | ipywidgets==7.6.5
41 | jedi==0.18.0
42 | Jinja2==3.0.3
43 | jmespath==0.10.0
44 | joblib==1.1.0
45 | jsonschema==4.2.1
46 | jupyter==1.0.0
47 | jupyter-client==7.0.6
48 | jupyter-console==6.4.0
49 | jupyter-core==4.9.1
50 | jupyterlab-pygments==0.1.2
51 | jupyterlab-widgets==1.0.2
52 | kiwisolver==1.3.2
53 | llvmlite==0.38.0
54 | lmdb==1.2.1
55 | Markdown==3.3.6
56 | MarkupSafe==2.0.1
57 | matplotlib==3.4.3
58 | matplotlib-inline==0.1.3
59 | mistune==0.8.4
60 | nbclient==0.5.8
61 | nbconvert==6.4.3
62 | nbformat==5.1.3
63 | nest-asyncio==1.5.1
64 | nopdb==0.2.0
65 | notebook==6.4.5
66 | numba==0.55.1
67 | numpy==1.21.4
68 | oauthlib==3.1.1
69 | omegaconf==2.0.0
70 | opencv-python==4.5.5.64
71 | packaging==21.2
72 | pandas==1.3.4
73 | pandocfilters==1.5.0
74 | parso==0.8.2
75 | pexpect==4.8.0
76 | pickleshare==0.7.5
77 | Pillow==8.4.0
78 | prometheus-client==0.12.0
79 | prompt-toolkit==3.0.22
80 | protobuf==3.19.1
81 | psutil==5.9.1
82 | ptyprocess==0.7.0
83 | pure-eval==0.2.2
84 | pyasn1==0.4.8
85 | pyasn1-modules==0.2.8
86 | pycparser==2.21
87 | pyDeprecate==0.3.1
88 | Pygments==2.10.0
89 | pynndescent==0.5.6
90 | pyparsing==2.4.7
91 | pyrsistent==0.18.0
92 | python-dateutil==2.8.2
93 | pytorch-lightning==1.0.8
94 | pytorch-memlab==0.2.4
95 | pytz==2021.3
96 | PyYAML==6.0
97 | pyzmq==22.3.0
98 | qtconsole==5.2.0
99 | QtPy==1.11.2
100 | regex==2021.11.10
101 | requests==2.26.0
102 | requests-oauthlib==1.3.0
103 | rotary-embedding-torch==0.1.2
104 | rsa==4.7.2
105 | s3transfer==0.5.0
106 | sacremoses==0.0.46
107 | scikit-learn==1.0.2
108 | scipy==1.7.2
109 | Send2Trash==1.8.0
110 | six==1.16.0
111 | stack-data==0.2.0
112 | tape-proteins==0.5
113 | tensorboard==2.7.0
114 | tensorboard-data-server==0.6.1
115 | tensorboard-plugin-wit==1.8.0
116 | tensorboardX==2.4
117 | terminado==0.12.1
118 | test-tube==0.7.5
119 | testpath==0.5.0
120 | threadpoolctl==3.1.0
121 | tokenizers==0.10.3
122 | torchmetrics==0.6.0
123 | tornado==6.1
124 | tqdm==4.62.3
125 | traitlets==5.1.1
126 | transformers==4.3.1
127 | typing-extensions==4.1.1
128 | umap-learn==0.5.2
129 | urllib3==1.26.7
130 | wcwidth==0.2.5
131 | webencodings==0.5.1
132 | Werkzeug==2.0.2
133 | widgetsnbextension==3.5.2
134 | zipp==3.6.0
135 |
--------------------------------------------------------------------------------
/taming/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n):
33 | return self.schedule(n)
34 |
35 |
--------------------------------------------------------------------------------
/taming/models/cond_transformer.py:
--------------------------------------------------------------------------------
1 | import os, math
2 | import torch
3 | import torch.nn.functional as F
4 | import pytorch_lightning as pl
5 |
6 | from main import instantiate_from_config
7 | from taming.modules.util import SOSProvider
8 |
9 |
10 | def disabled_train(self, mode=True):
11 | """Overwrite model.train with this function to make sure train/eval mode
12 | does not change anymore."""
13 | return self
14 |
15 |
16 | class Net2NetTransformer(pl.LightningModule):
17 | def __init__(self,
18 | transformer_config,
19 | first_stage_config,
20 | cond_stage_config,
21 | permuter_config=None,
22 | ckpt_path=None,
23 | ignore_keys=[],
24 | first_stage_key="image",
25 | cond_stage_key="depth",
26 | downsample_cond_size=-1,
27 | pkeep=1.0,
28 | sos_token=0,
29 | unconditional=False,
30 | ):
31 | super().__init__()
32 | self.be_unconditional = unconditional
33 | self.sos_token = sos_token
34 | self.first_stage_key = first_stage_key
35 | self.cond_stage_key = cond_stage_key
36 | self.init_first_stage_from_ckpt(first_stage_config)
37 | self.init_cond_stage_from_ckpt(cond_stage_config)
38 | if permuter_config is None:
39 | permuter_config = {"target": "taming.modules.transformer.permuter.Identity"}
40 | self.permuter = instantiate_from_config(config=permuter_config)
41 | self.transformer = instantiate_from_config(config=transformer_config)
42 |
43 | if ckpt_path is not None:
44 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
45 | self.downsample_cond_size = downsample_cond_size
46 | self.pkeep = pkeep
47 |
48 | def init_from_ckpt(self, path, ignore_keys=list()):
49 | sd = torch.load(path, map_location="cpu")["state_dict"]
50 | for k in sd.keys():
51 | for ik in ignore_keys:
52 | if k.startswith(ik):
53 | self.print("Deleting key {} from state_dict.".format(k))
54 | del sd[k]
55 | self.load_state_dict(sd, strict=False)
56 | print(f"Restored from {path}")
57 |
58 | def init_first_stage_from_ckpt(self, config):
59 | model = instantiate_from_config(config)
60 | model = model.eval()
61 | model.train = disabled_train
62 | self.first_stage_model = model
63 |
64 | def init_cond_stage_from_ckpt(self, config):
65 | if config == "__is_first_stage__":
66 | print("Using first stage also as cond stage.")
67 | self.cond_stage_model = self.first_stage_model
68 | elif config == "__is_unconditional__" or self.be_unconditional:
69 | print(f"Using no cond stage. Assuming the training is intended to be unconditional. "
70 | f"Prepending {self.sos_token} as a sos token.")
71 | self.be_unconditional = True
72 | self.cond_stage_key = self.first_stage_key
73 | self.cond_stage_model = SOSProvider(self.sos_token)
74 | else:
75 | model = instantiate_from_config(config)
76 | model = model.eval()
77 | model.train = disabled_train
78 | self.cond_stage_model = model
79 |
80 | def forward(self, x, c):
81 | # one step to produce the logits
82 | # x = target
83 | # c = nucleus
84 | _, z_indices = self.encode_to_z(x)
85 | _, c_indices = self.encode_to_c(c)
86 |
87 | if self.training and self.pkeep < 1.0:
88 | mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
89 | device=z_indices.device))
90 | mask = mask.round().to(dtype=torch.int64)
91 | r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
92 | a_indices = mask*z_indices+(1-mask)*r_indices
93 | else:
94 | a_indices = z_indices
95 |
96 | cz_indices = torch.cat((c_indices, a_indices), dim=1)
97 |
98 | # target includes all sequence elements (no need to handle first one
99 | # differently because we are conditioning)
100 | target = z_indices
101 | # make the prediction
102 | logits, _ = self.transformer(cz_indices[:, :-1])
103 | # cut off conditioning outputs - output i corresponds to p(z_i | z_{ -1:
180 | c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size))
181 |
182 | #quant_c, _, info = self.cond_stage_model.encode(x)
183 | #indices = info[2].view(quant_c.shape[0], -1)
184 | #indices = self.permuter(indices)
185 | quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c)
186 | if len(indices.shape) != 2:
187 | indices = indices.view(c.shape[0], -1)
188 | return quant_c, indices
189 |
190 | @torch.no_grad()
191 | def decode_to_img(self, index, zshape):
192 | index = self.permuter(index, reverse=True)
193 | bhwc = (zshape[0],zshape[2],zshape[3],zshape[1])
194 | quant_z = self.first_stage_model.quantize.get_codebook_entry(
195 | index.reshape(-1), shape=bhwc)
196 | x = self.first_stage_model.decode(quant_z)
197 | return x
198 |
199 | @torch.no_grad()
200 | def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs):
201 | log = dict()
202 |
203 | N = 4
204 | if lr_interface:
205 | x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8)
206 | else:
207 | x, c = self.get_xc(batch, N)
208 | x = x.to(device=self.device)
209 | c = c.to(device=self.device)
210 |
211 | quant_z, z_indices = self.encode_to_z(x)
212 | quant_c, c_indices = self.encode_to_c(c)
213 |
214 | # create a "half"" sample
215 | z_start_indices = z_indices[:,:z_indices.shape[1]//2]
216 | index_sample = self.sample(z_start_indices, c_indices,
217 | steps=z_indices.shape[1]-z_start_indices.shape[1],
218 | temperature=temperature if temperature is not None else 1.0,
219 | sample=True,
220 | top_k=top_k if top_k is not None else 100,
221 | callback=callback if callback is not None else lambda k: None)
222 | x_sample = self.decode_to_img(index_sample, quant_z.shape)
223 |
224 | # sample
225 | z_start_indices = z_indices[:, :0]
226 | index_sample = self.sample(z_start_indices, c_indices,
227 | steps=z_indices.shape[1],
228 | temperature=temperature if temperature is not None else 1.0,
229 | sample=True,
230 | top_k=top_k if top_k is not None else 100,
231 | callback=callback if callback is not None else lambda k: None)
232 | x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape)
233 |
234 | # det sample
235 | z_start_indices = z_indices[:, :0]
236 | index_sample = self.sample(z_start_indices, c_indices,
237 | steps=z_indices.shape[1],
238 | sample=False,
239 | callback=callback if callback is not None else lambda k: None)
240 | x_sample_det = self.decode_to_img(index_sample, quant_z.shape)
241 |
242 | # reconstruction
243 | x_rec = self.decode_to_img(z_indices, quant_z.shape)
244 |
245 | log["inputs"] = x
246 | log["reconstructions"] = x_rec
247 |
248 | if self.cond_stage_key != "image" or self.cond_stage_key != "nucleus" or self.cond_stage_key != "target":
249 | cond_rec = self.cond_stage_model.decode(quant_c)
250 | if self.cond_stage_key == "segmentation":
251 | # get image from segmentation mask
252 | num_classes = cond_rec.shape[1]
253 |
254 | c = torch.argmax(c, dim=1, keepdim=True)
255 | c = F.one_hot(c, num_classes=num_classes)
256 | c = c.squeeze(1).permute(0, 3, 1, 2).float()
257 | c = self.cond_stage_model.to_rgb(c)
258 |
259 | cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True)
260 | cond_rec = F.one_hot(cond_rec, num_classes=num_classes)
261 | cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float()
262 | cond_rec = self.cond_stage_model.to_rgb(cond_rec)
263 | log["conditioning_rec"] = cond_rec
264 | log["conditioning"] = c
265 |
266 | log["samples_half"] = x_sample
267 | log["samples_nopix"] = x_sample_nopix
268 | log["samples_det"] = x_sample_det
269 | return log
270 |
271 | def get_input(self, key, batch):
272 | x = batch[key]
273 | if len(x.shape) == 3:
274 | x = x[..., None]
275 | #if len(x.shape) == 4:
276 | # x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
277 | if x.dtype == torch.double:
278 | x = x.float()
279 | return x
280 |
281 | def get_xc(self, batch, N=None):
282 | x = self.get_input(self.first_stage_key, batch)
283 | c = self.get_input(self.cond_stage_key, batch)
284 | if N is not None:
285 | x = x[:N]
286 | c = c[:N]
287 | return x, c
288 |
289 | def shared_step(self, batch):
290 | x, c = self.get_xc(batch)
291 | logits, target = self(x, c)
292 | loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
293 | return loss
294 |
295 | def training_step(self, batch, batch_idx):
296 | loss = self.shared_step(batch)
297 | self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
298 | return loss
299 |
300 | def validation_step(self, batch, batch_idx):
301 | loss = self.shared_step(batch)
302 | self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
303 | return loss
304 |
305 | def configure_optimizers(self):
306 | """
307 | Following minGPT:
308 | This long function is unfortunately doing something very simple and is being very defensive:
309 | We are separating out all parameters of the model into two buckets: those that will experience
310 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
311 | We are then returning the PyTorch optimizer object.
312 | """
313 | # separate out all parameters to those that will and won't experience regularizing weight decay
314 | decay = set()
315 | no_decay = set()
316 | whitelist_weight_modules = (torch.nn.Linear, )
317 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
318 | for mn, m in self.transformer.named_modules():
319 | for pn, p in m.named_parameters():
320 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
321 |
322 | if pn.endswith('bias'):
323 | # all biases will not be decayed
324 | no_decay.add(fpn)
325 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
326 | # weights of whitelist modules will be weight decayed
327 | decay.add(fpn)
328 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
329 | # weights of blacklist modules will NOT be weight decayed
330 | no_decay.add(fpn)
331 |
332 | # special case the position embedding parameter in the root GPT module as not decayed
333 | no_decay.add('pos_emb')
334 |
335 | # validate that we considered every parameter
336 | param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
337 | inter_params = decay & no_decay
338 | union_params = decay | no_decay
339 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
340 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
341 | % (str(param_dict.keys() - union_params), )
342 |
343 | # create the pytorch optimizer object
344 | optim_groups = [
345 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
346 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
347 | ]
348 | optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95))
349 | return optimizer
350 |
--------------------------------------------------------------------------------
/taming/models/dummy_cond_stage.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor
2 |
3 |
4 | class DummyCondStage:
5 | def __init__(self, conditional_key):
6 | self.conditional_key = conditional_key
7 | self.train = None
8 |
9 | def eval(self):
10 | return self
11 |
12 | @staticmethod
13 | def encode(c: Tensor):
14 | return c, None, (None, None, c)
15 |
16 | @staticmethod
17 | def decode(c: Tensor):
18 | return c
19 |
20 | @staticmethod
21 | def to_rgb(c: Tensor):
22 | return c
23 |
--------------------------------------------------------------------------------
/taming/models/vqgan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import pytorch_lightning as pl
4 |
5 | from celle_taming_main import instantiate_from_config
6 |
7 | from taming.modules.diffusionmodules.model import Encoder, Decoder
8 | from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
9 | from taming.modules.vqvae.quantize import GumbelQuantize
10 | from taming.modules.vqvae.quantize import EMAVectorQuantizer
11 |
12 |
13 | class VQModel(pl.LightningModule):
14 | def __init__(
15 | self,
16 | ddconfig,
17 | lossconfig,
18 | n_embed,
19 | embed_dim,
20 | ckpt_path=None,
21 | ignore_keys=[],
22 | image_key="image",
23 | colorize_nlabels=None,
24 | monitor=None,
25 | remap=None,
26 | sane_index_shape=False, # tell vector quantizer to return indices as bhw
27 | ):
28 | super().__init__()
29 | self.image_key = image_key
30 | self.encoder = Encoder(**ddconfig)
31 | self.decoder = Decoder(**ddconfig)
32 | self.loss = instantiate_from_config(lossconfig)
33 | self.quantize = VectorQuantizer(
34 | n_embed,
35 | embed_dim,
36 | beta=0.25,
37 | remap=remap,
38 | sane_index_shape=sane_index_shape,
39 | )
40 | self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
41 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
42 | if ckpt_path is not None:
43 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
44 | self.image_key = image_key
45 | if colorize_nlabels is not None:
46 | assert type(colorize_nlabels) == int
47 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
48 | if monitor is not None:
49 | self.monitor = monitor
50 |
51 | def init_from_ckpt(self, path, ignore_keys=list()):
52 | sd = torch.load(path, map_location="cpu")["state_dict"]
53 | keys = list(sd.keys())
54 | for k in keys:
55 | for ik in ignore_keys:
56 | if k.startswith(ik):
57 | print("Deleting key {} from state_dict.".format(k))
58 | del sd[k]
59 | self.load_state_dict(sd, strict=False)
60 | print(f"Restored from {path}")
61 |
62 | def encode(self, x):
63 | h = self.encoder(x)
64 | h = self.quant_conv(h)
65 | quant, emb_loss, info = self.quantize(h)
66 | return quant, emb_loss, info
67 |
68 | def decode(self, quant):
69 | quant = self.post_quant_conv(quant)
70 | dec = self.decoder(quant)
71 | return dec
72 |
73 | def decode_code(self, code_b):
74 | quant_b = self.quantize.embed_code(code_b)
75 | dec = self.decode(quant_b)
76 | return dec
77 |
78 | def forward(self, input):
79 | quant, diff, _ = self.encode(input)
80 | dec = self.decode(quant)
81 | return dec, diff
82 |
83 | def get_input(self, batch, k):
84 |
85 | if k == "mixed":
86 | keys = ["nucleus", "target"]
87 | index = torch.randint(low=0, high=2, size=(1,), dtype=int).item()
88 | k = keys[index]
89 |
90 | x = batch[k]
91 | if len(x.shape) == 3:
92 | x = x[..., None]
93 |
94 | # x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
95 | return x
96 |
97 | def training_step(self, batch, batch_idx=None, optimizer_idx=0):
98 |
99 | if type(batch) == dict:
100 |
101 | x = self.get_input(batch, self.image_key)
102 |
103 | else:
104 | x = batch
105 |
106 | xrec, qloss = self(
107 | x,
108 | )
109 |
110 | if optimizer_idx == 0:
111 | # autoencode
112 | aeloss, log_dict_ae = self.loss(
113 | qloss,
114 | x,
115 | xrec,
116 | optimizer_idx,
117 | self.global_step,
118 | last_layer=self.get_last_layer(),
119 | split="train",
120 | )
121 |
122 | self.log(
123 | "train/aeloss",
124 | aeloss,
125 | prog_bar=True,
126 | logger=True,
127 | on_step=True,
128 | on_epoch=True,
129 | )
130 | self.log_dict(
131 | log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
132 | )
133 | return aeloss
134 |
135 | if optimizer_idx == 1:
136 | # discriminator
137 | discloss, log_dict_disc = self.loss(
138 | qloss,
139 | x,
140 | xrec,
141 | optimizer_idx,
142 | self.global_step,
143 | last_layer=self.get_last_layer(),
144 | split="train",
145 | )
146 | self.log(
147 | "train/discloss",
148 | discloss,
149 | prog_bar=True,
150 | logger=True,
151 | on_step=True,
152 | on_epoch=True,
153 | )
154 | self.log_dict(
155 | log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
156 | )
157 | return discloss
158 |
159 | def validation_step(self, batch, batch_idx):
160 |
161 | if type(batch) == dict:
162 |
163 | x = self.get_input(batch, self.image_key)
164 |
165 | else:
166 | x = batch
167 |
168 | xrec, qloss = self(x)
169 | aeloss, log_dict_ae = self.loss(
170 | qloss,
171 | x,
172 | xrec,
173 | 0,
174 | self.global_step,
175 | last_layer=self.get_last_layer(),
176 | split="val",
177 | )
178 |
179 | discloss, log_dict_disc = self.loss(
180 | qloss,
181 | x,
182 | xrec,
183 | 1,
184 | self.global_step,
185 | last_layer=self.get_last_layer(),
186 | split="val",
187 | )
188 | rec_loss = log_dict_ae["val/rec_loss"]
189 | self.log(
190 | "val/rec_loss",
191 | rec_loss,
192 | prog_bar=True,
193 | logger=True,
194 | on_step=True,
195 | on_epoch=True,
196 | sync_dist=True,
197 | )
198 | self.log(
199 | "val/aeloss",
200 | aeloss,
201 | prog_bar=True,
202 | logger=True,
203 | on_step=True,
204 | on_epoch=True,
205 | sync_dist=True,
206 | )
207 | self.log_dict(log_dict_ae)
208 | self.log_dict(log_dict_disc)
209 | return self.log_dict
210 |
211 | def configure_optimizers(self):
212 | lr = self.learning_rate
213 | opt_ae = torch.optim.Adam(
214 | list(self.encoder.parameters())
215 | + list(self.decoder.parameters())
216 | + list(self.quantize.parameters())
217 | + list(self.quant_conv.parameters())
218 | + list(self.post_quant_conv.parameters()),
219 | lr=lr,
220 | betas=(0.5, 0.9),
221 | )
222 | opt_disc = torch.optim.Adam(
223 | self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
224 | )
225 | return [opt_ae, opt_disc], []
226 |
227 | def get_last_layer(self):
228 | return self.decoder.conv_out.weight
229 |
230 | def log_images(self, batch, **kwargs):
231 | log = dict()
232 | x = self.get_input(batch, self.image_key)
233 | x = x.to(self.device)
234 | xrec, _ = self(x)
235 | if x.shape[1] > 3:
236 | # colorize with random projection
237 | assert xrec.shape[1] > 3
238 | x = self.to_rgb(x)
239 | xrec = self.to_rgb(xrec)
240 | log["inputs"] = x
241 | log["reconstructions"] = xrec
242 | return log
243 |
244 | def to_rgb(self, x):
245 | assert self.image_key == "segmentation"
246 | if not hasattr(self, "colorize"):
247 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
248 | x = F.conv2d(x, weight=self.colorize)
249 | x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
250 | return x
251 |
252 |
253 | class VQSegmentationModel(VQModel):
254 | def __init__(self, n_labels, *args, **kwargs):
255 | super().__init__(*args, **kwargs)
256 | self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1))
257 |
258 | def configure_optimizers(self):
259 | lr = self.learning_rate
260 | opt_ae = torch.optim.Adam(
261 | list(self.encoder.parameters())
262 | + list(self.decoder.parameters())
263 | + list(self.quantize.parameters())
264 | + list(self.quant_conv.parameters())
265 | + list(self.post_quant_conv.parameters()),
266 | lr=lr,
267 | betas=(0.5, 0.9),
268 | )
269 | return opt_ae
270 |
271 | def training_step(self, batch, batch_idx):
272 | x = self.get_input(batch, self.image_key)
273 | xrec, qloss = self(x)
274 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
275 | self.log_dict(
276 | log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
277 | )
278 | return aeloss
279 |
280 | def validation_step(self, batch, batch_idx):
281 | x = self.get_input(batch, self.image_key)
282 | xrec, qloss = self(x)
283 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val")
284 | self.log_dict(
285 | log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
286 | )
287 | total_loss = log_dict_ae["val/total_loss"]
288 | self.log(
289 | "val/total_loss",
290 | total_loss,
291 | prog_bar=True,
292 | logger=True,
293 | on_step=True,
294 | on_epoch=True,
295 | sync_dist=True,
296 | )
297 | return aeloss
298 |
299 | @torch.no_grad()
300 | def log_images(self, batch, **kwargs):
301 | log = dict()
302 | x = self.get_input(batch, self.image_key)
303 | x = x.to(self.device)
304 | xrec, _ = self(x)
305 | if x.shape[1] > 3:
306 | # colorize with random projection
307 | assert xrec.shape[1] > 3
308 | # convert logits to indices
309 | xrec = torch.argmax(xrec, dim=1, keepdim=True)
310 | xrec = F.one_hot(xrec, num_classes=x.shape[1])
311 | xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
312 | x = self.to_rgb(x)
313 | xrec = self.to_rgb(xrec)
314 | log["inputs"] = x
315 | log["reconstructions"] = xrec
316 | return log
317 |
318 |
319 | class VQNoDiscModel(VQModel):
320 | def __init__(
321 | self,
322 | ddconfig,
323 | lossconfig,
324 | n_embed,
325 | embed_dim,
326 | ckpt_path=None,
327 | ignore_keys=[],
328 | image_key="image",
329 | colorize_nlabels=None,
330 | ):
331 | super().__init__(
332 | ddconfig=ddconfig,
333 | lossconfig=lossconfig,
334 | n_embed=n_embed,
335 | embed_dim=embed_dim,
336 | ckpt_path=ckpt_path,
337 | ignore_keys=ignore_keys,
338 | image_key=image_key,
339 | colorize_nlabels=colorize_nlabels,
340 | )
341 |
342 | def training_step(self, batch, batch_idx):
343 | x = self.get_input(batch, self.image_key)
344 | xrec, qloss = self(x)
345 | # autoencode
346 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train")
347 | output = pl.TrainResult(minimize=aeloss)
348 | output.log(
349 | "train/aeloss",
350 | aeloss,
351 | prog_bar=True,
352 | logger=True,
353 | on_step=True,
354 | on_epoch=True,
355 | )
356 | output.log_dict(
357 | log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
358 | )
359 | return output
360 |
361 | def validation_step(self, batch, batch_idx):
362 | x = self.get_input(batch, self.image_key)
363 | xrec, qloss = self(x)
364 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val")
365 | rec_loss = log_dict_ae["val/rec_loss"]
366 | output = pl.EvalResult(checkpoint_on=rec_loss)
367 | output.log(
368 | "val/rec_loss",
369 | rec_loss,
370 | prog_bar=True,
371 | logger=True,
372 | on_step=True,
373 | on_epoch=True,
374 | )
375 | output.log(
376 | "val/aeloss",
377 | aeloss,
378 | prog_bar=True,
379 | logger=True,
380 | on_step=True,
381 | on_epoch=True,
382 | )
383 | output.log_dict(log_dict_ae)
384 |
385 | return output
386 |
387 | def configure_optimizers(self):
388 | optimizer = torch.optim.Adam(
389 | list(self.encoder.parameters())
390 | + list(self.decoder.parameters())
391 | + list(self.quantize.parameters())
392 | + list(self.quant_conv.parameters())
393 | + list(self.post_quant_conv.parameters()),
394 | lr=self.learning_rate,
395 | betas=(0.5, 0.9),
396 | )
397 | return optimizer
398 |
399 |
400 | class GumbelVQ(VQModel):
401 | def __init__(
402 | self,
403 | ddconfig,
404 | lossconfig,
405 | n_embed,
406 | embed_dim,
407 | temperature_scheduler_config,
408 | ckpt_path=None,
409 | ignore_keys=[],
410 | image_key="image",
411 | colorize_nlabels=None,
412 | monitor=None,
413 | kl_weight=1e-8,
414 | remap=None,
415 | ):
416 |
417 | z_channels = ddconfig["z_channels"]
418 | super().__init__(
419 | ddconfig,
420 | lossconfig,
421 | n_embed,
422 | embed_dim,
423 | ckpt_path=None,
424 | ignore_keys=ignore_keys,
425 | image_key=image_key,
426 | colorize_nlabels=colorize_nlabels,
427 | monitor=monitor,
428 | )
429 |
430 | self.loss.n_classes = n_embed
431 | self.vocab_size = n_embed
432 |
433 | self.quantize = GumbelQuantize(
434 | z_channels,
435 | embed_dim,
436 | n_embed=n_embed,
437 | kl_weight=kl_weight,
438 | temp_init=1.0,
439 | remap=remap,
440 | )
441 |
442 | self.temperature_scheduler = instantiate_from_config(
443 | temperature_scheduler_config
444 | ) # annealing of temp
445 |
446 | if ckpt_path is not None:
447 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
448 |
449 | def temperature_scheduling(self):
450 | self.quantize.temperature = self.temperature_scheduler(self.global_step)
451 |
452 | def encode_to_prequant(self, x):
453 | h = self.encoder(x)
454 | h = self.quant_conv(h)
455 | return h
456 |
457 | def decode_code(self, code_b):
458 | raise NotImplementedError
459 |
460 | def training_step(self, batch, batch_idx, optimizer_idx):
461 | self.temperature_scheduling()
462 | x = self.get_input(batch, self.image_key)
463 | xrec, qloss = self(x)
464 |
465 | if optimizer_idx == 0:
466 | # autoencode
467 | aeloss, log_dict_ae = self.loss(
468 | qloss,
469 | x,
470 | xrec,
471 | optimizer_idx,
472 | self.global_step,
473 | last_layer=self.get_last_layer(),
474 | split="train",
475 | )
476 |
477 | self.log_dict(
478 | log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
479 | )
480 | self.log(
481 | "temperature",
482 | self.quantize.temperature,
483 | prog_bar=False,
484 | logger=True,
485 | on_step=True,
486 | on_epoch=True,
487 | )
488 | return aeloss
489 |
490 | if optimizer_idx == 1:
491 | # discriminator
492 | discloss, log_dict_disc = self.loss(
493 | qloss,
494 | x,
495 | xrec,
496 | optimizer_idx,
497 | self.global_step,
498 | last_layer=self.get_last_layer(),
499 | split="train",
500 | )
501 | self.log_dict(
502 | log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
503 | )
504 | return discloss
505 |
506 | def validation_step(self, batch, batch_idx):
507 | x = self.get_input(batch, self.image_key)
508 | xrec, qloss = self(x)
509 | aeloss, log_dict_ae = self.loss(
510 | qloss,
511 | x,
512 | xrec,
513 | 0,
514 | self.global_step,
515 | last_layer=self.get_last_layer(),
516 | split="val",
517 | )
518 |
519 | discloss, log_dict_disc = self.loss(
520 | qloss,
521 | x,
522 | xrec,
523 | 1,
524 | self.global_step,
525 | last_layer=self.get_last_layer(),
526 | split="val",
527 | )
528 | rec_loss = log_dict_ae["val/rec_loss"]
529 | self.log(
530 | "val/rec_loss",
531 | rec_loss,
532 | prog_bar=True,
533 | logger=True,
534 | on_step=False,
535 | on_epoch=True,
536 | sync_dist=True,
537 | )
538 | self.log(
539 | "val/aeloss",
540 | aeloss,
541 | prog_bar=True,
542 | logger=True,
543 | on_step=False,
544 | on_epoch=True,
545 | sync_dist=True,
546 | )
547 | self.log_dict(log_dict_ae)
548 | self.log_dict(log_dict_disc)
549 | return self.log_dict
550 |
551 | def log_images(self, batch, **kwargs):
552 | log = dict()
553 | x = self.get_input(batch, self.image_key)
554 | x = x.to(self.device)
555 | # encode
556 | h = self.encoder(x)
557 | h = self.quant_conv(h)
558 | quant, _, _ = self.quantize(h)
559 | # decode
560 | x_rec = self.decode(quant)
561 | log["inputs"] = x
562 | log["reconstructions"] = x_rec
563 | return log
564 |
565 |
566 | class EMAVQ(VQModel):
567 | def __init__(
568 | self,
569 | ddconfig,
570 | lossconfig,
571 | n_embed,
572 | embed_dim,
573 | ckpt_path=None,
574 | ignore_keys=[],
575 | image_key="image",
576 | colorize_nlabels=None,
577 | monitor=None,
578 | remap=None,
579 | sane_index_shape=False, # tell vector quantizer to return indices as bhw
580 | ):
581 | super().__init__(
582 | ddconfig,
583 | lossconfig,
584 | n_embed,
585 | embed_dim,
586 | ckpt_path=None,
587 | ignore_keys=ignore_keys,
588 | image_key=image_key,
589 | colorize_nlabels=colorize_nlabels,
590 | monitor=monitor,
591 | )
592 | self.quantize = EMAVectorQuantizer(
593 | n_embed=n_embed, embedding_dim=embed_dim, beta=0.25, remap=remap
594 | )
595 |
596 | def configure_optimizers(self):
597 | lr = self.learning_rate
598 | # Remove self.quantize from parameter list since it is updated via EMA
599 | opt_ae = torch.optim.Adam(
600 | list(self.encoder.parameters())
601 | + list(self.decoder.parameters())
602 | + list(self.quant_conv.parameters())
603 | + list(self.post_quant_conv.parameters()),
604 | lr=lr,
605 | betas=(0.5, 0.9),
606 | )
607 | opt_disc = torch.optim.Adam(
608 | self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
609 | )
610 | return [opt_ae, opt_disc], []
611 |
--------------------------------------------------------------------------------
/taming/modules/autoencoder/lpips/vgg.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BoHuangLab/Protein-Localization-Transformer/9d7d0bb4296a0363d1af21a7e50ddc8672e0eca6/taming/modules/autoencoder/lpips/vgg.pth
--------------------------------------------------------------------------------
/taming/modules/discriminator/model.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import torch.nn as nn
3 |
4 |
5 | from taming.modules.util import ActNorm
6 |
7 |
8 | def weights_init(m):
9 | classname = m.__class__.__name__
10 | if classname.find('Conv') != -1:
11 | nn.init.normal_(m.weight.data, 0.0, 0.02)
12 | elif classname.find('BatchNorm') != -1:
13 | nn.init.normal_(m.weight.data, 1.0, 0.02)
14 | nn.init.constant_(m.bias.data, 0)
15 |
16 |
17 | class NLayerDiscriminator(nn.Module):
18 | """Defines a PatchGAN discriminator as in Pix2Pix
19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20 | """
21 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
22 | """Construct a PatchGAN discriminator
23 | Parameters:
24 | input_nc (int) -- the number of channels in input images
25 | ndf (int) -- the number of filters in the last conv layer
26 | n_layers (int) -- the number of conv layers in the discriminator
27 | norm_layer -- normalization layer
28 | """
29 | super(NLayerDiscriminator, self).__init__()
30 | if not use_actnorm:
31 | norm_layer = nn.BatchNorm2d
32 | else:
33 | norm_layer = ActNorm
34 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
35 | use_bias = norm_layer.func != nn.BatchNorm2d
36 | else:
37 | use_bias = norm_layer != nn.BatchNorm2d
38 |
39 | kw = 4
40 | padw = 1
41 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
42 | nf_mult = 1
43 | nf_mult_prev = 1
44 | for n in range(1, n_layers): # gradually increase the number of filters
45 | nf_mult_prev = nf_mult
46 | nf_mult = min(2 ** n, 8)
47 | sequence += [
48 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
49 | norm_layer(ndf * nf_mult),
50 | nn.LeakyReLU(0.2, True)
51 | ]
52 |
53 | nf_mult_prev = nf_mult
54 | nf_mult = min(2 ** n_layers, 8)
55 | sequence += [
56 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
57 | norm_layer(ndf * nf_mult),
58 | nn.LeakyReLU(0.2, True)
59 | ]
60 |
61 | sequence += [
62 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
63 | self.main = nn.Sequential(*sequence)
64 |
65 | def forward(self, input):
66 | """Standard forward."""
67 | return self.main(input)
68 |
--------------------------------------------------------------------------------
/taming/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from taming.modules.losses.vqperceptual import DummyLoss
2 |
3 |
--------------------------------------------------------------------------------
/taming/modules/losses/lpips.py:
--------------------------------------------------------------------------------
1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torchvision import models
6 | from collections import namedtuple
7 |
8 | from taming.util import get_ckpt_path
9 |
10 |
11 | class LPIPS(nn.Module):
12 | # Learned perceptual metric
13 | def __init__(self, use_dropout=True):
14 | super().__init__()
15 | self.scaling_layer = ScalingLayer()
16 | self.chns = [64, 128, 256, 512, 512] # vg16 features
17 | self.net = vgg16(pretrained=True, requires_grad=False)
18 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
19 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
20 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
21 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
22 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
23 | self.load_from_pretrained()
24 | for param in self.parameters():
25 | param.requires_grad = False
26 |
27 | def load_from_pretrained(self, name="vgg_lpips"):
28 | ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
29 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
30 | print("loaded pretrained LPIPS loss from {}".format(ckpt))
31 |
32 | @classmethod
33 | def from_pretrained(cls, name="vgg_lpips"):
34 | if name != "vgg_lpips":
35 | raise NotImplementedError
36 | model = cls()
37 | ckpt = get_ckpt_path(name)
38 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
39 | return model
40 |
41 | def forward(self, input, target):
42 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
43 | outs0, outs1 = self.net(in0_input), self.net(in1_input)
44 | feats0, feats1, diffs = {}, {}, {}
45 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
46 | for kk in range(len(self.chns)):
47 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
49 |
50 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
51 | val = res[0]
52 | for l in range(1, len(self.chns)):
53 | val += res[l]
54 | return val
55 |
56 |
57 | class ScalingLayer(nn.Module):
58 | def __init__(self):
59 | super(ScalingLayer, self).__init__()
60 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
61 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
62 |
63 | def forward(self, inp):
64 | return (inp - self.shift) / self.scale
65 |
66 |
67 | class NetLinLayer(nn.Module):
68 | """ A single linear layer which does a 1x1 conv """
69 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
70 | super(NetLinLayer, self).__init__()
71 | layers = [nn.Dropout(), ] if (use_dropout) else []
72 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
73 | self.model = nn.Sequential(*layers)
74 |
75 |
76 | class vgg16(torch.nn.Module):
77 | def __init__(self, requires_grad=False, pretrained=True):
78 | super(vgg16, self).__init__()
79 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
80 | self.slice1 = torch.nn.Sequential()
81 | self.slice2 = torch.nn.Sequential()
82 | self.slice3 = torch.nn.Sequential()
83 | self.slice4 = torch.nn.Sequential()
84 | self.slice5 = torch.nn.Sequential()
85 | self.N_slices = 5
86 | for x in range(4):
87 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
88 | for x in range(4, 9):
89 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
90 | for x in range(9, 16):
91 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
92 | for x in range(16, 23):
93 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
94 | for x in range(23, 30):
95 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
96 | if not requires_grad:
97 | for param in self.parameters():
98 | param.requires_grad = False
99 |
100 | def forward(self, X):
101 | h = self.slice1(X)
102 | h_relu1_2 = h
103 | h = self.slice2(h)
104 | h_relu2_2 = h
105 | h = self.slice3(h)
106 | h_relu3_3 = h
107 | h = self.slice4(h)
108 | h_relu4_3 = h
109 | h = self.slice5(h)
110 | h_relu5_3 = h
111 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
112 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
113 | return out
114 |
115 |
116 | def normalize_tensor(x,eps=1e-10):
117 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
118 | return x/(norm_factor+eps)
119 |
120 |
121 | def spatial_average(x, keepdim=True):
122 | return x.mean([2,3],keepdim=keepdim)
123 |
124 |
--------------------------------------------------------------------------------
/taming/modules/losses/segmentation.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class BCELoss(nn.Module):
6 | def forward(self, prediction, target):
7 | loss = F.binary_cross_entropy_with_logits(prediction,target)
8 | return loss, {}
9 |
10 |
11 | class BCELossWithQuant(nn.Module):
12 | def __init__(self, codebook_weight=1.):
13 | super().__init__()
14 | self.codebook_weight = codebook_weight
15 |
16 | def forward(self, qloss, target, prediction, split):
17 | bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
18 | loss = bce_loss + self.codebook_weight*qloss
19 | return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
20 | "{}/bce_loss".format(split): bce_loss.detach().mean(),
21 | "{}/quant_loss".format(split): qloss.detach().mean()
22 | }
23 |
--------------------------------------------------------------------------------
/taming/modules/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from taming.modules.losses.lpips import LPIPS
6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7 |
8 |
9 | class DummyLoss(nn.Module):
10 | def __init__(self):
11 | super().__init__()
12 |
13 |
14 | def adopt_weight(weight, global_step, threshold=0, value=0.):
15 | if global_step < threshold:
16 | weight = value
17 | return weight
18 |
19 |
20 | def hinge_d_loss(logits_real, logits_fake):
21 | loss_real = torch.mean(F.relu(1. - logits_real))
22 | loss_fake = torch.mean(F.relu(1. + logits_fake))
23 | d_loss = 0.5 * (loss_real + loss_fake)
24 | return d_loss
25 |
26 |
27 | def vanilla_d_loss(logits_real, logits_fake):
28 | d_loss = 0.5 * (
29 | torch.mean(torch.nn.functional.softplus(-logits_real)) +
30 | torch.mean(torch.nn.functional.softplus(logits_fake)))
31 | return d_loss
32 |
33 |
34 | class VQLPIPSWithDiscriminator(nn.Module):
35 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
36 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
37 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
38 | disc_ndf=64, disc_loss="hinge"):
39 | super().__init__()
40 | assert disc_loss in ["hinge", "vanilla"]
41 | self.codebook_weight = codebook_weight
42 | self.pixel_weight = pixelloss_weight
43 | self.perceptual_loss = LPIPS().eval()
44 | self.perceptual_weight = perceptual_weight
45 |
46 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
47 | n_layers=disc_num_layers,
48 | use_actnorm=use_actnorm,
49 | ndf=disc_ndf
50 | ).apply(weights_init)
51 | self.discriminator_iter_start = disc_start
52 | if disc_loss == "hinge":
53 | self.disc_loss = hinge_d_loss
54 | elif disc_loss == "vanilla":
55 | self.disc_loss = vanilla_d_loss
56 | else:
57 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
58 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
59 | self.disc_factor = disc_factor
60 | self.discriminator_weight = disc_weight
61 | self.disc_conditional = disc_conditional
62 |
63 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
64 | if last_layer is not None:
65 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
66 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
67 | else:
68 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
69 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
70 |
71 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
72 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
73 | d_weight = d_weight * self.discriminator_weight
74 | return d_weight
75 |
76 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
77 | global_step, last_layer=None, cond=None, split="train"):
78 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
79 | if self.perceptual_weight > 0:
80 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
81 | rec_loss = rec_loss + self.perceptual_weight * p_loss
82 | else:
83 | p_loss = torch.tensor([0.0])
84 |
85 | nll_loss = rec_loss
86 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
87 | nll_loss = torch.mean(nll_loss)
88 |
89 | # now the GAN part
90 | if optimizer_idx == 0:
91 | # generator update
92 | if cond is None:
93 | assert not self.disc_conditional
94 | logits_fake = self.discriminator(reconstructions.contiguous())
95 | else:
96 | assert self.disc_conditional
97 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
98 | g_loss = -torch.mean(logits_fake)
99 |
100 | try:
101 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
102 | except RuntimeError:
103 | assert not self.training
104 | d_weight = torch.tensor(0.0)
105 |
106 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
107 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
108 |
109 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
110 | "{}/quant_loss".format(split): codebook_loss.detach().mean(),
111 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
112 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
113 | "{}/p_loss".format(split): p_loss.detach().mean(),
114 | "{}/d_weight".format(split): d_weight.detach(),
115 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
116 | "{}/g_loss".format(split): g_loss.detach().mean(),
117 | }
118 | return loss, log
119 |
120 | if optimizer_idx == 1:
121 | # second pass for discriminator update
122 | if cond is None:
123 | logits_real = self.discriminator(inputs.contiguous().detach())
124 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
125 | else:
126 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
127 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
128 |
129 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
130 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
131 |
132 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
133 | "{}/logits_real".format(split): logits_real.detach().mean(),
134 | "{}/logits_fake".format(split): logits_fake.detach().mean()
135 | }
136 | return d_loss, log
137 |
--------------------------------------------------------------------------------
/taming/modules/misc/coord.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class CoordStage(object):
4 | def __init__(self, n_embed, down_factor):
5 | self.n_embed = n_embed
6 | self.down_factor = down_factor
7 |
8 | def eval(self):
9 | return self
10 |
11 | def encode(self, c):
12 | """fake vqmodel interface"""
13 | assert 0.0 <= c.min() and c.max() <= 1.0
14 | b,ch,h,w = c.shape
15 | assert ch == 1
16 |
17 | c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
18 | mode="area")
19 | c = c.clamp(0.0, 1.0)
20 | c = self.n_embed*c
21 | c_quant = c.round()
22 | c_ind = c_quant.to(dtype=torch.long)
23 |
24 | info = None, None, c_ind
25 | return c_quant, None, info
26 |
27 | def decode(self, c):
28 | c = c/self.n_embed
29 | c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
30 | mode="nearest")
31 | return c
32 |
--------------------------------------------------------------------------------
/taming/modules/transformer/mingpt.py:
--------------------------------------------------------------------------------
1 | """
2 | taken from: https://github.com/karpathy/minGPT/
3 | GPT model:
4 | - the initial stem consists of a combination of token encoding and a positional encoding
5 | - the meat of it is a uniform sequence of Transformer blocks
6 | - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
7 | - all blocks feed into a central residual pathway similar to resnets
8 | - the final decoder is a linear projection into a vanilla Softmax classifier
9 | """
10 |
11 | import math
12 | import logging
13 |
14 | import torch
15 | import torch.nn as nn
16 | from torch.nn import functional as F
17 | from transformers import top_k_top_p_filtering
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 |
22 | class GPTConfig:
23 | """ base GPT config, params common to all GPT versions """
24 | embd_pdrop = 0.1
25 | resid_pdrop = 0.1
26 | attn_pdrop = 0.1
27 |
28 | def __init__(self, vocab_size, block_size, **kwargs):
29 | self.vocab_size = vocab_size
30 | self.block_size = block_size
31 | for k,v in kwargs.items():
32 | setattr(self, k, v)
33 |
34 |
35 | class GPT1Config(GPTConfig):
36 | """ GPT-1 like network roughly 125M params """
37 | n_layer = 12
38 | n_head = 12
39 | n_embd = 768
40 |
41 |
42 | class CausalSelfAttention(nn.Module):
43 | """
44 | A vanilla multi-head masked self-attention layer with a projection at the end.
45 | It is possible to use torch.nn.MultiheadAttention here but I am including an
46 | explicit implementation here to show that there is nothing too scary here.
47 | """
48 |
49 | def __init__(self, config):
50 | super().__init__()
51 | assert config.n_embd % config.n_head == 0
52 | # key, query, value projections for all heads
53 | self.key = nn.Linear(config.n_embd, config.n_embd)
54 | self.query = nn.Linear(config.n_embd, config.n_embd)
55 | self.value = nn.Linear(config.n_embd, config.n_embd)
56 | # regularization
57 | self.attn_drop = nn.Dropout(config.attn_pdrop)
58 | self.resid_drop = nn.Dropout(config.resid_pdrop)
59 | # output projection
60 | self.proj = nn.Linear(config.n_embd, config.n_embd)
61 | # causal mask to ensure that attention is only applied to the left in the input sequence
62 | mask = torch.tril(torch.ones(config.block_size,
63 | config.block_size))
64 | if hasattr(config, "n_unmasked"):
65 | mask[:config.n_unmasked, :config.n_unmasked] = 1
66 | self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
67 | self.n_head = config.n_head
68 |
69 | def forward(self, x, layer_past=None):
70 | B, T, C = x.size()
71 |
72 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
73 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
74 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
75 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
76 |
77 | present = torch.stack((k, v))
78 | if layer_past is not None:
79 | past_key, past_value = layer_past
80 | k = torch.cat((past_key, k), dim=-2)
81 | v = torch.cat((past_value, v), dim=-2)
82 |
83 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
84 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
85 | if layer_past is None:
86 | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
87 |
88 | att = F.softmax(att, dim=-1)
89 | att = self.attn_drop(att)
90 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
91 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
92 |
93 | # output projection
94 | y = self.resid_drop(self.proj(y))
95 | return y, present # TODO: check that this does not break anything
96 |
97 |
98 | class Block(nn.Module):
99 | """ an unassuming Transformer block """
100 | def __init__(self, config):
101 | super().__init__()
102 | self.ln1 = nn.LayerNorm(config.n_embd)
103 | self.ln2 = nn.LayerNorm(config.n_embd)
104 | self.attn = CausalSelfAttention(config)
105 | self.mlp = nn.Sequential(
106 | nn.Linear(config.n_embd, 4 * config.n_embd),
107 | nn.GELU(), # nice
108 | nn.Linear(4 * config.n_embd, config.n_embd),
109 | nn.Dropout(config.resid_pdrop),
110 | )
111 |
112 | def forward(self, x, layer_past=None, return_present=False):
113 | # TODO: check that training still works
114 | if return_present: assert not self.training
115 | # layer past: tuple of length two with B, nh, T, hs
116 | attn, present = self.attn(self.ln1(x), layer_past=layer_past)
117 |
118 | x = x + attn
119 | x = x + self.mlp(self.ln2(x))
120 | if layer_past is not None or return_present:
121 | return x, present
122 | return x
123 |
124 |
125 | class GPT(nn.Module):
126 | """ the full GPT language model, with a context size of block_size """
127 | def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256,
128 | embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
129 | super().__init__()
130 | config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
131 | embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
132 | n_layer=n_layer, n_head=n_head, n_embd=n_embd,
133 | n_unmasked=n_unmasked)
134 | # input embedding stem
135 | self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
136 | self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
137 | self.drop = nn.Dropout(config.embd_pdrop)
138 | # transformer
139 | self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
140 | # decoder head
141 | self.ln_f = nn.LayerNorm(config.n_embd)
142 | self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
143 | self.block_size = config.block_size
144 | self.apply(self._init_weights)
145 | self.config = config
146 | logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
147 |
148 | def get_block_size(self):
149 | return self.block_size
150 |
151 | def _init_weights(self, module):
152 | if isinstance(module, (nn.Linear, nn.Embedding)):
153 | module.weight.data.normal_(mean=0.0, std=0.02)
154 | if isinstance(module, nn.Linear) and module.bias is not None:
155 | module.bias.data.zero_()
156 | elif isinstance(module, nn.LayerNorm):
157 | module.bias.data.zero_()
158 | module.weight.data.fill_(1.0)
159 |
160 | def forward(self, idx, embeddings=None, targets=None):
161 | # forward the GPT model
162 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
163 |
164 | if embeddings is not None: # prepend explicit embeddings
165 | token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
166 |
167 | t = token_embeddings.shape[1]
168 | assert t <= self.block_size, "Cannot forward, model block size is exhausted."
169 | position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
170 | x = self.drop(token_embeddings + position_embeddings)
171 | x = self.blocks(x)
172 | x = self.ln_f(x)
173 | logits = self.head(x)
174 |
175 | # if we are given some desired targets also calculate the loss
176 | loss = None
177 | if targets is not None:
178 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
179 |
180 | return logits, loss
181 |
182 | def forward_with_past(self, idx, embeddings=None, targets=None, past=None, past_length=None):
183 | # inference only
184 | assert not self.training
185 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
186 | if embeddings is not None: # prepend explicit embeddings
187 | token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
188 |
189 | if past is not None:
190 | assert past_length is not None
191 | past = torch.cat(past, dim=-2) # n_layer, 2, b, nh, len_past, dim_head
192 | past_shape = list(past.shape)
193 | expected_shape = [self.config.n_layer, 2, idx.shape[0], self.config.n_head, past_length, self.config.n_embd//self.config.n_head]
194 | assert past_shape == expected_shape, f"{past_shape} =/= {expected_shape}"
195 | position_embeddings = self.pos_emb[:, past_length, :] # each position maps to a (learnable) vector
196 | else:
197 | position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :]
198 |
199 | x = self.drop(token_embeddings + position_embeddings)
200 | presents = [] # accumulate over layers
201 | for i, block in enumerate(self.blocks):
202 | x, present = block(x, layer_past=past[i, ...] if past is not None else None, return_present=True)
203 | presents.append(present)
204 |
205 | x = self.ln_f(x)
206 | logits = self.head(x)
207 | # if we are given some desired targets also calculate the loss
208 | loss = None
209 | if targets is not None:
210 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
211 |
212 | return logits, loss, torch.stack(presents) # _, _, n_layer, 2, b, nh, 1, dim_head
213 |
214 |
215 | class DummyGPT(nn.Module):
216 | # for debugging
217 | def __init__(self, add_value=1):
218 | super().__init__()
219 | self.add_value = add_value
220 |
221 | def forward(self, idx):
222 | return idx + self.add_value, None
223 |
224 |
225 | class CodeGPT(nn.Module):
226 | """Takes in semi-embeddings"""
227 | def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256,
228 | embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
229 | super().__init__()
230 | config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
231 | embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
232 | n_layer=n_layer, n_head=n_head, n_embd=n_embd,
233 | n_unmasked=n_unmasked)
234 | # input embedding stem
235 | self.tok_emb = nn.Linear(in_channels, config.n_embd)
236 | self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
237 | self.drop = nn.Dropout(config.embd_pdrop)
238 | # transformer
239 | self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
240 | # decoder head
241 | self.ln_f = nn.LayerNorm(config.n_embd)
242 | self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
243 | self.block_size = config.block_size
244 | self.apply(self._init_weights)
245 | self.config = config
246 | logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
247 |
248 | def get_block_size(self):
249 | return self.block_size
250 |
251 | def _init_weights(self, module):
252 | if isinstance(module, (nn.Linear, nn.Embedding)):
253 | module.weight.data.normal_(mean=0.0, std=0.02)
254 | if isinstance(module, nn.Linear) and module.bias is not None:
255 | module.bias.data.zero_()
256 | elif isinstance(module, nn.LayerNorm):
257 | module.bias.data.zero_()
258 | module.weight.data.fill_(1.0)
259 |
260 | def forward(self, idx, embeddings=None, targets=None):
261 | # forward the GPT model
262 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
263 |
264 | if embeddings is not None: # prepend explicit embeddings
265 | token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
266 |
267 | t = token_embeddings.shape[1]
268 | assert t <= self.block_size, "Cannot forward, model block size is exhausted."
269 | position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
270 | x = self.drop(token_embeddings + position_embeddings)
271 | x = self.blocks(x)
272 | x = self.taming_cinln_f(x)
273 | logits = self.head(x)
274 |
275 | # if we are given some desired targets also calculate the loss
276 | loss = None
277 | if targets is not None:
278 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
279 |
280 | return logits, loss
281 |
282 |
283 |
284 | #### sampling utils
285 |
286 | def top_k_logits(logits, k):
287 | v, ix = torch.topk(logits, k)
288 | out = logits.clone()
289 | out[out < v[:, [-1]]] = -float('Inf')
290 | return out
291 |
292 | @torch.no_grad()
293 | def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
294 | """
295 | take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
296 | the sequence, feeding the predictions back into the model each time. Clearly the sampling
297 | has quadratic complexity unlike an RNN that is only linear, and has a finite context window
298 | of block_size, unlike an RNN that has an infinite context window.
299 | """
300 | block_size = model.get_block_size()
301 | model.eval()
302 | for k in range(steps):
303 | x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
304 | logits, _ = model(x_cond)
305 | # pluck the logits at the final step and scale by temperature
306 | logits = logits[:, -1, :] / temperature
307 | # optionally crop probabilities to only the top k options
308 | if top_k is not None:
309 | logits = top_k_logits(logits, top_k)
310 | # apply softmax to convert to probabilities
311 | probs = F.softmax(logits, dim=-1)
312 | # sample from the distribution or take the most likely
313 | if sample:
314 | ix = torch.multinomial(probs, num_samples=1)
315 | else:
316 | _, ix = torch.topk(probs, k=1, dim=-1)
317 | # append to the sequence and continue
318 | x = torch.cat((x, ix), dim=1)
319 |
320 | return x
321 |
322 |
323 | @torch.no_grad()
324 | def sample_with_past(x, model, steps, temperature=1., sample_logits=True,
325 | top_k=None, top_p=None, callback=None):
326 | # x is conditioning
327 | sample = x
328 | cond_len = x.shape[1]
329 | past = None
330 | for n in range(steps):
331 | if callback is not None:
332 | callback(n)
333 | logits, _, present = model.forward_with_past(x, past=past, past_length=(n+cond_len-1))
334 | if past is None:
335 | past = [present]
336 | else:
337 | past.append(present)
338 | logits = logits[:, -1, :] / temperature
339 | if top_k is not None:
340 | logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
341 |
342 | probs = F.softmax(logits, dim=-1)
343 | if not sample_logits:
344 | _, x = torch.topk(probs, k=1, dim=-1)
345 | else:
346 | x = torch.multinomial(probs, num_samples=1)
347 | # append to the sequence and continue
348 | sample = torch.cat((sample, x), dim=1)
349 | del past
350 | sample = sample[:, cond_len:] # cut conditioning off
351 | return sample
352 |
353 |
354 | #### clustering utils
355 |
356 | class KMeans(nn.Module):
357 | def __init__(self, ncluster=512, nc=3, niter=10):
358 | super().__init__()
359 | self.ncluster = ncluster
360 | self.nc = nc
361 | self.niter = niter
362 | self.shape = (3,32,32)
363 | self.register_buffer("C", torch.zeros(self.ncluster,nc))
364 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
365 |
366 | def is_initialized(self):
367 | return self.initialized.item() == 1
368 |
369 | @torch.no_grad()
370 | def initialize(self, x):
371 | N, D = x.shape
372 | assert D == self.nc, D
373 | c = x[torch.randperm(N)[:self.ncluster]] # init clusters at random
374 | for i in range(self.niter):
375 | # assign all pixels to the closest codebook element
376 | a = ((x[:, None, :] - c[None, :, :])**2).sum(-1).argmin(1)
377 | # move each codebook element to be the mean of the pixels that assigned to it
378 | c = torch.stack([x[a==k].mean(0) for k in range(self.ncluster)])
379 | # re-assign any poorly positioned codebook elements
380 | nanix = torch.any(torch.isnan(c), dim=1)
381 | ndead = nanix.sum().item()
382 | print('done step %d/%d, re-initialized %d dead clusters' % (i+1, self.niter, ndead))
383 | c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters
384 |
385 | self.C.copy_(c)
386 | self.initialized.fill_(1)
387 |
388 |
389 | def forward(self, x, reverse=False, shape=None):
390 | if not reverse:
391 | # flatten
392 | bs,c,h,w = x.shape
393 | assert c == self.nc
394 | x = x.reshape(bs,c,h*w,1)
395 | C = self.C.permute(1,0)
396 | C = C.reshape(1,c,1,self.ncluster)
397 | a = ((x-C)**2).sum(1).argmin(-1) # bs, h*w indices
398 | return a
399 | else:
400 | # flatten
401 | bs, HW = x.shape
402 | """
403 | c = self.C.reshape( 1, self.nc, 1, self.ncluster)
404 | c = c[bs*[0],:,:,:]
405 | c = c[:,:,HW*[0],:]
406 | x = x.reshape(bs, 1, HW, 1)
407 | x = x[:,3*[0],:,:]
408 | x = torch.gather(c, dim=3, index=x)
409 | """
410 | x = self.C[x]
411 | x = x.permute(0,2,1)
412 | shape = shape if shape is not None else self.shape
413 | x = x.reshape(bs, *shape)
414 |
415 | return x
416 |
--------------------------------------------------------------------------------
/taming/modules/transformer/permuter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 |
6 | class AbstractPermuter(nn.Module):
7 | def __init__(self, *args, **kwargs):
8 | super().__init__()
9 | def forward(self, x, reverse=False):
10 | raise NotImplementedError
11 |
12 |
13 | class Identity(AbstractPermuter):
14 | def __init__(self):
15 | super().__init__()
16 |
17 | def forward(self, x, reverse=False):
18 | return x
19 |
20 |
21 | class Subsample(AbstractPermuter):
22 | def __init__(self, H, W):
23 | super().__init__()
24 | C = 1
25 | indices = np.arange(H*W).reshape(C,H,W)
26 | while min(H, W) > 1:
27 | indices = indices.reshape(C,H//2,2,W//2,2)
28 | indices = indices.transpose(0,2,4,1,3)
29 | indices = indices.reshape(C*4,H//2, W//2)
30 | H = H//2
31 | W = W//2
32 | C = C*4
33 | assert H == W == 1
34 | idx = torch.tensor(indices.ravel())
35 | self.register_buffer('forward_shuffle_idx',
36 | nn.Parameter(idx, requires_grad=False))
37 | self.register_buffer('backward_shuffle_idx',
38 | nn.Parameter(torch.argsort(idx), requires_grad=False))
39 |
40 | def forward(self, x, reverse=False):
41 | if not reverse:
42 | return x[:, self.forward_shuffle_idx]
43 | else:
44 | return x[:, self.backward_shuffle_idx]
45 |
46 |
47 | def mortonify(i, j):
48 | """(i,j) index to linear morton code"""
49 | i = np.uint64(i)
50 | j = np.uint64(j)
51 |
52 | z = np.uint(0)
53 |
54 | for pos in range(32):
55 | z = (z |
56 | ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) |
57 | ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1))
58 | )
59 | return z
60 |
61 |
62 | class ZCurve(AbstractPermuter):
63 | def __init__(self, H, W):
64 | super().__init__()
65 | reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)]
66 | idx = np.argsort(reverseidx)
67 | idx = torch.tensor(idx)
68 | reverseidx = torch.tensor(reverseidx)
69 | self.register_buffer('forward_shuffle_idx',
70 | idx)
71 | self.register_buffer('backward_shuffle_idx',
72 | reverseidx)
73 |
74 | def forward(self, x, reverse=False):
75 | if not reverse:
76 | return x[:, self.forward_shuffle_idx]
77 | else:
78 | return x[:, self.backward_shuffle_idx]
79 |
80 |
81 | class SpiralOut(AbstractPermuter):
82 | def __init__(self, H, W):
83 | super().__init__()
84 | assert H == W
85 | size = W
86 | indices = np.arange(size*size).reshape(size,size)
87 |
88 | i0 = size//2
89 | j0 = size//2-1
90 |
91 | i = i0
92 | j = j0
93 |
94 | idx = [indices[i0, j0]]
95 | step_mult = 0
96 | for c in range(1, size//2+1):
97 | step_mult += 1
98 | # steps left
99 | for k in range(step_mult):
100 | i = i - 1
101 | j = j
102 | idx.append(indices[i, j])
103 |
104 | # step down
105 | for k in range(step_mult):
106 | i = i
107 | j = j + 1
108 | idx.append(indices[i, j])
109 |
110 | step_mult += 1
111 | if c < size//2:
112 | # step right
113 | for k in range(step_mult):
114 | i = i + 1
115 | j = j
116 | idx.append(indices[i, j])
117 |
118 | # step up
119 | for k in range(step_mult):
120 | i = i
121 | j = j - 1
122 | idx.append(indices[i, j])
123 | else:
124 | # end reached
125 | for k in range(step_mult-1):
126 | i = i + 1
127 | idx.append(indices[i, j])
128 |
129 | assert len(idx) == size*size
130 | idx = torch.tensor(idx)
131 | self.register_buffer('forward_shuffle_idx', idx)
132 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
133 |
134 | def forward(self, x, reverse=False):
135 | if not reverse:
136 | return x[:, self.forward_shuffle_idx]
137 | else:
138 | return x[:, self.backward_shuffle_idx]
139 |
140 |
141 | class SpiralIn(AbstractPermuter):
142 | def __init__(self, H, W):
143 | super().__init__()
144 | assert H == W
145 | size = W
146 | indices = np.arange(size*size).reshape(size,size)
147 |
148 | i0 = size//2
149 | j0 = size//2-1
150 |
151 | i = i0
152 | j = j0
153 |
154 | idx = [indices[i0, j0]]
155 | step_mult = 0
156 | for c in range(1, size//2+1):
157 | step_mult += 1
158 | # steps left
159 | for k in range(step_mult):
160 | i = i - 1
161 | j = j
162 | idx.append(indices[i, j])
163 |
164 | # step down
165 | for k in range(step_mult):
166 | i = i
167 | j = j + 1
168 | idx.append(indices[i, j])
169 |
170 | step_mult += 1
171 | if c < size//2:
172 | # step right
173 | for k in range(step_mult):
174 | i = i + 1
175 | j = j
176 | idx.append(indices[i, j])
177 |
178 | # step up
179 | for k in range(step_mult):
180 | i = i
181 | j = j - 1
182 | idx.append(indices[i, j])
183 | else:
184 | # end reached
185 | for k in range(step_mult-1):
186 | i = i + 1
187 | idx.append(indices[i, j])
188 |
189 | assert len(idx) == size*size
190 | idx = idx[::-1]
191 | idx = torch.tensor(idx)
192 | self.register_buffer('forward_shuffle_idx', idx)
193 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
194 |
195 | def forward(self, x, reverse=False):
196 | if not reverse:
197 | return x[:, self.forward_shuffle_idx]
198 | else:
199 | return x[:, self.backward_shuffle_idx]
200 |
201 |
202 | class Random(nn.Module):
203 | def __init__(self, H, W):
204 | super().__init__()
205 | indices = np.random.RandomState(1).permutation(H*W)
206 | idx = torch.tensor(indices.ravel())
207 | self.register_buffer('forward_shuffle_idx', idx)
208 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
209 |
210 | def forward(self, x, reverse=False):
211 | if not reverse:
212 | return x[:, self.forward_shuffle_idx]
213 | else:
214 | return x[:, self.backward_shuffle_idx]
215 |
216 |
217 | class AlternateParsing(AbstractPermuter):
218 | def __init__(self, H, W):
219 | super().__init__()
220 | indices = np.arange(W*H).reshape(H,W)
221 | for i in range(1, H, 2):
222 | indices[i, :] = indices[i, ::-1]
223 | idx = indices.flatten()
224 | assert len(idx) == H*W
225 | idx = torch.tensor(idx)
226 | self.register_buffer('forward_shuffle_idx', idx)
227 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
228 |
229 | def forward(self, x, reverse=False):
230 | if not reverse:
231 | return x[:, self.forward_shuffle_idx]
232 | else:
233 | return x[:, self.backward_shuffle_idx]
234 |
235 |
236 | if __name__ == "__main__":
237 | p0 = AlternateParsing(16, 16)
238 | print(p0.forward_shuffle_idx)
239 | print(p0.backward_shuffle_idx)
240 |
241 | x = torch.randint(0, 768, size=(11, 256))
242 | y = p0(x)
243 | xre = p0(y, reverse=True)
244 | assert torch.equal(x, xre)
245 |
246 | p1 = SpiralOut(2, 2)
247 | print(p1.forward_shuffle_idx)
248 | print(p1.backward_shuffle_idx)
249 |
--------------------------------------------------------------------------------
/taming/modules/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def count_params(model):
6 | total_params = sum(p.numel() for p in model.parameters())
7 | return total_params
8 |
9 |
10 | class ActNorm(nn.Module):
11 | def __init__(self, num_features, logdet=False, affine=True,
12 | allow_reverse_init=False):
13 | assert affine
14 | super().__init__()
15 | self.logdet = logdet
16 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
17 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
18 | self.allow_reverse_init = allow_reverse_init
19 |
20 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
21 |
22 | def initialize(self, input):
23 | with torch.no_grad():
24 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
25 | mean = (
26 | flatten.mean(1)
27 | .unsqueeze(1)
28 | .unsqueeze(2)
29 | .unsqueeze(3)
30 | .permute(1, 0, 2, 3)
31 | )
32 | std = (
33 | flatten.std(1)
34 | .unsqueeze(1)
35 | .unsqueeze(2)
36 | .unsqueeze(3)
37 | .permute(1, 0, 2, 3)
38 | )
39 |
40 | self.loc.data.copy_(-mean)
41 | self.scale.data.copy_(1 / (std + 1e-6))
42 |
43 | def forward(self, input, reverse=False):
44 | if reverse:
45 | return self.reverse(input)
46 | if len(input.shape) == 2:
47 | input = input[:,:,None,None]
48 | squeeze = True
49 | else:
50 | squeeze = False
51 |
52 | _, _, height, width = input.shape
53 |
54 | if self.training and self.initialized.item() == 0:
55 | self.initialize(input)
56 | self.initialized.fill_(1)
57 |
58 | h = self.scale * (input + self.loc)
59 |
60 | if squeeze:
61 | h = h.squeeze(-1).squeeze(-1)
62 |
63 | if self.logdet:
64 | log_abs = torch.log(torch.abs(self.scale))
65 | logdet = height*width*torch.sum(log_abs)
66 | logdet = logdet * torch.ones(input.shape[0]).to(input)
67 | return h, logdet
68 |
69 | return h
70 |
71 | def reverse(self, output):
72 | if self.training and self.initialized.item() == 0:
73 | if not self.allow_reverse_init:
74 | raise RuntimeError(
75 | "Initializing ActNorm in reverse direction is "
76 | "disabled by default. Use allow_reverse_init=True to enable."
77 | )
78 | else:
79 | self.initialize(output)
80 | self.initialized.fill_(1)
81 |
82 | if len(output.shape) == 2:
83 | output = output[:,:,None,None]
84 | squeeze = True
85 | else:
86 | squeeze = False
87 |
88 | h = output / self.scale - self.loc
89 |
90 | if squeeze:
91 | h = h.squeeze(-1).squeeze(-1)
92 | return h
93 |
94 |
95 | class AbstractEncoder(nn.Module):
96 | def __init__(self):
97 | super().__init__()
98 |
99 | def encode(self, *args, **kwargs):
100 | raise NotImplementedError
101 |
102 |
103 | class Labelator(AbstractEncoder):
104 | """Net2Net Interface for Class-Conditional Model"""
105 | def __init__(self, n_classes, quantize_interface=True):
106 | super().__init__()
107 | self.n_classes = n_classes
108 | self.quantize_interface = quantize_interface
109 |
110 | def encode(self, c):
111 | c = c[:,None]
112 | if self.quantize_interface:
113 | return c, None, [None, None, c.long()]
114 | return c
115 |
116 |
117 | class SOSProvider(AbstractEncoder):
118 | # for unconditional training
119 | def __init__(self, sos_token, quantize_interface=True):
120 | super().__init__()
121 | self.sos_token = sos_token
122 | self.quantize_interface = quantize_interface
123 |
124 | def encode(self, x):
125 | # get batch size from data and replicate sos_token
126 | c = torch.ones(x.shape[0], 1)*self.sos_token
127 | c = c.long().to(x.device)
128 | if self.quantize_interface:
129 | return c, None, [None, None, c]
130 | return c
131 |
--------------------------------------------------------------------------------
/taming/modules/vqvae/quantize.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from torch import einsum
6 | from einops import rearrange
7 |
8 |
9 | class VectorQuantizer(nn.Module):
10 | """
11 | see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
12 | ____________________________________________
13 | Discretization bottleneck part of the VQ-VAE.
14 | Inputs:
15 | - n_e : number of embeddings
16 | - e_dim : dimension of embedding
17 | - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
18 | _____________________________________________
19 | """
20 |
21 | # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for
22 | # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be
23 | # used wherever VectorQuantizer has been used before and is additionally
24 | # more efficient.
25 | def __init__(self, n_e, e_dim, beta):
26 | super(VectorQuantizer, self).__init__()
27 | self.n_e = n_e
28 | self.e_dim = e_dim
29 | self.beta = beta
30 |
31 | self.embedding = nn.Embedding(self.n_e, self.e_dim)
32 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
33 |
34 | def forward(self, z):
35 | """
36 | Inputs the output of the encoder network z and maps it to a discrete
37 | one-hot vector that is the index of the closest embedding vector e_j
38 | z (continuous) -> z_q (discrete)
39 | z.shape = (batch, channel, height, width)
40 | quantization pipeline:
41 | 1. get encoder input (B,C,H,W)
42 | 2. flatten input to (B*H*W,C)
43 | """
44 | # reshape z -> (batch, height, width, channel) and flatten
45 | z = z.permute(0, 2, 3, 1).contiguous()
46 | z_flattened = z.view(-1, self.e_dim)
47 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
48 |
49 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
50 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \
51 | torch.matmul(z_flattened, self.embedding.weight.t())
52 |
53 | ## could possible replace this here
54 | # #\start...
55 | # find closest encodings
56 | min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
57 |
58 | min_encodings = torch.zeros(
59 | min_encoding_indices.shape[0], self.n_e).to(z)
60 | min_encodings.scatter_(1, min_encoding_indices, 1)
61 |
62 | # dtype min encodings: torch.float32
63 | # min_encodings shape: torch.Size([2048, 512])
64 | # min_encoding_indices.shape: torch.Size([2048, 1])
65 |
66 | # get quantized latent vectors
67 | z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
68 | #.........\end
69 |
70 | # with:
71 | # .........\start
72 | #min_encoding_indices = torch.argmin(d, dim=1)
73 | #z_q = self.embedding(min_encoding_indices)
74 | # ......\end......... (TODO)
75 |
76 | # compute loss for embedding
77 | loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
78 | torch.mean((z_q - z.detach()) ** 2)
79 |
80 | # preserve gradients
81 | z_q = z + (z_q - z).detach()
82 |
83 | # perplexity
84 | e_mean = torch.mean(min_encodings, dim=0)
85 | perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
86 |
87 | # reshape back to match original input shape
88 | z_q = z_q.permute(0, 3, 1, 2).contiguous()
89 |
90 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
91 |
92 | def get_codebook_entry(self, indices, shape):
93 | # shape specifying (batch, height, width, channel)
94 | # TODO: check for more easy handling with nn.Embedding
95 | min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
96 | min_encodings.scatter_(1, indices[:,None], 1)
97 |
98 | # get quantized latent vectors
99 | z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
100 |
101 | if shape is not None:
102 | z_q = z_q.view(shape)
103 |
104 | # reshape back to match original input shape
105 | z_q = z_q.permute(0, 3, 1, 2).contiguous()
106 |
107 | return z_q
108 |
109 |
110 | class GumbelQuantize(nn.Module):
111 | """
112 | credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
113 | Gumbel Softmax trick quantizer
114 | Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
115 | https://arxiv.org/abs/1611.01144
116 | """
117 | def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
118 | kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
119 | remap=None, unknown_index="random"):
120 | super().__init__()
121 |
122 | self.embedding_dim = embedding_dim
123 | self.n_embed = n_embed
124 |
125 | self.straight_through = straight_through
126 | self.temperature = temp_init
127 | self.kl_weight = kl_weight
128 |
129 | self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
130 | self.embed = nn.Embedding(n_embed, embedding_dim)
131 |
132 | self.use_vqinterface = use_vqinterface
133 |
134 | self.remap = remap
135 | if self.remap is not None:
136 | self.register_buffer("used", torch.tensor(np.load(self.remap)))
137 | self.re_embed = self.used.shape[0]
138 | self.unknown_index = unknown_index # "random" or "extra" or integer
139 | if self.unknown_index == "extra":
140 | self.unknown_index = self.re_embed
141 | self.re_embed = self.re_embed+1
142 | print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
143 | f"Using {self.unknown_index} for unknown indices.")
144 | else:
145 | self.re_embed = n_embed
146 |
147 | def remap_to_used(self, inds):
148 | ishape = inds.shape
149 | assert len(ishape)>1
150 | inds = inds.reshape(ishape[0],-1)
151 | used = self.used.to(inds)
152 | match = (inds[:,:,None]==used[None,None,...]).long()
153 | new = match.argmax(-1)
154 | unknown = match.sum(2)<1
155 | if self.unknown_index == "random":
156 | new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
157 | else:
158 | new[unknown] = self.unknown_index
159 | return new.reshape(ishape)
160 |
161 | def unmap_to_all(self, inds):
162 | ishape = inds.shape
163 | assert len(ishape)>1
164 | inds = inds.reshape(ishape[0],-1)
165 | used = self.used.to(inds)
166 | if self.re_embed > self.used.shape[0]: # extra token
167 | inds[inds>=self.used.shape[0]] = 0 # simply set to zero
168 | back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
169 | return back.reshape(ishape)
170 |
171 | def forward(self, z, temp=None, return_logits=False):
172 | # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
173 | hard = self.straight_through if self.training else True
174 | temp = self.temperature if temp is None else temp
175 |
176 | logits = self.proj(z)
177 | if self.remap is not None:
178 | # continue only with used logits
179 | full_zeros = torch.zeros_like(logits)
180 | logits = logits[:,self.used,...]
181 |
182 | soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
183 | if self.remap is not None:
184 | # go back to all entries but unused set to zero
185 | full_zeros[:,self.used,...] = soft_one_hot
186 | soft_one_hot = full_zeros
187 | z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
188 |
189 | # + kl divergence to the prior loss
190 | qy = F.softmax(logits, dim=1)
191 | diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
192 |
193 | ind = soft_one_hot.argmax(dim=1)
194 | if self.remap is not None:
195 | ind = self.remap_to_used(ind)
196 | if self.use_vqinterface:
197 | if return_logits:
198 | return z_q, diff, (None, None, ind), logits
199 | return z_q, diff, (None, None, ind)
200 | return z_q, diff, ind
201 |
202 | def get_codebook_entry(self, indices, shape):
203 | b, h, w, c = shape
204 | assert b*h*w == indices.shape[0]
205 | indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
206 | if self.remap is not None:
207 | indices = self.unmap_to_all(indices)
208 | one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
209 | z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
210 | return z_q
211 |
212 |
213 | class VectorQuantizer2(nn.Module):
214 | """
215 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
216 | avoids costly matrix multiplications and allows for post-hoc remapping of indices.
217 | """
218 | # NOTE: due to a bug the beta term was applied to the wrong term. for
219 | # backwards compatibility we use the buggy version by default, but you can
220 | # specify legacy=False to fix it.
221 | def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
222 | sane_index_shape=False, legacy=True):
223 | super().__init__()
224 | self.n_e = n_e
225 | self.e_dim = e_dim
226 | self.beta = beta
227 | self.legacy = legacy
228 |
229 | self.embedding = nn.Embedding(self.n_e, self.e_dim)
230 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
231 |
232 | self.remap = remap
233 | if self.remap is not None:
234 | self.register_buffer("used", torch.tensor(np.load(self.remap)))
235 | self.re_embed = self.used.shape[0]
236 | self.unknown_index = unknown_index # "random" or "extra" or integer
237 | if self.unknown_index == "extra":
238 | self.unknown_index = self.re_embed
239 | self.re_embed = self.re_embed+1
240 | print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
241 | f"Using {self.unknown_index} for unknown indices.")
242 | else:
243 | self.re_embed = n_e
244 |
245 | self.sane_index_shape = sane_index_shape
246 |
247 | def remap_to_used(self, inds):
248 | ishape = inds.shape
249 | assert len(ishape)>1
250 | inds = inds.reshape(ishape[0],-1)
251 | used = self.used.to(inds)
252 | match = (inds[:,:,None]==used[None,None,...]).long()
253 | new = match.argmax(-1)
254 | unknown = match.sum(2)<1
255 | if self.unknown_index == "random":
256 | new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
257 | else:
258 | new[unknown] = self.unknown_index
259 | return new.reshape(ishape)
260 |
261 | def unmap_to_all(self, inds):
262 | ishape = inds.shape
263 | assert len(ishape)>1
264 | inds = inds.reshape(ishape[0],-1)
265 | used = self.used.to(inds)
266 | if self.re_embed > self.used.shape[0]: # extra token
267 | inds[inds>=self.used.shape[0]] = 0 # simply set to zero
268 | back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
269 | return back.reshape(ishape)
270 |
271 | def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
272 | assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
273 | assert rescale_logits==False, "Only for interface compatible with Gumbel"
274 | assert return_logits==False, "Only for interface compatible with Gumbel"
275 | # reshape z -> (batch, height, width, channel) and flatten
276 | z = rearrange(z, 'b c h w -> b h w c').contiguous()
277 | z_flattened = z.view(-1, self.e_dim)
278 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
279 |
280 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
281 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \
282 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
283 |
284 | min_encoding_indices = torch.argmin(d, dim=1)
285 | z_q = self.embedding(min_encoding_indices).view(z.shape)
286 | perplexity = None
287 | min_encodings = None
288 |
289 | # compute loss for embedding
290 | if not self.legacy:
291 | loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
292 | torch.mean((z_q - z.detach()) ** 2)
293 | else:
294 | loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
295 | torch.mean((z_q - z.detach()) ** 2)
296 |
297 | # preserve gradients
298 | z_q = z + (z_q - z).detach()
299 |
300 | # reshape back to match original input shape
301 | z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
302 |
303 | if self.remap is not None:
304 | min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
305 | min_encoding_indices = self.remap_to_used(min_encoding_indices)
306 | min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
307 |
308 | if self.sane_index_shape:
309 | min_encoding_indices = min_encoding_indices.reshape(
310 | z_q.shape[0], z_q.shape[2], z_q.shape[3])
311 |
312 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
313 |
314 | def get_codebook_entry(self, indices, shape):
315 | # shape specifying (batch, height, width, channel)
316 | if self.remap is not None:
317 | indices = indices.reshape(shape[0],-1) # add batch axis
318 | indices = self.unmap_to_all(indices)
319 | indices = indices.reshape(-1) # flatten again
320 |
321 | # get quantized latent vectors
322 | z_q = self.embedding(indices)
323 |
324 | if shape is not None:
325 | z_q = z_q.view(shape)
326 | # reshape back to match original input shape
327 | z_q = z_q.permute(0, 3, 1, 2).contiguous()
328 |
329 | return z_q
330 |
331 | class EmbeddingEMA(nn.Module):
332 | def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
333 | super().__init__()
334 | self.decay = decay
335 | self.eps = eps
336 | weight = torch.randn(num_tokens, codebook_dim)
337 | self.weight = nn.Parameter(weight, requires_grad = False)
338 | self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False)
339 | self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False)
340 | self.update = True
341 |
342 | def forward(self, embed_id):
343 | return F.embedding(embed_id, self.weight)
344 |
345 | def cluster_size_ema_update(self, new_cluster_size):
346 | self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
347 |
348 | def embed_avg_ema_update(self, new_embed_avg):
349 | self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
350 |
351 | def weight_update(self, num_tokens):
352 | n = self.cluster_size.sum()
353 | smoothed_cluster_size = (
354 | (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
355 | )
356 | #normalize embedding average with smoothed cluster size
357 | embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
358 | self.weight.data.copy_(embed_normalized)
359 |
360 |
361 | class EMAVectorQuantizer(nn.Module):
362 | def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
363 | remap=None, unknown_index="random"):
364 | super().__init__()
365 | self.codebook_dim = codebook_dim
366 | self.num_tokens = num_tokens
367 | self.beta = beta
368 | self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
369 |
370 | self.remap = remap
371 | if self.remap is not None:
372 | self.register_buffer("used", torch.tensor(np.load(self.remap)))
373 | self.re_embed = self.used.shape[0]
374 | self.unknown_index = unknown_index # "random" or "extra" or integer
375 | if self.unknown_index == "extra":
376 | self.unknown_index = self.re_embed
377 | self.re_embed = self.re_embed+1
378 | print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
379 | f"Using {self.unknown_index} for unknown indices.")
380 | else:
381 | self.re_embed = n_embed
382 |
383 | def remap_to_used(self, inds):
384 | ishape = inds.shape
385 | assert len(ishape)>1
386 | inds = inds.reshape(ishape[0],-1)
387 | used = self.used.to(inds)
388 | match = (inds[:,:,None]==used[None,None,...]).long()
389 | new = match.argmax(-1)
390 | unknown = match.sum(2)<1
391 | if self.unknown_index == "random":
392 | new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
393 | else:
394 | new[unknown] = self.unknown_index
395 | return new.reshape(ishape)
396 |
397 | def unmap_to_all(self, inds):
398 | ishape = inds.shape
399 | assert len(ishape)>1
400 | inds = inds.reshape(ishape[0],-1)
401 | used = self.used.to(inds)
402 | if self.re_embed > self.used.shape[0]: # extra token
403 | inds[inds>=self.used.shape[0]] = 0 # simply set to zero
404 | back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
405 | return back.reshape(ishape)
406 |
407 | def forward(self, z):
408 | # reshape z -> (batch, height, width, channel) and flatten
409 | #z, 'b c h w -> b h w c'
410 | z = rearrange(z, 'b c h w -> b h w c')
411 | z_flattened = z.reshape(-1, self.codebook_dim)
412 |
413 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
414 | d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
415 | self.embedding.weight.pow(2).sum(dim=1) - 2 * \
416 | torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
417 |
418 |
419 | encoding_indices = torch.argmin(d, dim=1)
420 |
421 | z_q = self.embedding(encoding_indices).view(z.shape)
422 | encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
423 | avg_probs = torch.mean(encodings, dim=0)
424 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
425 |
426 | if self.training and self.embedding.update:
427 | #EMA cluster size
428 | encodings_sum = encodings.sum(0)
429 | self.embedding.cluster_size_ema_update(encodings_sum)
430 | #EMA embedding average
431 | embed_sum = encodings.transpose(0,1) @ z_flattened
432 | self.embedding.embed_avg_ema_update(embed_sum)
433 | #normalize embed_avg and update weight
434 | self.embedding.weight_update(self.num_tokens)
435 |
436 | # compute loss for embedding
437 | loss = self.beta * F.mse_loss(z_q.detach(), z)
438 |
439 | # preserve gradients
440 | z_q = z + (z_q - z).detach()
441 |
442 | # reshape back to match original input shape
443 | #z_q, 'b h w c -> b c h w'
444 | z_q = rearrange(z_q, 'b h w c -> b c h w')
445 | return z_q, loss, (perplexity, encodings, encoding_indices)
446 |
--------------------------------------------------------------------------------
/taming/util.py:
--------------------------------------------------------------------------------
1 | import os, hashlib
2 | import requests
3 | from tqdm import tqdm
4 |
5 | URL_MAP = {
6 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
7 | }
8 |
9 | CKPT_MAP = {
10 | "vgg_lpips": "vgg.pth"
11 | }
12 |
13 | MD5_MAP = {
14 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
15 | }
16 |
17 |
18 | def download(url, local_path, chunk_size=1024):
19 | os.makedirs(os.path.split(local_path)[0], exist_ok=True)
20 | with requests.get(url, stream=True) as r:
21 | total_size = int(r.headers.get("content-length", 0))
22 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
23 | with open(local_path, "wb") as f:
24 | for data in r.iter_content(chunk_size=chunk_size):
25 | if data:
26 | f.write(data)
27 | pbar.update(chunk_size)
28 |
29 |
30 | def md5_hash(path):
31 | with open(path, "rb") as f:
32 | content = f.read()
33 | return hashlib.md5(content).hexdigest()
34 |
35 |
36 | def get_ckpt_path(name, root, check=False):
37 | assert name in URL_MAP
38 | path = os.path.join(root, CKPT_MAP[name])
39 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
40 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
41 | download(URL_MAP[name], path)
42 | md5 = md5_hash(path)
43 | assert md5 == MD5_MAP[name], md5
44 | return path
45 |
46 |
47 | class KeyNotFoundError(Exception):
48 | def __init__(self, cause, keys=None, visited=None):
49 | self.cause = cause
50 | self.keys = keys
51 | self.visited = visited
52 | messages = list()
53 | if keys is not None:
54 | messages.append("Key not found: {}".format(keys))
55 | if visited is not None:
56 | messages.append("Visited: {}".format(visited))
57 | messages.append("Cause:\n{}".format(cause))
58 | message = "\n".join(messages)
59 | super().__init__(message)
60 |
61 |
62 | def retrieve(
63 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
64 | ):
65 | """Given a nested list or dict return the desired value at key expanding
66 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion
67 | is done in-place.
68 |
69 | Parameters
70 | ----------
71 | list_or_dict : list or dict
72 | Possibly nested list or dictionary.
73 | key : str
74 | key/to/value, path like string describing all keys necessary to
75 | consider to get to the desired value. List indices can also be
76 | passed here.
77 | splitval : str
78 | String that defines the delimiter between keys of the
79 | different depth levels in `key`.
80 | default : obj
81 | Value returned if :attr:`key` is not found.
82 | expand : bool
83 | Whether to expand callable nodes on the path or not.
84 |
85 | Returns
86 | -------
87 | The desired value or if :attr:`default` is not ``None`` and the
88 | :attr:`key` is not found returns ``default``.
89 |
90 | Raises
91 | ------
92 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
93 | ``None``.
94 | """
95 |
96 | keys = key.split(splitval)
97 |
98 | success = True
99 | try:
100 | visited = []
101 | parent = None
102 | last_key = None
103 | for key in keys:
104 | if callable(list_or_dict):
105 | if not expand:
106 | raise KeyNotFoundError(
107 | ValueError(
108 | "Trying to get past callable node with expand=False."
109 | ),
110 | keys=keys,
111 | visited=visited,
112 | )
113 | list_or_dict = list_or_dict()
114 | parent[last_key] = list_or_dict
115 |
116 | last_key = key
117 | parent = list_or_dict
118 |
119 | try:
120 | if isinstance(list_or_dict, dict):
121 | list_or_dict = list_or_dict[key]
122 | else:
123 | list_or_dict = list_or_dict[int(key)]
124 | except (KeyError, IndexError, ValueError) as e:
125 | raise KeyNotFoundError(e, keys=keys, visited=visited)
126 |
127 | visited += [key]
128 | # final expansion of retrieved value
129 | if expand and callable(list_or_dict):
130 | list_or_dict = list_or_dict()
131 | parent[last_key] = list_or_dict
132 | except KeyNotFoundError as e:
133 | if default is None:
134 | raise e
135 | else:
136 | list_or_dict = default
137 | success = False
138 |
139 | if not pass_success:
140 | return list_or_dict
141 | else:
142 | return list_or_dict, success
143 |
144 |
145 | if __name__ == "__main__":
146 | config = {"keya": "a",
147 | "keyb": "b",
148 | "keyc":
149 | {"cc1": 1,
150 | "cc2": 2,
151 | }
152 | }
153 | from omegaconf import OmegaConf
154 | config = OmegaConf.create(config)
155 | print(config)
156 | retrieve(config, "keya")
157 |
158 |
--------------------------------------------------------------------------------