├── LICENSE
├── README.md
├── laion30k
└── datacheck.py
├── utils.py
├── paella_minimal.py
├── modules.py
├── paella.py
├── evaluation
├── evaluation_generation.py
└── evaluator.py
└── paella_sampling.ipynb
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 delicious-tasty
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Paella (Anonymized CVPR submission)
2 | Conditional text-to-image generation has seen countless recent improvements in terms of quality, diversity and fidelity. Nevertheless, most state-of-the-art models require numerous inference steps to produce faithful generations, resulting in performance bottlenecks for end-user applications. In this paper we introduce Paella, a novel text-to-image model requiring less than 10 steps to sample high-fidelity images, using a speed-optimized architecture allowing to sample a single image in less than 500 ms, while having 573M parameters. The model operates on a compressed & quantized latent space, it is conditioned on CLIP embeddings and uses an improved sampling function over previous works. Aside from text-conditional image generation, our model is able to do latent space interpolation and image manipulations such as inpainting, outpainting, and structural editing.
3 |
4 |
5 | 
6 |
7 |
8 |
9 | ## Code
10 | We especially want to highlight the minimalistic amount of code that is necessary to run & train Paella. The entire code including training, sampling, architecture and utilities can fit in approx. 400 lines of code. We hope to make this method more accessible to more people this way.
11 |
12 | ## Sampling
13 | For sampling you can just take a look at the [sampling.ipynb](https://github.com/delicious-tasty/Paella/blob/main/paella_sampling.ipynb) notebook. :sunglasses:
14 |
15 | ## Train your own Paella
16 | The main file for training will be [paella.py](https://github.com/delicious-tasty/Paella/blob/main/paella.py). You can adjust all [hyperparameters](https://github.com/delicious-tasty/Paella/blob/main/paella.py#L322) to your own needs. During training we use webdataset, but you are free to replace that with your own custom dataloader. Just change the line on 119 in [paella.py](https://github.com/delicious-tasty/Paella/blob/main/paella.py#L119) to point to your own dataloader. Make sure it returns a tuple of ```(images, captions)``` where ```images``` is a ```torch.Tensor``` of shape ```batch_size x channels x height x width``` and captions is a ```List``` of length ```batch_size```. Now decide if you want to finetune Paella or start a new training from scratch:
17 | ### From Scratch
18 | ```
19 | python3 paella.py
20 | ```
21 | ### Finetune
22 | If you want to finetune you first need to download the [latest checkpoint and it's optimizer state](epic_download_link.py), set the [finetune hyperparameter](https://github.com/delicious-tasty/Paella/blob/main/paella.py#L254) to ```True``` and create a folder ```models/``` and move both checkpoints to this folder. After that you can also just run:
23 | ```
24 | python3 paella.py
25 | ```
26 |
--------------------------------------------------------------------------------
/laion30k/datacheck.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import torch
4 | import torchvision
5 | import webdataset as wds
6 | from matplotlib import pyplot as plt
7 | from torch.utils.data import DataLoader
8 | from webdataset.handlers import warn_and_continue
9 |
10 |
11 | def clean_caption(caption):
12 | caption = re.sub(" +", " ", caption)
13 | if caption[0] == "\"" or caption[0] == "'":
14 | caption = caption[1:]
15 | if caption[-1] == "\"" or caption[-1] == "'":
16 | caption = caption[:-1]
17 | return caption
18 |
19 |
20 | def filter_captions(caption):
21 | possible_url_hints = ["www.", ".com", "http"]
22 | forbidden_characters = ["-", "_", ":", ";", "(", ")", "/", "%", "|", "?"]
23 | forbidden_words = ["download", "interior", "kitchen", "chair", "getty", "how", "what", "when", "why", "laminate", "furniture", "hair", "dress", "clothing"]
24 | if len(caption.split(" ")) < 2:
25 | print(False)
26 | return False
27 | if not all([False if i in caption else True for i in forbidden_characters]):
28 | print(False)
29 | return False
30 | if len(caption) > 150:
31 | print(False)
32 | return False
33 | if not all(ord(c) < 128 for c in caption):
34 | return False
35 | if not all([False if i in caption else True for i in possible_url_hints]):
36 | return False
37 | if any(char.isdigit() for char in caption):
38 | return False
39 | if not all([False if i in caption.lower() else True for i in forbidden_words]):
40 | return False
41 | return True
42 |
43 |
44 | class ProcessDataV2:
45 | def __init__(self,):
46 | self.transforms = torchvision.transforms.Compose([
47 | torchvision.transforms.ToTensor(),
48 | torchvision.transforms.Resize(256),
49 | torchvision.transforms.CenterCrop(256),
50 | ])
51 |
52 | def __call__(self, data):
53 | data["jpg"] = self.transforms(data["jpg"])
54 | return data
55 |
56 |
57 | def collate(batch):
58 | images = torch.stack([i[0] for i in batch], dim=0)
59 | json_file = [i[1] for i in batch]
60 | captions = [i[2] for i in batch]
61 | return [images, json_file, captions]
62 |
63 |
64 | def get_dataloader_new(path):
65 | dataset = wds.WebDataset(path, resampled=True, handler=warn_and_continue).decode("rgb", handler=warn_and_continue).map(
66 | ProcessDataV2(), handler=warn_and_continue).to_tuple("jpg", "json", "txt", handler=warn_and_continue).shuffle(1000, handler=warn_and_continue)
67 | dataloader = DataLoader(dataset, batch_size=10, collate_fn=collate)
68 | return dataloader
69 |
70 | dataset_length = 30_000
71 | dataset_path = "30k"
72 | print(os.getcwd())
73 | os.makedirs(dataset_path, exist_ok=True)
74 | # path = "file:000069.tar"
75 | path = "path_to_laion_aesthetic_dataset"
76 | dataloader = get_dataloader_new(path)
77 | idx = 0
78 |
79 |
80 | for _, (images, json_files, captions) in enumerate(dataloader):
81 | if idx < dataset_length:
82 | f = [i for i, json_file in enumerate(json_files) if json_file["AESTHETIC_SCORE"] > 6.0]
83 | if f:
84 | print(f)
85 | f = [i for i in f if filter_captions(captions[i])]
86 | captions = [clean_caption(captions[i]) for i in f if filter_captions(captions[i])]
87 | if f:
88 | print(captions)
89 | aesthetic_images = images[f]
90 | for image, caption in zip(aesthetic_images, captions):
91 | torchvision.utils.save_image(image, os.path.join(dataset_path, f"{idx}.jpg"))
92 | open(os.path.join(dataset_path, f"{idx}.txt"), "w").write(caption)
93 | idx += 1
94 | # plt.figure(figsize=(32, 32))
95 | # plt.imshow(torch.cat([
96 | # torch.cat([i for i in aesthetic_images.cpu()], dim=-1),
97 | # ], dim=-2).permute(1, 2, 0).cpu())
98 | # plt.show()
99 |
100 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import webdataset as wds
4 | from torch.utils.data import DataLoader
5 | from webdataset.handlers import warn_and_continue
6 |
7 |
8 | def encode(vq, x):
9 | return vq.model.encode((2 * x - 1))[-1][-1]
10 |
11 |
12 | def decode(vq, z):
13 | return vq.decode(z.view(z.shape[0], -1))
14 |
15 |
16 | def log(t, eps=1e-20):
17 | return torch.log(t + eps)
18 |
19 |
20 | def gumbel_noise(t):
21 | noise = torch.zeros_like(t).uniform_(0, 1)
22 | return -log(-log(noise))
23 |
24 |
25 | def gumbel_sample(t, temperature=1., dim=-1):
26 | return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
27 |
28 |
29 | def sample(model, c, x=None, mask=None, T=12, size=(32, 32), starting_t=0, temp_range=[1.0, 1.0], typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=-1, renoise_steps=11, renoise_mode='start'):
30 | with torch.inference_mode():
31 | r_range = torch.linspace(0, 1, T+1)[:-1][:, None].expand(-1, c.size(0)).to(c.device)
32 | temperatures = torch.linspace(temp_range[0], temp_range[1], T)
33 | if x is None:
34 | x = torch.randint(0, model.num_labels, size=(c.size(0), *size), device=c.device)
35 | elif mask is not None:
36 | noise = torch.randint(0, model.num_labels, size=(c.size(0), *size), device=c.device)
37 | x = noise * mask + (1-mask) * x
38 | init_x = x.clone()
39 | for i in range(starting_t, T):
40 | if renoise_mode == 'prev':
41 | prev_x = x.clone()
42 | r, temp = r_range[i], temperatures[i]
43 | logits = model(x, c, r)
44 | if classifier_free_scale >= 0:
45 | logits_uncond = model(x, torch.zeros_like(c), r)
46 | logits = torch.lerp(logits_uncond, logits, classifier_free_scale)
47 | x = logits
48 | x_flat = x.permute(0, 2, 3, 1).reshape(-1, x.size(1))
49 | if typical_filtering:
50 | x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)
51 | x_flat_norm_p = torch.exp(x_flat_norm)
52 | entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True)
53 |
54 | c_flat_shifted = torch.abs((-x_flat_norm) - entropy)
55 | c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)
56 | x_flat_cumsum = x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)
57 |
58 | last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)
59 | sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(1, last_ind.view(-1, 1))
60 | if typical_min_tokens > 1:
61 | sorted_indices_to_remove[..., :typical_min_tokens] = 0
62 | indices_to_remove = sorted_indices_to_remove.scatter(1, x_flat_indices, sorted_indices_to_remove)
63 | x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf"))
64 | # x_flat = torch.multinomial(x_flat.div(temp).softmax(-1), num_samples=1)[:, 0]
65 | x_flat = gumbel_sample(x_flat, temperature=temp)
66 | x = x_flat.view(x.size(0), *x.shape[2:])
67 | if mask is not None:
68 | x = x * mask + (1-mask) * init_x
69 | if i < renoise_steps:
70 | if renoise_mode == 'start':
71 | x, _ = model.add_noise(x, r_range[i+1], random_x=init_x)
72 | elif renoise_mode == 'prev':
73 | x, _ = model.add_noise(x, r_range[i+1], random_x=prev_x)
74 | else: # 'rand'
75 | x, _ = model.add_noise(x, r_range[i+1])
76 | return x.detach()
77 |
78 |
79 | class ProcessData:
80 | def __init__(self, image_size=256):
81 | self.transforms = torchvision.transforms.Compose([
82 | torchvision.transforms.ToTensor(),
83 | torchvision.transforms.Resize(image_size),
84 | torchvision.transforms.RandomCrop(image_size),
85 | ])
86 |
87 | def __call__(self, data):
88 | data["jpg"] = self.transforms(data["jpg"])
89 | return data
90 |
91 |
92 | def collate(batch):
93 | images = torch.stack([i[0] for i in batch], dim=0)
94 | captions = [i[1] for i in batch]
95 | return [images, captions]
96 |
97 |
98 | def get_dataloader(args):
99 | dataset = wds.WebDataset(args.dataset_path, resampled=True, handler=warn_and_continue).decode("rgb", handler=warn_and_continue).map(
100 | ProcessData(args.image_size), handler=warn_and_continue).to_tuple("jpg", "txt", handler=warn_and_continue).shuffle(690, handler=warn_and_continue)
101 | dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=collate)
102 | return dataloader
103 |
--------------------------------------------------------------------------------
/paella_minimal.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import torch
4 | from torch import nn, optim
5 | import torchvision
6 | from tqdm import tqdm
7 | import numpy as np
8 | import torch.multiprocessing as mp
9 | import torch.distributed as dist
10 | from torch.nn.parallel import DistributedDataParallel
11 | from modules import DenoiseUNet
12 | from utils import get_dataloader, sample, encode, decode
13 | import open_clip
14 | from open_clip import tokenizer
15 | from rudalle import get_vae
16 |
17 |
18 | def train(proc_id, args):
19 | parallel = len(args.devices) > 1
20 | device = torch.device(proc_id)
21 |
22 | vqmodel = get_vae().to(device)
23 | vqmodel.eval().requires_grad_(False)
24 |
25 | if parallel:
26 | torch.cuda.set_device(proc_id)
27 | torch.backends.cudnn.benchmark = True
28 | dist.init_process_group(backend="nccl", init_method="file://dist_file", world_size=args.n_nodes * len(args.devices), rank=proc_id + len(args.devices) * args.node_id)
29 | torch.set_num_threads(6)
30 |
31 | model = DenoiseUNet(num_labels=args.num_codebook_vectors, c_clip=1024).to(device)
32 |
33 | clip_model, _, _ = open_clip.create_model_and_transforms('ViT-g-14', pretrained='laion2b_s12b_b42k')
34 | del clip_model.visual
35 | clip_model = clip_model.to(device).eval().requires_grad_(False)
36 |
37 | lr = 3e-4
38 | dataset = get_dataloader(args)
39 | optimizer = optim.AdamW(model.parameters(), lr=lr)
40 | criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
41 |
42 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=math.ceil(1000 / args.accum_grad), epochs=600, pct_start=30 / 300, div_factor=25, final_div_factor=1 / 25, anneal_strategy='linear')
43 |
44 | losses, accuracies = [], []
45 | start_step, total_loss, total_acc = 0, 0, 0
46 |
47 | if parallel:
48 | model = DistributedDataParallel(model, device_ids=[device], output_device=device)
49 |
50 | pbar = tqdm(enumerate(dataset, start=start_step), total=args.total_steps, initial=start_step) if args.node_id == 0 and proc_id == 0 else enumerate(dataset, start=start_step)
51 | model.train()
52 | for step, (images, captions) in pbar:
53 | images = images.to(device)
54 | with torch.no_grad():
55 | image_indices = encode(vqmodel, images) # encode images (batch_size x 3 x 256 x 256) to tokens (batch_size x 32 x 32)
56 | r = torch.rand(images.size(0), device=device) # generate random timesteps
57 | noised_indices, mask = model.module.add_noise(image_indices, r) # noise the tokens according to the timesteps
58 |
59 | if np.random.rand() < 0.1: # 10% of the times -> unconditional training for classifier-free-guidance
60 | text_embeddings = images.new_zeros(images.size(0), 1024)
61 | else:
62 | text_tokens = tokenizer.tokenize(captions)
63 | text_tokens = text_tokens.to(device)
64 | text_embeddings = clip_model.encode_text(text_tokens).float() # text embeddings (batch_size x 1024)
65 |
66 | pred = model(noised_indices, text_embeddings, r) # predict denoised tokens (batch_size x 32 x 32 x 8192
67 | loss = criterion(pred, image_indices) # cross entropy loss
68 | loss_adjusted = loss / args.accum_grad
69 |
70 | loss_adjusted.backward()
71 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 5).item()
72 | if (step + 1) % args.accum_grad == 0:
73 | optimizer.step()
74 | scheduler.step()
75 | optimizer.zero_grad()
76 |
77 | acc = (pred.argmax(1) == image_indices).float().mean()
78 |
79 | total_loss += loss.item()
80 | total_acc += acc.item()
81 |
82 | if not proc_id and args.node_id == 0:
83 | pbar.set_postfix({"loss": total_loss / (step + 1), "acc": total_acc / (step + 1), "curr_loss": loss.item(), "curr_acc": acc.item(), "ppx": np.exp(total_loss / (step + 1)), "lr": optimizer.param_groups[0]['lr'], "grad_norm": grad_norm})
84 |
85 | if args.node_id == 0 and proc_id == 0 and step % args.log_period == 0:
86 | print(f"Step {step} - loss {total_loss / (step + 1)} - acc {total_acc / (step + 1)} - ppx {np.exp(total_loss / (step + 1))}")
87 |
88 | losses.append(total_loss / (step + 1))
89 | accuracies.append(total_acc / (step + 1))
90 |
91 | model.eval()
92 | with torch.no_grad():
93 | sampled = sample(model.module, c=text_embeddings)[-1]
94 | sampled = decode(vqmodel, sampled)
95 |
96 | model.train()
97 | log_images = torch.cat([torch.cat([i for i in sampled.cpu()], dim=-1)], dim=-2)
98 | torchvision.utils.save_image(log_images, os.path.join(f"results/{args.run_name}", f"{step:03d}.png"))
99 |
100 | del sampled
101 |
102 | torch.save(model.module.state_dict(), f"models/{args.run_name}/model.pt")
103 | torch.save(optimizer.state_dict(), f"models/{args.run_name}/optim.pt")
104 | torch.save({'step': step, 'losses': losses, 'accuracies': accuracies}, f"results/{args.run_name}/log.pt")
105 |
106 | del images, image_indices, r, text_embeddings
107 | del noised_indices, mask, pred, loss, loss_adjusted, acc
108 |
109 |
110 | def launch(args):
111 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(d) for d in args.devices])
112 | if len(args.devices) == 1:
113 | train(0, args)
114 | else:
115 | os.environ["MASTER_ADDR"] = "localhost"
116 | os.environ["MASTER_PORT"] = "33751"
117 | p = mp.spawn(train, nprocs=len(args.devices), args=(args,))
118 |
119 |
120 | if __name__ == '__main__':
121 | import argparse
122 | parser = argparse.ArgumentParser()
123 | args = parser.parse_args()
124 | args.run_name = "Paella_f8_8192"
125 | args.dataset_type = "webdataset"
126 | args.total_steps = 501_000
127 | args.batch_size = 22
128 | args.image_size = 256
129 | args.num_workers = 10
130 | args.log_period = 5000
131 | args.accum_grad = 1
132 | args.num_codebook_vectors = 8192
133 |
134 | args.n_nodes = 8
135 | args.node_id = int(os.environ["SLURM_PROCID"])
136 | args.devices = [0, 1, 2, 3, 4, 5, 6, 7]
137 |
138 | args.dataset_path = ""
139 | print("Launching with args: ", args)
140 | launch(
141 | args
142 | )
143 |
--------------------------------------------------------------------------------
/modules.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class ModulatedLayerNorm(nn.Module):
8 | def __init__(self, num_features, eps=1e-6, channels_first=True):
9 | super().__init__()
10 | self.ln = nn.LayerNorm(num_features, eps=eps)
11 | self.gamma = nn.Parameter(torch.randn(1, 1, 1))
12 | self.beta = nn.Parameter(torch.randn(1, 1, 1))
13 | self.channels_first = channels_first
14 |
15 | def forward(self, x, w=None):
16 | x = x.permute(0, 2, 3, 1) if self.channels_first else x
17 | if w is None:
18 | x = self.ln(x)
19 | else:
20 | x = self.gamma * w * self.ln(x) + self.beta * w
21 | x = x.permute(0, 3, 1, 2) if self.channels_first else x
22 | return x
23 |
24 |
25 | class ResBlock(nn.Module):
26 | def __init__(self, c, c_hidden, c_cond=0, c_skip=0, scaler=None, layer_scale_init_value=1e-6):
27 | super().__init__()
28 | self.depthwise = nn.Sequential(
29 | nn.ReflectionPad2d(1),
30 | nn.Conv2d(c, c, kernel_size=3, groups=c)
31 | )
32 | self.ln = ModulatedLayerNorm(c, channels_first=False)
33 | self.channelwise = nn.Sequential(
34 | nn.Linear(c+c_skip, c_hidden),
35 | nn.GELU(),
36 | nn.Linear(c_hidden, c),
37 | )
38 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(c), requires_grad=True) if layer_scale_init_value > 0 else None
39 | self.scaler = scaler
40 | if c_cond > 0:
41 | self.cond_mapper = nn.Linear(c_cond, c)
42 |
43 | def forward(self, x, s=None, skip=None):
44 | res = x
45 | x = self.depthwise(x)
46 | if s is not None:
47 | s = self.cond_mapper(s.permute(0, 2, 3, 1))
48 | if s.size(1) == s.size(2) == 1:
49 | s = s.expand(-1, x.size(2), x.size(3), -1)
50 | x = self.ln(x.permute(0, 2, 3, 1), s)
51 | if skip is not None:
52 | x = torch.cat([x, skip.permute(0, 2, 3, 1)], dim=-1)
53 | x = self.channelwise(x)
54 | x = self.gamma * x if self.gamma is not None else x
55 | x = res + x.permute(0, 3, 1, 2)
56 | if self.scaler is not None:
57 | x = self.scaler(x)
58 | return x
59 |
60 |
61 | class DenoiseUNet(nn.Module):
62 | def __init__(self, num_labels, c_hidden=1280, c_clip=1024, c_r=64, down_levels=[4, 8, 16], up_levels=[16, 8, 4]):
63 | super().__init__()
64 | self.num_labels = num_labels
65 | self.c_r = c_r
66 | c_levels = [c_hidden // (2 ** i) for i in reversed(range(len(down_levels)))]
67 | self.embedding = nn.Embedding(num_labels, c_levels[0])
68 |
69 | # DOWN BLOCKS
70 | self.down_blocks = nn.ModuleList()
71 | for i, num_blocks in enumerate(down_levels):
72 | blocks = []
73 | if i > 0:
74 | blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
75 | for j in range(num_blocks):
76 | block = ResBlock(c_levels[i], c_levels[i] * 4, c_clip + c_r)
77 | block.channelwise[-1].weight.data *= np.sqrt(1 / sum(down_levels))
78 | blocks.append(block)
79 | self.down_blocks.append(nn.ModuleList(blocks))
80 |
81 | # UP BLOCKS
82 | self.up_blocks = nn.ModuleList()
83 | for i, num_blocks in enumerate(up_levels):
84 | blocks = []
85 | for j in range(num_blocks):
86 | block = ResBlock(c_levels[len(c_levels) - 1 - i], c_levels[len(c_levels) - 1 - i] * 4, c_clip + c_r,
87 | c_levels[len(c_levels) - 1 - i] if (j == 0 and i > 0) else 0)
88 | block.channelwise[-1].weight.data *= np.sqrt(1 / sum(up_levels))
89 | blocks.append(block)
90 | if i < len(up_levels) - 1:
91 | blocks.append(
92 | nn.ConvTranspose2d(c_levels[len(c_levels) - 1 - i], c_levels[len(c_levels) - 2 - i], kernel_size=4, stride=2, padding=1))
93 | self.up_blocks.append(nn.ModuleList(blocks))
94 |
95 | self.clf = nn.Conv2d(c_levels[0], num_labels, kernel_size=1)
96 |
97 | def gamma(self, r):
98 | return (r * torch.pi / 2).cos()
99 |
100 | def add_noise(self, x, r, random_x=None):
101 | r = self.gamma(r)[:, None, None]
102 | mask = torch.bernoulli(r * torch.ones_like(x), )
103 | mask = mask.round().long()
104 | if random_x is None:
105 | random_x = torch.randint_like(x, 0, self.num_labels)
106 | x = x * (1 - mask) + random_x * mask
107 | return x, mask
108 |
109 | def gen_r_embedding(self, r, max_positions=10000):
110 | r = self.gamma(r) * max_positions
111 | half_dim = self.c_r // 2
112 | emb = math.log(max_positions) / (half_dim - 1)
113 | emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
114 | emb = r[:, None] * emb[None, :]
115 | emb = torch.cat([emb.sin(), emb.cos()], dim=1)
116 | if self.c_r % 2 == 1: # zero pad
117 | emb = nn.functional.pad(emb, (0, 1), mode='constant')
118 | return emb
119 |
120 | def _down_encode_(self, x, s):
121 | level_outputs = []
122 | for i, blocks in enumerate(self.down_blocks):
123 | for block in blocks:
124 | if isinstance(block, ResBlock):
125 | x = block(x, s)
126 | else:
127 | x = block(x)
128 | level_outputs.insert(0, x)
129 | return level_outputs
130 |
131 | def _up_decode(self, level_outputs, s):
132 | x = level_outputs[0]
133 | for i, blocks in enumerate(self.up_blocks):
134 | for j, block in enumerate(blocks):
135 | if isinstance(block, ResBlock):
136 | if i > 0 and j == 0:
137 | x = block(x, s, level_outputs[i])
138 | else:
139 | x = block(x, s)
140 | else:
141 | x = block(x)
142 | return x
143 |
144 | def forward(self, x, c, r): # r is a uniform value between 0 and 1
145 | r_embed = self.gen_r_embedding(r)
146 | x = self.embedding(x).permute(0, 3, 1, 2)
147 | s = torch.cat([c, r_embed], dim=-1)[:, :, None, None]
148 | level_outputs = self._down_encode_(x, s)
149 | x = self._up_decode(level_outputs, s)
150 | x = self.clf(x)
151 | return x
152 |
153 |
154 | if __name__ == '__main__':
155 | device = "cuda"
156 | model = DenoiseUNet(1024).to(device)
157 | print(sum([p.numel() for p in model.parameters()]))
158 | x = torch.randint(0, 1024, (1, 32, 32)).long().to(device)
159 | c = torch.randn((1, 1024)).to(device)
160 | r = torch.rand(1).to(device)
161 | model(x, c, r)
162 |
163 |
--------------------------------------------------------------------------------
/paella.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import torch
4 | from torch.utils.data import TensorDataset, DataLoader
5 | import wandb
6 | from torch import nn, optim
7 | import torchvision
8 | from tqdm import tqdm
9 | import time
10 | import numpy as np
11 | import torch.multiprocessing as mp
12 | import torch.distributed as dist
13 | from torch.nn.parallel import DistributedDataParallel
14 | from modules import DenoiseUNet
15 | from utils import get_dataloader, sample, encode, decode
16 | import open_clip
17 | from open_clip import tokenizer
18 | from rudalle import get_vae
19 |
20 |
21 | def train(proc_id, args):
22 | if os.path.exists(f"results/{args.run_name}/log.pt"):
23 | resume = True
24 | else:
25 | resume = False
26 | if not proc_id and args.node_id == 0:
27 | if resume:
28 | wandb.init(project="project", name=args.run_name, entity="your_entity", config=vars(args))
29 | else:
30 | wandb.init(project="project", name=args.run_name, entity="your_entity", config=vars(args))
31 | print(f"Starting run '{args.run_name}'....")
32 | print(f"Batch Size check: {args.n_nodes * args.batch_size * args.accum_grad * len(args.devices)}")
33 | parallel = len(args.devices) > 1
34 | device = torch.device(proc_id)
35 |
36 | vqmodel = get_vae().to(device)
37 | vqmodel.eval().requires_grad_(False)
38 |
39 | if parallel:
40 | torch.cuda.set_device(proc_id)
41 | torch.backends.cudnn.benchmark = True
42 | dist.init_process_group(backend="nccl", init_method="file://dist_file",
43 | world_size=args.n_nodes * len(args.devices),
44 | rank=proc_id + len(args.devices) * args.node_id)
45 | torch.set_num_threads(6)
46 |
47 | model = DenoiseUNet(num_labels=args.num_codebook_vectors, c_clip=1024).to(device)
48 |
49 | if not proc_id and args.node_id == 0:
50 | print(f"Number of Parameters: {sum([p.numel() for p in model.parameters()])}")
51 |
52 | clip_model, _, _ = open_clip.create_model_and_transforms('ViT-g-14', pretrained='laion2b_s12b_b42k')
53 | del clip_model.visual
54 | clip_model = clip_model.to(device).eval().requires_grad_(False)
55 |
56 | lr = 3e-4
57 | dataset = get_dataloader(args)
58 | optimizer = optim.AdamW(model.parameters(), lr=lr)
59 | criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
60 |
61 | if not proc_id and args.node_id == 0:
62 | wandb.watch(model)
63 | os.makedirs(f"results/{args.run_name}", exist_ok=True)
64 | os.makedirs(f"models/{args.run_name}", exist_ok=True)
65 |
66 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr,
67 | steps_per_epoch=math.ceil(1000 / args.accum_grad),
68 | epochs=600, pct_start=30 / 300, div_factor=25,
69 | final_div_factor=1 / 25, anneal_strategy='linear')
70 |
71 | if resume:
72 | if not proc_id and args.node_id == 0:
73 | print("Loading last checkpoint....")
74 | logs = torch.load(f"results/{args.run_name}/log.pt")
75 | start_step = logs["step"] + 1
76 | losses = logs["losses"]
77 | accuracies = logs["accuracies"]
78 | total_loss, total_acc = losses[-1] * start_step, accuracies[-1] * start_step
79 | model.load_state_dict(torch.load(f"models/{args.run_name}/model.pt", map_location=device))
80 | if not proc_id and args.node_id == 0:
81 | print("Loaded model.")
82 | opt_state = torch.load(f"models/{args.run_name}/optim.pt", map_location=device)
83 | last_lr = opt_state["param_groups"][0]["lr"]
84 | with torch.no_grad():
85 | for _ in range(logs["step"]):
86 | scheduler.step()
87 | if not proc_id and args.node_id == 0:
88 | print(f"Initialized scheduler")
89 | print(f"Sanity check => Last-LR: {last_lr} == Current-LR: {optimizer.param_groups[0]['lr']} -> {last_lr == optimizer.param_groups[0]['lr']}")
90 | optimizer.load_state_dict(opt_state)
91 | del opt_state
92 | else:
93 | losses = []
94 | accuracies = []
95 | start_step, total_loss, total_acc = 0, 0, 0
96 |
97 | if parallel:
98 | model = DistributedDataParallel(model, device_ids=[device], output_device=device)
99 |
100 | pbar = tqdm(enumerate(dataset, start=start_step), total=args.total_steps, initial=start_step) if args.node_id == 0 and proc_id == 0 else enumerate(dataset, start=start_step)
101 | model.train()
102 | for step, (images, captions) in pbar:
103 | images = images.to(device)
104 | with torch.no_grad():
105 | image_indices = encode(vqmodel, images)
106 | r = torch.rand(images.size(0), device=device)
107 | noised_indices, mask = model.module.add_noise(image_indices, r)
108 |
109 | if np.random.rand() < 0.1: # 10% of the times -> unconditional training for classifier-free-guidance
110 | text_embeddings = images.new_zeros(images.size(0), 1024)
111 | else:
112 | text_tokens = tokenizer.tokenize(captions)
113 | text_tokens = text_tokens.to(device)
114 | text_embeddings = clip_model.encode_text(text_tokens).float()
115 |
116 | pred = model(noised_indices, text_embeddings, r)
117 | loss = criterion(pred, image_indices)
118 | loss_adjusted = loss / args.accum_grad
119 |
120 | loss_adjusted.backward()
121 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 5).item()
122 | if (step + 1) % args.accum_grad == 0:
123 | optimizer.step()
124 | scheduler.step()
125 | optimizer.zero_grad()
126 |
127 | acc = (pred.argmax(1) == image_indices).float()
128 | acc = acc.mean()
129 |
130 | total_loss += loss.item()
131 | total_acc += acc.item()
132 |
133 | if not proc_id and args.node_id == 0:
134 | log = {
135 | "loss": total_loss / (step + 1),
136 | "acc": total_acc / (step + 1),
137 | "curr_loss": loss.item(),
138 | "curr_acc": acc.item(),
139 | "ppx": np.exp(total_loss / (step + 1)),
140 | "lr": optimizer.param_groups[0]['lr'],
141 | "grad_norm": grad_norm
142 | }
143 | pbar.set_postfix(log)
144 | wandb.log(log)
145 |
146 | if args.node_id == 0 and proc_id == 0 and step % args.log_period == 0:
147 | print(f"Step {step} - loss {total_loss / (step + 1)} - acc {total_acc / (step + 1)} - ppx {np.exp(total_loss / (step + 1))}")
148 |
149 | losses.append(total_loss / (step + 1))
150 | accuracies.append(total_acc / (step + 1))
151 |
152 | model.eval()
153 | with torch.no_grad():
154 | n = 1
155 | images = images[:10]
156 | image_indices = image_indices[:10]
157 | captions = captions[:10]
158 | text_embeddings = text_embeddings[:10]
159 | sampled = sample(model.module, c=text_embeddings)[-1]
160 | sampled = decode(vqmodel, sampled)
161 | recon_images = decode(vqmodel, image_indices)
162 |
163 | if args.log_captions:
164 | cool_captions_data = torch.load("cool_captions.pth")
165 | cool_captions_text = cool_captions_data["captions"]
166 |
167 | text_tokens = tokenizer.tokenize(cool_captions_text)
168 | text_tokens = text_tokens.to(device)
169 | cool_captions_embeddings = clip_model.encode_text(text_tokens).float()
170 |
171 | cool_captions = DataLoader(TensorDataset(cool_captions_embeddings.repeat_interleave(n, dim=0)), batch_size=11)
172 | cool_captions_sampled = []
173 | cool_captions_sampled_ema = []
174 | st = time.time()
175 | for caption_embedding in cool_captions:
176 | caption_embedding = caption_embedding[0].float().to(device)
177 | sampled_text = sample(model.module, c=caption_embedding)[-1]
178 | sampled_text = decode(vqmodel, sampled_text)
179 | sampled_text_ema = decode(vqmodel, sampled_text_ema)
180 | for s, t in zip(sampled_text, sampled_text_ema):
181 | cool_captions_sampled.append(s.cpu())
182 | cool_captions_sampled_ema.append(t.cpu())
183 | print(f"Took {time.time() - st} seconds to sample {len(cool_captions_text) * 2} captions.")
184 |
185 | cool_captions_sampled = torch.stack(cool_captions_sampled)
186 | torchvision.utils.save_image(
187 | torchvision.utils.make_grid(cool_captions_sampled, nrow=11),
188 | os.path.join(f"results/{args.run_name}", f"cool_captions_{step:03d}.png")
189 | )
190 |
191 | cool_captions_sampled_ema = torch.stack(cool_captions_sampled_ema)
192 | torchvision.utils.save_image(
193 | torchvision.utils.make_grid(cool_captions_sampled_ema, nrow=11),
194 | os.path.join(f"results/{args.run_name}", f"cool_captions_{step:03d}_ema.png")
195 | )
196 |
197 | log_images = torch.cat([
198 | torch.cat([i for i in sampled.cpu()], dim=-1),
199 | ], dim=-2)
200 |
201 | model.train()
202 |
203 | torchvision.utils.save_image(log_images, os.path.join(f"results/{args.run_name}", f"{step:03d}.png"))
204 |
205 | log_data = [[captions[i]] + [wandb.Image(sampled[i])] + [wandb.Image(images[i])] + [wandb.Image(recon_images[i])] for i in range(len(captions))]
206 | log_table = wandb.Table(data=log_data, columns=["Caption", "Image", "EMA", "Orig", "Recon"])
207 | wandb.log({"Log": log_table})
208 |
209 | if args.log_captions:
210 | log_data_cool = [[cool_captions_text[i]] + [wandb.Image(cool_captions_sampled[i])] + [wandb.Image(cool_captions_sampled_ema[i])] for i in range(len(cool_captions_text))]
211 | log_table_cool = wandb.Table(data=log_data_cool, columns=["Caption", "Image", "EMA Image"])
212 | wandb.log({"Log Cool": log_table_cool})
213 | del sampled_text, log_data_cool
214 |
215 | del sampled, log_data
216 |
217 | if step % args.extra_ckpt == 0:
218 | torch.save(model.module.state_dict(), f"models/{args.run_name}/model_{step}.pt")
219 | torch.save(optimizer.state_dict(), f"models/{args.run_name}/model_{step}_optim.pt")
220 | torch.save(model.module.state_dict(), f"models/{args.run_name}/model.pt")
221 | torch.save(optimizer.state_dict(), f"models/{args.run_name}/optim.pt")
222 | torch.save({'step': step, 'losses': losses, 'accuracies': accuracies}, f"results/{args.run_name}/log.pt")
223 |
224 | del images, image_indices, r, text_embeddings
225 | del noised_indices, mask, pred, loss, loss_adjusted, acc
226 |
227 |
228 | def launch(args):
229 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(d) for d in args.devices])
230 | if len(args.devices) == 1:
231 | train(0, args)
232 | else:
233 | os.environ["MASTER_ADDR"] = "localhost"
234 | os.environ["MASTER_PORT"] = "33751"
235 | p = mp.spawn(train, nprocs=len(args.devices), args=(args,))
236 |
237 |
238 | if __name__ == '__main__':
239 | import argparse
240 | parser = argparse.ArgumentParser()
241 | args = parser.parse_args()
242 | args.run_name = "run_name"
243 | args.model = "UNet"
244 | args.dataset_type = "webdataset"
245 | args.total_steps = 501_000
246 | args.batch_size = 22
247 | args.image_size = 256
248 | args.num_workers = 10
249 | args.log_period = 5000
250 | args.extra_ckpt = 50_000
251 | args.accum_grad = 1
252 | args.num_codebook_vectors = 8192
253 | args.log_captions = True
254 | args.finetune = False
255 |
256 | args.n_nodes = 8
257 | args.node_id = int(os.environ["SLURM_PROCID"])
258 | args.devices = [0, 1, 2, 3, 4, 5, 6, 7]
259 |
260 | args.dataset_path = ""
261 | print("Launching with args: ", args)
262 | launch(
263 | args
264 | )
265 |
--------------------------------------------------------------------------------
/evaluation/evaluation_generation.py:
--------------------------------------------------------------------------------
1 | import json
2 | import queue
3 | import torch.multiprocessing as mp
4 | from collections import OrderedDict
5 | import os
6 | import time
7 | import torch
8 | import pandas as pd
9 | from itertools import product
10 | import torchvision
11 | import numpy as np
12 | import open_clip
13 | from open_clip import tokenizer
14 | from rudalle import get_vae
15 | from einops import rearrange
16 | import tensorflow.compat.v1 as tf
17 | from PIL import Image
18 | from evaluator import Evaluator
19 | from modules import DenoiseUNet
20 |
21 |
22 | def chunk(lst, n):
23 | return [lst[i:i + n] for i in range(0, len(lst), n)]
24 |
25 |
26 | def save_images_npz(path, save_path):
27 | arr = []
28 | base_shape = None
29 | for item in os.listdir(path):
30 | if os.path.isfile(os.path.join(path, item)) and item.endswith(".jpg"):
31 | img = Image.open(os.path.join(path, item))
32 | img = np.array(img.resize((256,256), Image.ANTIALIAS))
33 | if base_shape is None:
34 | base_shape = img.shape
35 | try:
36 | if img.shape == base_shape:
37 | arr.append(img)
38 | except Exception as e:
39 | print(e)
40 | continue
41 | arr = np.stack(arr)
42 | print(arr.shape)
43 | np.savez(save_path, arr)
44 |
45 |
46 | def log(t, eps=1e-20):
47 | return torch.log(t + eps)
48 |
49 |
50 | def gumbel_noise(t):
51 | noise = torch.zeros_like(t).uniform_(0, 1)
52 | return -log(-log(noise))
53 |
54 |
55 | def gumbel_sample(t, temperature=1., dim=-1):
56 | return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
57 |
58 |
59 | def sample(model, c, x=None, mask=None, T=12, size=(32, 32), starting_t=0, temp_range=[1.0, 1.0], typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=-1, renoise_steps=11, renoise_mode='start'):
60 | with torch.inference_mode():
61 | r_range = torch.linspace(0, 1, T+1)[:-1][:, None].expand(-1, c.size(0)).to(c.device)
62 | temperatures = torch.linspace(temp_range[0], temp_range[1], T)
63 | if x is None:
64 | x = torch.randint(0, model.num_labels, size=(c.size(0), *size), device=c.device)
65 | if renoise_mode == 'start':
66 | init_x = x.clone()
67 | for i in range(starting_t, T):
68 | if renoise_mode == 'prev':
69 | prev_x = x.clone()
70 | r, temp = r_range[i], temperatures[i]
71 | logits = model(x, c, r)
72 | if classifier_free_scale >= 0:
73 | logits_uncond = model(x, torch.zeros_like(c), r)
74 | logits = torch.lerp(logits_uncond, logits, classifier_free_scale)
75 | x = logits
76 | x_flat = x.permute(0, 2, 3, 1).reshape(-1, x.size(1))
77 | if typical_filtering:
78 | x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)
79 | x_flat_norm_p = torch.exp(x_flat_norm)
80 | entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True)
81 |
82 | c_flat_shifted = torch.abs((-x_flat_norm) - entropy)
83 | c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)
84 | x_flat_cumsum = x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)
85 |
86 | last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)
87 | sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(1, last_ind.view(-1, 1))
88 | if typical_min_tokens > 1:
89 | sorted_indices_to_remove[..., :typical_min_tokens] = 0
90 | indices_to_remove = sorted_indices_to_remove.scatter(1, x_flat_indices, sorted_indices_to_remove)
91 | x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf"))
92 | x_flat = torch.multinomial(x_flat.div(temp).softmax(-1), num_samples=1)[:, 0]
93 | # print(x_flat.shape)
94 | # x_flat = gumbel_sample(x_flat, temperature=temp)
95 | x = x_flat.view(x.size(0), *x.shape[2:])
96 | if i < renoise_steps:
97 | if renoise_mode == 'start':
98 | x, _ = model.add_noise(x, r_range[i+1], random_x=init_x)
99 | elif renoise_mode == 'prev':
100 | x, _ = model.add_noise(x, r_range[i+1], random_x=prev_x)
101 | else: # 'rand'
102 | x, _ = model.add_noise(x, r_range[i+1])
103 | return x.detach()
104 |
105 |
106 | def encode(vq, x):
107 | return vq.encode(x)[-1]
108 |
109 |
110 | def decode(vq, z):
111 | return vq.decode_indices(z)
112 |
113 |
114 | class DatasetWriter:
115 | def __init__(self, date, base_path="/home/data"):
116 | self.date = date
117 | self.base_path = base_path
118 |
119 | def saveimages(self, imgs, captions, **kwargs):
120 | try:
121 | for img, caption in zip(imgs, captions):
122 | caption = caption.replace(" ", "_").replace(".", "")
123 | path = os.path.join(self.base_path, caption + ".jpg")
124 | torchvision.utils.save_image(img, path, **kwargs)
125 | except Exception as e:
126 | print(e)
127 |
128 | def save(self, que):
129 | while True:
130 | try:
131 | data = que.get(True, 1)
132 | except (queue.Empty, FileNotFoundError):
133 | continue
134 | if data is None:
135 | print("Finished")
136 | return
137 | sampled, captions = data["payload"]
138 | self.saveimages(sampled, captions)
139 |
140 |
141 | class Sample:
142 | def __init__(self, date, device, cfg_weight=5, steps=8, typical_filtering=True, batch_size=8, base_path="/home/data", dataset="coco", captions_path="cap.parquet"):
143 | self.date = date
144 | self.cfg_weight = cfg_weight
145 | self.steps = steps
146 | self.typical_filtering = typical_filtering
147 | self.dataset = dataset
148 | self.captions_path = captions_path
149 | self.device = torch.device(device)
150 | self.batch_size = batch_size
151 | self.base_path = base_path
152 | self.path = os.path.join(base_path, f"{steps}_{cfg_weight}_{typical_filtering}")
153 | self.model, self.vqmodel, self.clip, self.clip_preprocess = self.load_models()
154 | self.setup()
155 | self.que = mp.Queue()
156 | mp.Process(target=DatasetWriter(date, base_path=self.path).save, args=(self.que,)).start()
157 |
158 | def setup(self):
159 | if self.dataset == "coco":
160 | self.captions = pd.read_parquet(self.captions_path)["caption"]
161 | elif self.dataset == "laion":
162 | self.captions = pd.read_parquet(self.captions_path)["caption"]
163 | else:
164 | raise ValueError
165 | num_sampled = len(os.listdir(self.base_path))
166 | self.captions = self.captions[num_sampled:]
167 | os.makedirs(self.path, exist_ok=True)
168 | with open(os.path.join(self.path, "log.json"), "w") as f:
169 | json.dump({
170 | "date": self.date,
171 | "cfg": self.cfg_weight,
172 | "steps": self.steps
173 | }, f)
174 |
175 | def load_models(self):
176 | # --- Paella MODEL ---
177 | model_path = f"./models/Paella_f8_8192/model_600000.pt"
178 | state_dict = torch.load(model_path, map_location=self.device)
179 | # new_state_dict = OrderedDict()
180 | # for k, v in state_dict.items():
181 | # name = k[7:] # remove `module.`
182 | # new_state_dict[name] = v
183 | model = DenoiseUNet(num_labels=8192, c_clip=1024).to(self.device)
184 | model.load_state_dict(state_dict)
185 | model.eval().requires_grad_()
186 | # --- VQ MODEL ---
187 | vqmodel = get_vae().to(self.device)
188 | vqmodel.eval().requires_grad_(False)
189 | # --- CLIP MODEL ---
190 | clip_model, _, _ = open_clip.create_model_and_transforms('ViT-g-14', pretrained='laion2b_s12b_b42k', cache_dir="/fsx/mas/.cache")
191 | del clip_model.visual
192 | clip_model = clip_model.to(self.device).eval().requires_grad_(False)
193 | clip_preprocess = torchvision.transforms.Compose([
194 | torchvision.transforms.Resize(224),
195 | torchvision.transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
196 | std=(0.26862954, 0.26130258, 0.27577711)),
197 | ])
198 | return model, vqmodel, clip_model, clip_preprocess
199 |
200 | @torch.no_grad()
201 | def r_decode(self, img_seq, shape=(32, 32)):
202 | img_seq = img_seq.view(img_seq.shape[0], -1)
203 | one_hot_indices = torch.nn.functional.one_hot(img_seq, num_classes=self.vqmodel.num_tokens).float()
204 | z = (one_hot_indices @ self.vqmodel.model.quantize.embed.weight)
205 | z = rearrange(z, 'b (h w) c -> b c h w', h=shape[0], w=shape[1])
206 | img = self.vqmodel.model.decode(z)
207 | img = (img.clamp(-1., 1.) + 1) * 0.5
208 | return img
209 |
210 | def convert_dataset(self):
211 | """
212 | path: base_path + folder + tar_file
213 | """
214 | batch_size = len(self.captions) / 8
215 | latent_shape = (32, 32)
216 | for cap in np.array_split(self.captions, batch_size):
217 | cap = list(cap)
218 | s = time.time()
219 | # print(len(cap))
220 | text = tokenizer.tokenize(cap).to(self.device)
221 | with torch.inference_mode():
222 | with torch.autocast(device_type="cuda"):
223 | clip_embeddings = self.clip.encode_text(text).float()
224 |
225 | sampled = sample(self.model, clip_embeddings, T=self.steps, size=latent_shape, starting_t=0,
226 | temp_range=[1.0, 1.0],
227 | typical_filtering=self.typical_filtering, typical_mass=0.2, typical_min_tokens=1,
228 | classifier_free_scale=self.cfg_weight, renoise_steps=self.steps - 1)
229 | sampled = self.r_decode(sampled, latent_shape)
230 | data = {
231 | "payload": [sampled, cap]
232 | }
233 | self.que.put(data)
234 | # print(f"Sampled {len(cap)} in {time.time() - s} seconds.")
235 |
236 |
237 | if __name__ == '__main__':
238 | mp.set_start_method('spawn')
239 |
240 | date = "f8_600k_no_ema"
241 | dataset = "coco"
242 | base_dir = "/fsx/mas/paella_unet/evaluation"
243 | ref_images = os.path.join(base_dir, f"{dataset}_30k.npz")
244 | ref_captions = os.path.join(base_dir, f"{dataset}_30k.parquet")
245 | base_path = os.path.join(base_dir, date)
246 | os.makedirs(base_path, exist_ok=True)
247 | devices = [0, 1, 2, 3, 4, 5, 6, 7]
248 | cfgs = [3, 4, 5] # [3, 4, 5, 8]
249 | steps = [12] # [6, 8, 10, 12]
250 | typical_filtering = [True, False]
251 |
252 | combinations = iter(list(product(steps, cfgs, typical_filtering)))
253 | # chunked_combinations = chunk(combinations, n=len(devices))
254 |
255 | try:
256 | while True:
257 | processes = []
258 | for proc_id in devices:
259 | steps, cfg_weight, typical_filtering = next(combinations)
260 | while os.path.exists(os.path.join(base_path, f"{steps}_{cfg_weight}_{typical_filtering}")):
261 | print(os.path.join(base_path, f"{steps}_{cfg_weight}_{typical_filtering}") + " already done. skipping....")
262 | steps, cfg_weight, typical_filtering = next(combinations)
263 | print(f"Starting sampling with steps={steps}, cfg_weight={cfg_weight}.")
264 | conv = Sample(date, proc_id, steps=steps, cfg_weight=cfg_weight, typical_filtering=typical_filtering, batch_size=8, base_path=base_path, dataset=dataset, captions_path=ref_captions)
265 | processes.append(mp.Process(target=conv.convert_dataset))
266 | processes[proc_id].start()
267 | for p in processes:
268 | p.join()
269 | except StopIteration:
270 | if len(processes) > 0:
271 | for p in processes:
272 | p.join()
273 | print("Finished sampling....")
274 |
275 | for run in os.listdir(base_path):
276 | stat_dict = {}
277 | run_path = os.path.join(base_path, run)
278 | batch_path = os.path.join(run_path, "batch.npz")
279 | print(f"Converting {run_path} to npz....")
280 | save_images_npz(run_path, batch_path)
281 |
282 | config = tf.ConfigProto(allow_soft_placement=True)
283 | config.gpu_options.allow_growth = True
284 | evaluator = Evaluator(tf.Session(config=config))
285 |
286 | evaluator.warmup()
287 |
288 | ref_acts = evaluator.read_activations(ref_images)
289 | ref_stats, ref_stats_spatial = evaluator.read_statistics(ref_images, ref_acts)
290 |
291 | sample_acts = evaluator.read_activations(batch_path)
292 | sample_stats, sample_stats_spatial = evaluator.read_statistics(batch_path, sample_acts)
293 |
294 | prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
295 | stat_dict["inception_score"] = evaluator.compute_inception_score(sample_acts[0])
296 | stat_dict["fid"] = sample_stats.frechet_distance(ref_stats)
297 | stat_dict["sfid"] = sample_stats_spatial.frechet_distance(ref_stats_spatial)
298 | stat_dict["prec"], stat_dict["recall"] = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
299 | print("---------------------------------------------------------")
300 | print(f"Metrics for {run_path}")
301 | print(stat_dict)
302 | print("---------------------------------------------------------")
303 | json.dump(stat_dict, open(os.path.join(run_path, "stat_dict.json"), "w"))
304 | os.remove(batch_path)
305 |
306 |
307 |
--------------------------------------------------------------------------------
/evaluation/evaluator.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import io
3 | import os
4 | import random
5 | import warnings
6 | import zipfile
7 | from abc import ABC, abstractmethod
8 | from contextlib import contextmanager
9 | from functools import partial
10 | from multiprocessing import cpu_count
11 | from multiprocessing.pool import ThreadPool
12 | from typing import Iterable, Optional, Tuple
13 |
14 | import numpy as np
15 | import requests
16 | import tensorflow.compat.v1 as tf
17 | from scipy import linalg
18 | from tqdm.auto import tqdm
19 |
20 | INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
21 | INCEPTION_V3_PATH = "classify_image_graph_def.pb"
22 |
23 | FID_POOL_NAME = "pool_3:0"
24 | FID_SPATIAL_NAME = "mixed_6/conv:0"
25 |
26 |
27 | def main():
28 | parser = argparse.ArgumentParser()
29 | parser.add_argument("ref_batch", help="path to reference batch npz file")
30 | parser.add_argument("sample_batch", help="path to sample batch npz file")
31 | args = parser.parse_args()
32 |
33 | config = tf.ConfigProto(
34 | allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph
35 | )
36 | config.gpu_options.allow_growth = True
37 | evaluator = Evaluator(tf.Session(config=config))
38 |
39 | print("warming up TensorFlow...")
40 | # This will cause TF to print a bunch of verbose stuff now rather
41 | # than after the next print(), to help prevent confusion.
42 | evaluator.warmup()
43 |
44 | print("computing reference batch activations...")
45 | ref_acts = evaluator.read_activations(args.ref_batch)
46 | print("computing/reading reference batch statistics...")
47 | ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
48 |
49 | print("computing sample batch activations...")
50 | sample_acts = evaluator.read_activations(args.sample_batch)
51 | print("computing/reading sample batch statistics...")
52 | sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts)
53 |
54 | print("Computing evaluations...")
55 | print("Inception Score:", evaluator.compute_inception_score(sample_acts[0]))
56 | print("FID:", sample_stats.frechet_distance(ref_stats))
57 | print("sFID:", sample_stats_spatial.frechet_distance(ref_stats_spatial))
58 | prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
59 | print("Precision:", prec)
60 | print("Recall:", recall)
61 |
62 |
63 | class InvalidFIDException(Exception):
64 | pass
65 |
66 |
67 | class FIDStatistics:
68 | def __init__(self, mu: np.ndarray, sigma: np.ndarray):
69 | self.mu = mu
70 | self.sigma = sigma
71 |
72 | def frechet_distance(self, other, eps=1e-6):
73 | """
74 | Compute the Frechet distance between two sets of statistics.
75 | """
76 | # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
77 | mu1, sigma1 = self.mu, self.sigma
78 | mu2, sigma2 = other.mu, other.sigma
79 |
80 | mu1 = np.atleast_1d(mu1)
81 | mu2 = np.atleast_1d(mu2)
82 |
83 | sigma1 = np.atleast_2d(sigma1)
84 | sigma2 = np.atleast_2d(sigma2)
85 |
86 | assert (
87 | mu1.shape == mu2.shape
88 | ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
89 | assert (
90 | sigma1.shape == sigma2.shape
91 | ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
92 |
93 | diff = mu1 - mu2
94 |
95 | # product might be almost singular
96 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
97 | if not np.isfinite(covmean).all():
98 | msg = (
99 | "fid calculation produces singular product; adding %s to diagonal of cov estimates"
100 | % eps
101 | )
102 | warnings.warn(msg)
103 | offset = np.eye(sigma1.shape[0]) * eps
104 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
105 |
106 | # numerical error might give slight imaginary component
107 | if np.iscomplexobj(covmean):
108 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
109 | m = np.max(np.abs(covmean.imag))
110 | raise ValueError("Imaginary component {}".format(m))
111 | covmean = covmean.real
112 |
113 | tr_covmean = np.trace(covmean)
114 |
115 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
116 |
117 |
118 | class Evaluator:
119 | def __init__(
120 | self,
121 | session,
122 | batch_size=64,
123 | softmax_batch_size=512,
124 | ):
125 | self.sess = session
126 | self.batch_size = batch_size
127 | self.softmax_batch_size = softmax_batch_size
128 | self.manifold_estimator = ManifoldEstimator(session)
129 | with self.sess.graph.as_default():
130 | self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
131 | self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
132 | self.pool_features, self.spatial_features = _create_feature_graph(self.image_input)
133 | self.softmax = _create_softmax_graph(self.softmax_input)
134 |
135 | def warmup(self):
136 | self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
137 |
138 | def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
139 | with open_npz_array(npz_path, "arr_0") as reader:
140 | return self.compute_activations(reader.read_batches(self.batch_size))
141 |
142 | def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
143 | """
144 | Compute image features for downstream evals.
145 |
146 | :param batches: a iterator over NHWC numpy arrays in [0, 255].
147 | :return: a tuple of numpy arrays of shape [N x X], where X is a feature
148 | dimension. The tuple is (pool_3, spatial).
149 | """
150 | preds = []
151 | spatial_preds = []
152 | for batch in tqdm(batches):
153 | batch = batch.astype(np.float32)
154 | pred, spatial_pred = self.sess.run(
155 | [self.pool_features, self.spatial_features], {self.image_input: batch}
156 | )
157 | preds.append(pred.reshape([pred.shape[0], -1]))
158 | spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
159 | return (
160 | np.concatenate(preds, axis=0),
161 | np.concatenate(spatial_preds, axis=0),
162 | )
163 |
164 | def read_statistics(
165 | self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
166 | ) -> Tuple[FIDStatistics, FIDStatistics]:
167 | obj = np.load(npz_path)
168 | if "mu" in list(obj.keys()):
169 | return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
170 | obj["mu_s"], obj["sigma_s"]
171 | )
172 | return tuple(self.compute_statistics(x) for x in activations)
173 |
174 | def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
175 | mu = np.mean(activations, axis=0)
176 | sigma = np.cov(activations, rowvar=False)
177 | return FIDStatistics(mu, sigma)
178 |
179 | def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float:
180 | softmax_out = []
181 | for i in range(0, len(activations), self.softmax_batch_size):
182 | acts = activations[i : i + self.softmax_batch_size]
183 | softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}))
184 | preds = np.concatenate(softmax_out, axis=0)
185 | # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
186 | scores = []
187 | for i in range(0, len(preds), split_size):
188 | part = preds[i : i + split_size]
189 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
190 | kl = np.mean(np.sum(kl, 1))
191 | scores.append(np.exp(kl))
192 | return float(np.mean(scores))
193 |
194 | def compute_prec_recall(
195 | self, activations_ref: np.ndarray, activations_sample: np.ndarray
196 | ) -> Tuple[float, float]:
197 | radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
198 | radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
199 | pr = self.manifold_estimator.evaluate_pr(
200 | activations_ref, radii_1, activations_sample, radii_2
201 | )
202 | return (float(pr[0][0]), float(pr[1][0]))
203 |
204 |
205 | class ManifoldEstimator:
206 | """
207 | A helper for comparing manifolds of feature vectors.
208 |
209 | Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
210 | """
211 |
212 | def __init__(
213 | self,
214 | session,
215 | row_batch_size=10000,
216 | col_batch_size=10000,
217 | nhood_sizes=(3,),
218 | clamp_to_percentile=None,
219 | eps=1e-5,
220 | ):
221 | """
222 | Estimate the manifold of given feature vectors.
223 |
224 | :param session: the TensorFlow session.
225 | :param row_batch_size: row batch size to compute pairwise distances
226 | (parameter to trade-off between memory usage and performance).
227 | :param col_batch_size: column batch size to compute pairwise distances.
228 | :param nhood_sizes: number of neighbors used to estimate the manifold.
229 | :param clamp_to_percentile: prune hyperspheres that have radius larger than
230 | the given percentile.
231 | :param eps: small number for numerical stability.
232 | """
233 | self.distance_block = DistanceBlock(session)
234 | self.row_batch_size = row_batch_size
235 | self.col_batch_size = col_batch_size
236 | self.nhood_sizes = nhood_sizes
237 | self.num_nhoods = len(nhood_sizes)
238 | self.clamp_to_percentile = clamp_to_percentile
239 | self.eps = eps
240 |
241 | def warmup(self):
242 | feats, radii = (
243 | np.zeros([1, 2048], dtype=np.float32),
244 | np.zeros([1, 1], dtype=np.float32),
245 | )
246 | self.evaluate_pr(feats, radii, feats, radii)
247 |
248 | def manifold_radii(self, features: np.ndarray) -> np.ndarray:
249 | num_images = len(features)
250 |
251 | # Estimate manifold of features by calculating distances to k-NN of each sample.
252 | radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
253 | distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
254 | seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
255 |
256 | for begin1 in range(0, num_images, self.row_batch_size):
257 | end1 = min(begin1 + self.row_batch_size, num_images)
258 | row_batch = features[begin1:end1]
259 |
260 | for begin2 in range(0, num_images, self.col_batch_size):
261 | end2 = min(begin2 + self.col_batch_size, num_images)
262 | col_batch = features[begin2:end2]
263 |
264 | # Compute distances between batches.
265 | distance_batch[
266 | 0 : end1 - begin1, begin2:end2
267 | ] = self.distance_block.pairwise_distances(row_batch, col_batch)
268 |
269 | # Find the k-nearest neighbor from the current batch.
270 | radii[begin1:end1, :] = np.concatenate(
271 | [
272 | x[:, self.nhood_sizes]
273 | for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1)
274 | ],
275 | axis=0,
276 | )
277 |
278 | if self.clamp_to_percentile is not None:
279 | max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
280 | radii[radii > max_distances] = 0
281 | return radii
282 |
283 | def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray):
284 | """
285 | Evaluate if new feature vectors are at the manifold.
286 | """
287 | num_eval_images = eval_features.shape[0]
288 | num_ref_images = radii.shape[0]
289 | distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)
290 | batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
291 | max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
292 | nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
293 |
294 | for begin1 in range(0, num_eval_images, self.row_batch_size):
295 | end1 = min(begin1 + self.row_batch_size, num_eval_images)
296 | feature_batch = eval_features[begin1:end1]
297 |
298 | for begin2 in range(0, num_ref_images, self.col_batch_size):
299 | end2 = min(begin2 + self.col_batch_size, num_ref_images)
300 | ref_batch = features[begin2:end2]
301 |
302 | distance_batch[
303 | 0 : end1 - begin1, begin2:end2
304 | ] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
305 |
306 | # From the minibatch of new feature vectors, determine if they are in the estimated manifold.
307 | # If a feature vector is inside a hypersphere of some reference sample, then
308 | # the new sample lies at the estimated manifold.
309 | # The radii of the hyperspheres are determined from distances of neighborhood size k.
310 | samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
311 | batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
312 |
313 | max_realism_score[begin1:end1] = np.max(
314 | radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
315 | )
316 | nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1)
317 |
318 | return {
319 | "fraction": float(np.mean(batch_predictions)),
320 | "batch_predictions": batch_predictions,
321 | "max_realisim_score": max_realism_score,
322 | "nearest_indices": nearest_indices,
323 | }
324 |
325 | def evaluate_pr(
326 | self,
327 | features_1: np.ndarray,
328 | radii_1: np.ndarray,
329 | features_2: np.ndarray,
330 | radii_2: np.ndarray,
331 | ) -> Tuple[np.ndarray, np.ndarray]:
332 | """
333 | Evaluate precision and recall efficiently.
334 |
335 | :param features_1: [N1 x D] feature vectors for reference batch.
336 | :param radii_1: [N1 x K1] radii for reference vectors.
337 | :param features_2: [N2 x D] feature vectors for the other batch.
338 | :param radii_2: [N x K2] radii for other vectors.
339 | :return: a tuple of arrays for (precision, recall):
340 | - precision: an np.ndarray of length K1
341 | - recall: an np.ndarray of length K2
342 | """
343 | features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool)
344 | features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool)
345 | for begin_1 in range(0, len(features_1), self.row_batch_size):
346 | end_1 = begin_1 + self.row_batch_size
347 | batch_1 = features_1[begin_1:end_1]
348 | for begin_2 in range(0, len(features_2), self.col_batch_size):
349 | end_2 = begin_2 + self.col_batch_size
350 | batch_2 = features_2[begin_2:end_2]
351 | batch_1_in, batch_2_in = self.distance_block.less_thans(
352 | batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
353 | )
354 | features_1_status[begin_1:end_1] |= batch_1_in
355 | features_2_status[begin_2:end_2] |= batch_2_in
356 | return (
357 | np.mean(features_2_status.astype(np.float64), axis=0),
358 | np.mean(features_1_status.astype(np.float64), axis=0),
359 | )
360 |
361 |
362 | class DistanceBlock:
363 | """
364 | Calculate pairwise distances between vectors.
365 |
366 | Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
367 | """
368 |
369 | def __init__(self, session):
370 | self.session = session
371 |
372 | # Initialize TF graph to calculate pairwise distances.
373 | with session.graph.as_default():
374 | self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
375 | self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
376 | distance_block_16 = _batch_pairwise_distances(
377 | tf.cast(self._features_batch1, tf.float16),
378 | tf.cast(self._features_batch2, tf.float16),
379 | )
380 | self.distance_block = tf.cond(
381 | tf.reduce_all(tf.math.is_finite(distance_block_16)),
382 | lambda: tf.cast(distance_block_16, tf.float32),
383 | lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2),
384 | )
385 |
386 | # Extra logic for less thans.
387 | self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
388 | self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
389 | dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
390 | self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
391 | self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0)
392 |
393 | def pairwise_distances(self, U, V):
394 | """
395 | Evaluate pairwise distances between two batches of feature vectors.
396 | """
397 | return self.session.run(
398 | self.distance_block,
399 | feed_dict={self._features_batch1: U, self._features_batch2: V},
400 | )
401 |
402 | def less_thans(self, batch_1, radii_1, batch_2, radii_2):
403 | return self.session.run(
404 | [self._batch_1_in, self._batch_2_in],
405 | feed_dict={
406 | self._features_batch1: batch_1,
407 | self._features_batch2: batch_2,
408 | self._radii1: radii_1,
409 | self._radii2: radii_2,
410 | },
411 | )
412 |
413 |
414 | def _batch_pairwise_distances(U, V):
415 | """
416 | Compute pairwise distances between two batches of feature vectors.
417 | """
418 | with tf.variable_scope("pairwise_dist_block"):
419 | # Squared norms of each row in U and V.
420 | norm_u = tf.reduce_sum(tf.square(U), 1)
421 | norm_v = tf.reduce_sum(tf.square(V), 1)
422 |
423 | # norm_u as a column and norm_v as a row vectors.
424 | norm_u = tf.reshape(norm_u, [-1, 1])
425 | norm_v = tf.reshape(norm_v, [1, -1])
426 |
427 | # Pairwise squared Euclidean distances.
428 | D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
429 |
430 | return D
431 |
432 |
433 | class NpzArrayReader(ABC):
434 | @abstractmethod
435 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
436 | pass
437 |
438 | @abstractmethod
439 | def remaining(self) -> int:
440 | pass
441 |
442 | def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
443 | def gen_fn():
444 | while True:
445 | batch = self.read_batch(batch_size)
446 | if batch is None:
447 | break
448 | yield batch
449 |
450 | rem = self.remaining()
451 | num_batches = rem // batch_size + int(rem % batch_size != 0)
452 | return BatchIterator(gen_fn, num_batches)
453 |
454 |
455 | class BatchIterator:
456 | def __init__(self, gen_fn, length):
457 | self.gen_fn = gen_fn
458 | self.length = length
459 |
460 | def __len__(self):
461 | return self.length
462 |
463 | def __iter__(self):
464 | return self.gen_fn()
465 |
466 |
467 | class StreamingNpzArrayReader(NpzArrayReader):
468 | def __init__(self, arr_f, shape, dtype):
469 | self.arr_f = arr_f
470 | self.shape = shape
471 | self.dtype = dtype
472 | self.idx = 0
473 |
474 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
475 | if self.idx >= self.shape[0]:
476 | return None
477 |
478 | bs = min(batch_size, self.shape[0] - self.idx)
479 | self.idx += bs
480 |
481 | if self.dtype.itemsize == 0:
482 | return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
483 |
484 | read_count = bs * np.prod(self.shape[1:])
485 | read_size = int(read_count * self.dtype.itemsize)
486 | data = _read_bytes(self.arr_f, read_size, "array data")
487 | return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
488 |
489 | def remaining(self) -> int:
490 | return max(0, self.shape[0] - self.idx)
491 |
492 |
493 | class MemoryNpzArrayReader(NpzArrayReader):
494 | def __init__(self, arr):
495 | self.arr = arr
496 | self.idx = 0
497 |
498 | @classmethod
499 | def load(cls, path: str, arr_name: str):
500 | with open(path, "rb") as f:
501 | arr = np.load(f)[arr_name]
502 | return cls(arr)
503 |
504 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
505 | if self.idx >= self.arr.shape[0]:
506 | return None
507 |
508 | res = self.arr[self.idx : self.idx + batch_size]
509 | self.idx += batch_size
510 | return res
511 |
512 | def remaining(self) -> int:
513 | return max(0, self.arr.shape[0] - self.idx)
514 |
515 |
516 | @contextmanager
517 | def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
518 | with _open_npy_file(path, arr_name) as arr_f:
519 | version = np.lib.format.read_magic(arr_f)
520 | if version == (1, 0):
521 | header = np.lib.format.read_array_header_1_0(arr_f)
522 | elif version == (2, 0):
523 | header = np.lib.format.read_array_header_2_0(arr_f)
524 | else:
525 | yield MemoryNpzArrayReader.load(path, arr_name)
526 | return
527 | shape, fortran, dtype = header
528 | if fortran or dtype.hasobject:
529 | yield MemoryNpzArrayReader.load(path, arr_name)
530 | else:
531 | yield StreamingNpzArrayReader(arr_f, shape, dtype)
532 |
533 |
534 | def _read_bytes(fp, size, error_template="ran out of data"):
535 | """
536 | Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
537 |
538 | Read from file-like object until size bytes are read.
539 | Raises ValueError if not EOF is encountered before size bytes are read.
540 | Non-blocking objects only supported if they derive from io objects.
541 | Required as e.g. ZipExtFile in python 2.6 can return less data than
542 | requested.
543 | """
544 | data = bytes()
545 | while True:
546 | # io files (default in python3) return None or raise on
547 | # would-block, python2 file will truncate, probably nothing can be
548 | # done about that. note that regular files can't be non-blocking
549 | try:
550 | r = fp.read(size - len(data))
551 | data += r
552 | if len(r) == 0 or len(data) == size:
553 | break
554 | except io.BlockingIOError:
555 | pass
556 | if len(data) != size:
557 | msg = "EOF: reading %s, expected %d bytes got %d"
558 | raise ValueError(msg % (error_template, size, len(data)))
559 | else:
560 | return data
561 |
562 |
563 | @contextmanager
564 | def _open_npy_file(path: str, arr_name: str):
565 | with open(path, "rb") as f:
566 | with zipfile.ZipFile(f, "r") as zip_f:
567 | if f"{arr_name}.npy" not in zip_f.namelist():
568 | raise ValueError(f"missing {arr_name} in npz file")
569 | with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
570 | yield arr_f
571 |
572 |
573 | def _download_inception_model():
574 | if os.path.exists(INCEPTION_V3_PATH):
575 | return
576 | print("downloading InceptionV3 model...")
577 | with requests.get(INCEPTION_V3_URL, stream=True) as r:
578 | r.raise_for_status()
579 | tmp_path = INCEPTION_V3_PATH + ".tmp"
580 | with open(tmp_path, "wb") as f:
581 | for chunk in tqdm(r.iter_content(chunk_size=8192)):
582 | f.write(chunk)
583 | os.rename(tmp_path, INCEPTION_V3_PATH)
584 |
585 |
586 | def _create_feature_graph(input_batch):
587 | _download_inception_model()
588 | prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
589 | with open(INCEPTION_V3_PATH, "rb") as f:
590 | graph_def = tf.GraphDef()
591 | graph_def.ParseFromString(f.read())
592 | pool3, spatial = tf.import_graph_def(
593 | graph_def,
594 | input_map={f"ExpandDims:0": input_batch},
595 | return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
596 | name=prefix,
597 | )
598 | _update_shapes(pool3)
599 | spatial = spatial[..., :7]
600 | return pool3, spatial
601 |
602 |
603 | def _create_softmax_graph(input_batch):
604 | _download_inception_model()
605 | prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
606 | with open(INCEPTION_V3_PATH, "rb") as f:
607 | graph_def = tf.GraphDef()
608 | graph_def.ParseFromString(f.read())
609 | (matmul,) = tf.import_graph_def(
610 | graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
611 | )
612 | w = matmul.inputs[1]
613 | logits = tf.matmul(input_batch, w)
614 | return tf.nn.softmax(logits)
615 |
616 |
617 | def _update_shapes(pool3):
618 | # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
619 | ops = pool3.graph.get_operations()
620 | for op in ops:
621 | for o in op.outputs:
622 | shape = o.get_shape()
623 | if shape._dims is not None: # pylint: disable=protected-access
624 | # shape = [s.value for s in shape] TF 1.x
625 | shape = [s for s in shape] # TF 2.x
626 | new_shape = []
627 | for j, s in enumerate(shape):
628 | if s == 1 and j == 0:
629 | new_shape.append(None)
630 | else:
631 | new_shape.append(s)
632 | o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
633 | return pool3
634 |
635 |
636 | def _numpy_partition(arr, kth, **kwargs):
637 | num_workers = min(cpu_count(), len(arr))
638 | chunk_size = len(arr) // num_workers
639 | extra = len(arr) % num_workers
640 |
641 | start_idx = 0
642 | batches = []
643 | for i in range(num_workers):
644 | size = chunk_size + (1 if i < extra else 0)
645 | batches.append(arr[start_idx : start_idx + size])
646 | start_idx += size
647 |
648 | with ThreadPool(num_workers) as pool:
649 | return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
650 |
651 |
652 | if __name__ == "__main__":
653 | main()
654 |
--------------------------------------------------------------------------------
/paella_sampling.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "b3ea0ec1-436f-4fca-99fa-5122a73b52d7",
7 | "metadata": {
8 | "id": "b3ea0ec1-436f-4fca-99fa-5122a73b52d7",
9 | "outputId": "b90ef2e0-156a-458d-f495-8a0fda0253e5",
10 | "scrolled": true,
11 | "tags": []
12 | },
13 | "outputs": [],
14 | "source": [
15 | "!pip install kornia lpips einops rudalle open_clip_torch pytorch_lightning webdataset timm git+https://github.com/pabloppp/pytorch-tools git+https://github.com/openai/CLIP.git -U\n",
16 | "!pip uninstall torch torchvision torchaudio -y\n",
17 | "!pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116\n",
18 | "!pip install --upgrade Pillow"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "id": "e05c89a6-4648-46e2-90c7-cf7c1b0dd861",
25 | "metadata": {
26 | "id": "788a2a72",
27 | "outputId": "702cbb5b-8c71-4b33-e86e-f84b718b7a1b"
28 | },
29 | "outputs": [],
30 | "source": [
31 | "import os\n",
32 | "import time\n",
33 | "import torch\n",
34 | "from torch import nn\n",
35 | "import torchvision\n",
36 | "import matplotlib.pyplot as plt\n",
37 | "from tqdm import tqdm\n",
38 | "from PIL import Image\n",
39 | "import requests\n",
40 | "from io import BytesIO\n",
41 | "from modules import DenoiseUNet\n",
42 | "import open_clip\n",
43 | "from open_clip import tokenizer\n",
44 | "from rudalle import get_vae\n",
45 | "from einops import rearrange\n",
46 | "\n",
47 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
48 | "print(\"Using device:\", device)"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": null,
54 | "id": "d6640155",
55 | "metadata": {
56 | "id": "d6640155",
57 | "outputId": "99d4752b-3a85-49da-cbf3-72637160e2b9",
58 | "scrolled": true
59 | },
60 | "outputs": [],
61 | "source": [
62 | "def showmask(mask):\n",
63 | " plt.axis(\"off\")\n",
64 | " plt.imshow(torch.cat([\n",
65 | " torch.cat([i for i in mask[0:1].cpu()], dim=-1),\n",
66 | " ], dim=-2).cpu())\n",
67 | " plt.show()\n",
68 | "\n",
69 | "def showimages(imgs, **kwargs):\n",
70 | " plt.figure(figsize=(kwargs.get(\"width\", 32), kwargs.get(\"height\", 32)))\n",
71 | " plt.axis(\"off\")\n",
72 | " plt.imshow(torch.cat([\n",
73 | " torch.cat([i for i in imgs], dim=-1),\n",
74 | " ], dim=-2).permute(1, 2, 0).cpu())\n",
75 | " plt.show()\n",
76 | " \n",
77 | "def saveimages(imgs, name, **kwargs):\n",
78 | " name = name.replace(\" \", \"_\").replace(\".\", \"\")\n",
79 | " path = os.path.join(\"outputs\", name + \".jpg\")\n",
80 | " while os.path.exists(path):\n",
81 | " base, ext = path.split(\".\")\n",
82 | " num = base.split(\"_\")[-1]\n",
83 | " if num.isdigit():\n",
84 | " num = int(num) + 1\n",
85 | " base = \"_\".join(base.split(\"_\")[:-1])\n",
86 | " else:\n",
87 | " num = 0\n",
88 | " path = base + \"_\" + str(num) + \".\" + ext\n",
89 | " torchvision.utils.save_image(imgs, path, **kwargs)"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": null,
95 | "id": "7e2e1f26-b4ca-4e11-947d-be80196d440f",
96 | "metadata": {
97 | "tags": []
98 | },
99 | "outputs": [],
100 | "source": [
101 | "def log(t, eps=1e-20):\n",
102 | " return torch.log(t + eps)\n",
103 | "\n",
104 | "def gumbel_noise(t):\n",
105 | " noise = torch.zeros_like(t).uniform_(0, 1)\n",
106 | " return -log(-log(noise))\n",
107 | "\n",
108 | "def gumbel_sample(t, temperature=1., dim=-1):\n",
109 | " return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)\n",
110 | "\n",
111 | "def sample(model, c, x=None, mask=None, T=12, size=(32, 32), starting_t=0, temp_range=[1.0, 1.0], typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=-1, renoise_steps=11, renoise_mode='start'):\n",
112 | " with torch.inference_mode():\n",
113 | " r_range = torch.linspace(0, 1, T+1)[:-1][:, None].expand(-1, c.size(0)).to(c.device)\n",
114 | " temperatures = torch.linspace(temp_range[0], temp_range[1], T)\n",
115 | " preds = []\n",
116 | " if x is None:\n",
117 | " x = torch.randint(0, model.num_labels, size=(c.size(0), *size), device=c.device)\n",
118 | " elif mask is not None:\n",
119 | " noise = torch.randint(0, model.num_labels, size=(c.size(0), *size), device=c.device)\n",
120 | " x = noise * mask + (1-mask) * x\n",
121 | " init_x = x.clone()\n",
122 | " for i in range(starting_t, T):\n",
123 | " if renoise_mode == 'prev':\n",
124 | " prev_x = x.clone()\n",
125 | " r, temp = r_range[i], temperatures[i]\n",
126 | " logits = model(x, c, r)\n",
127 | " if classifier_free_scale >= 0:\n",
128 | " logits_uncond = model(x, torch.zeros_like(c), r)\n",
129 | " logits = torch.lerp(logits_uncond, logits, classifier_free_scale)\n",
130 | " x = logits\n",
131 | " x_flat = x.permute(0, 2, 3, 1).reshape(-1, x.size(1))\n",
132 | " if typical_filtering:\n",
133 | " x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)\n",
134 | " x_flat_norm_p = torch.exp(x_flat_norm)\n",
135 | " entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True)\n",
136 | "\n",
137 | " c_flat_shifted = torch.abs((-x_flat_norm) - entropy)\n",
138 | " c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)\n",
139 | " x_flat_cumsum = x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)\n",
140 | "\n",
141 | " last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)\n",
142 | " sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(1, last_ind.view(-1, 1))\n",
143 | " if typical_min_tokens > 1:\n",
144 | " sorted_indices_to_remove[..., :typical_min_tokens] = 0\n",
145 | " indices_to_remove = sorted_indices_to_remove.scatter(1, x_flat_indices, sorted_indices_to_remove)\n",
146 | " x_flat = x_flat.masked_fill(indices_to_remove, -float(\"Inf\"))\n",
147 | " # x_flat = torch.multinomial(x_flat.div(temp).softmax(-1), num_samples=1)[:, 0]\n",
148 | " x_flat = gumbel_sample(x_flat, temperature=temp)\n",
149 | " x = x_flat.view(x.size(0), *x.shape[2:])\n",
150 | " if mask is not None:\n",
151 | " x = x * mask + (1-mask) * init_x\n",
152 | " if i < renoise_steps:\n",
153 | " if renoise_mode == 'start':\n",
154 | " x, _ = model.add_noise(x, r_range[i+1], random_x=init_x)\n",
155 | " elif renoise_mode == 'prev':\n",
156 | " x, _ = model.add_noise(x, r_range[i+1], random_x=prev_x)\n",
157 | " else: # 'rand'\n",
158 | " x, _ = model.add_noise(x, r_range[i+1])\n",
159 | " preds.append(x.detach())\n",
160 | " return preds"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": null,
166 | "id": "34c7c626",
167 | "metadata": {
168 | "id": "34c7c626",
169 | "outputId": "16108313-6b3d-4751-d84b-a92d7f8ebf68",
170 | "tags": []
171 | },
172 | "outputs": [],
173 | "source": [
174 | "vqmodel = get_vae().to(device)\n",
175 | "vqmodel.eval().requires_grad_(False)\n",
176 | "\n",
177 | "clip_model, _, _ = open_clip.create_model_and_transforms('ViT-g-14', pretrained='laion2b_s12b_b42k')\n",
178 | "clip_model = clip_model.to(device).eval().requires_grad_(False)\n",
179 | "\n",
180 | "clip_preprocess = torchvision.transforms.Compose([\n",
181 | " torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),\n",
182 | " torchvision.transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),\n",
183 | "])\n",
184 | "\n",
185 | "preprocess = torchvision.transforms.Compose([\n",
186 | " torchvision.transforms.Resize(256),\n",
187 | " # torchvision.transforms.CenterCrop(256),\n",
188 | " torchvision.transforms.ToTensor(),\n",
189 | "])\n",
190 | "\n",
191 | "def encode(x):\n",
192 | " return vqmodel.model.encode((2 * x - 1))[-1][-1]\n",
193 | " \n",
194 | "def decode(img_seq, shape=(32,32)):\n",
195 | " img_seq = img_seq.view(img_seq.shape[0], -1)\n",
196 | " b, n = img_seq.shape\n",
197 | " one_hot_indices = torch.nn.functional.one_hot(img_seq, num_classes=vqmodel.num_tokens).float()\n",
198 | " z = (one_hot_indices @ vqmodel.model.quantize.embed.weight)\n",
199 | " z = rearrange(z, 'b (h w) c -> b c h w', h=shape[0], w=shape[1])\n",
200 | " img = vqmodel.model.decode(z)\n",
201 | " img = (img.clamp(-1., 1.) + 1) * 0.5\n",
202 | " return img\n",
203 | " \n",
204 | "state_dict = torch.load(\"./models/f8_600000.pt\", map_location=device)\n",
205 | "# state_dict = torch.load(\"./models/f8_img_40000.pt\", map_location=device)\n",
206 | "model = DenoiseUNet(num_labels=8192).to(device)\n",
207 | "model.load_state_dict(state_dict)\n",
208 | "model.eval().requires_grad_()\n",
209 | "print()"
210 | ]
211 | },
212 | {
213 | "cell_type": "markdown",
214 | "id": "753a98f2-bf40-4059-86f6-c83dac8c15eb",
215 | "metadata": {},
216 | "source": [
217 | "# Text-Conditional"
218 | ]
219 | },
220 | {
221 | "cell_type": "code",
222 | "execution_count": null,
223 | "id": "d39cd5d7-220a-438b-8229-823fa9e3fff9",
224 | "metadata": {
225 | "tags": []
226 | },
227 | "outputs": [],
228 | "source": [
229 | "mode = \"text\"\n",
230 | "batch_size = 6\n",
231 | "text = \"highly detailed photograph of darth vader. artstation\"\n",
232 | "latent_shape = (32, 32)\n",
233 | "tokenized_text = tokenizer.tokenize([text] * batch_size).to(device)\n",
234 | "with torch.inference_mode():\n",
235 | " with torch.autocast(device_type=\"cuda\"):\n",
236 | " clip_embeddings = clip_model.encode_text(tokenized_text)\n",
237 | " s = time.time()\n",
238 | " sampled = sample(model, clip_embeddings, T=12, size=latent_shape, starting_t=0, temp_range=[1.0, 1.0],\n",
239 | " typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=5, renoise_steps=11,\n",
240 | " renoise_mode=\"start\")\n",
241 | " print(time.time() - s)\n",
242 | " sampled = decode(sampled[-1], latent_shape)\n",
243 | "\n",
244 | "showimages(sampled)\n",
245 | "saveimages(sampled, mode + \"_\" + text, nrow=len(sampled))"
246 | ]
247 | },
248 | {
249 | "cell_type": "markdown",
250 | "id": "c3f5cecc-4e30-465e-bc51-425ec76a7d06",
251 | "metadata": {},
252 | "source": [
253 | "# Interpolation"
254 | ]
255 | },
256 | {
257 | "cell_type": "code",
258 | "execution_count": null,
259 | "id": "35ae32e2",
260 | "metadata": {
261 | "id": "35ae32e2",
262 | "outputId": "28283340-fe61-44c6-9d92-720d04cf56bf"
263 | },
264 | "outputs": [],
265 | "source": [
266 | "mode = \"interpolation\"\n",
267 | "text = \"surreal painting of a yellow tulip. artstation\"\n",
268 | "text2 = \"surreal painting of a red tulip. artstation\"\n",
269 | "text_encoded = tokenizer.tokenize([text]).to(device)\n",
270 | "text2_encoded = tokenizer.tokenize([text2]).to(device)\n",
271 | "with torch.inference_mode():\n",
272 | " with torch.autocast(device_type=\"cuda\"):\n",
273 | " clip_embeddings = clip_model.encode_text(text_encoded).float()\n",
274 | " clip_embeddings2 = clip_model.encode_text(text2_encoded).float()\n",
275 | "\n",
276 | " l = torch.linspace(0, 1, 10).to(device)\n",
277 | " embeddings = []\n",
278 | " for i in l:\n",
279 | " lerp = torch.lerp(clip_embeddings, clip_embeddings2, i)\n",
280 | " embeddings.append(lerp)\n",
281 | " embeddings = torch.cat(embeddings)\n",
282 | " \n",
283 | " s = time.time()\n",
284 | " sampled = sample(model, embeddings, T=12, size=(32, 32), starting_t=0, temp_range=[1.0, 1.0],\n",
285 | " typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=4, renoise_steps=11)\n",
286 | " print(time.time() - s)\n",
287 | " sampled = decode(sampled[-1])\n",
288 | "showimages(sampled)\n",
289 | "saveimages(sampled, mode + \"_\" + text + \"_\" + text2, nrow=len(sampled))"
290 | ]
291 | },
292 | {
293 | "cell_type": "code",
294 | "execution_count": null,
295 | "id": "026fade5-0384-4521-992e-8cd66929d26d",
296 | "metadata": {},
297 | "outputs": [],
298 | "source": [
299 | "mode = \"interpolation\"\n",
300 | "text = \"High quality front portrait photo of a tiger.\"\n",
301 | "text2 = \"High quality front portrait photo of a dog.\"\n",
302 | "text_encoded = tokenizer.tokenize([text]).to(device)\n",
303 | "text2_encoded = tokenizer.tokenize([text2]).to(device)\n",
304 | "with torch.inference_mode():\n",
305 | " with torch.autocast(device_type=\"cuda\"):\n",
306 | " clip_embeddings = clip_model.encode_text(text_encoded).float()\n",
307 | " clip_embeddings2 = clip_model.encode_text(text2_encoded).float()\n",
308 | "\n",
309 | " l = torch.linspace(0, 1, 10).to(device)\n",
310 | " s = time.time()\n",
311 | " outputs = []\n",
312 | " for i in l:\n",
313 | " # lerp = torch.lerp(clip_embeddings, clip_embeddings2, i)\n",
314 | " low, high = clip_embeddings, clip_embeddings2\n",
315 | " low_norm = low/torch.norm(low, dim=1, keepdim=True)\n",
316 | " high_norm = high/torch.norm(high, dim=1, keepdim=True)\n",
317 | " omega = torch.acos((low_norm*high_norm).sum(1)).unsqueeze(1)\n",
318 | " so = torch.sin(omega)\n",
319 | " lerp = (torch.sin((1.0-i)*omega)/so)*low + (torch.sin(i*omega)/so) * high\n",
320 | " with torch.random.fork_rng():\n",
321 | " torch.random.manual_seed(32)\n",
322 | " sampled = sample(model, lerp, T=20, size=(32, 32), starting_t=0, temp_range=[1.0, 1.0],\n",
323 | " typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=5, renoise_steps=11)\n",
324 | " outputs.append(sampled[-1])\n",
325 | " print(time.time() - s)\n",
326 | " sampled = torch.cat(outputs)\n",
327 | " sampled = decode(sampled)\n",
328 | "showimages(sampled)\n",
329 | "saveimages(sampled, mode + \"_\" + text + \"_\" + text2, nrow=len(sampled))"
330 | ]
331 | },
332 | {
333 | "cell_type": "markdown",
334 | "id": "561133a9",
335 | "metadata": {
336 | "id": "0bd51975"
337 | },
338 | "source": [
339 | "# Multi-Conditioning"
340 | ]
341 | },
342 | {
343 | "cell_type": "code",
344 | "execution_count": null,
345 | "id": "83e9ac9f",
346 | "metadata": {
347 | "id": "83e9ac9f"
348 | },
349 | "outputs": [],
350 | "source": [
351 | "batch_size = 4\n",
352 | "latent_shape = (32, 32)\n",
353 | "text_a = \"a cute portrait of a dog\"\n",
354 | "text_b = \"a cute portrait of a cat\"\n",
355 | "mode = \"vertical\"\n",
356 | "# mode = \"horizontal\"\n",
357 | "text = tokenizer.tokenize([text_a, text_b] * batch_size).to(device)\n",
358 | "\n",
359 | "with torch.inference_mode():\n",
360 | " with torch.autocast(device_type=\"cuda\"):\n",
361 | " clip_embeddings = clip_model.encode_text(text).float()[:, :, None, None].expand(-1, -1, latent_shape[0], latent_shape[1])\n",
362 | " if mode == 'vertical':\n",
363 | " interp_mask = torch.linspace(0, 1, latent_shape[0], device=device)[None, None, :, None].expand(batch_size, 1, -1, latent_shape[1])\n",
364 | " else: \n",
365 | " interp_mask = torch.linspace(0, 1, latent_shape[1], device=device)[None, None, None, :].expand(batch_size, 1, latent_shape[0], -1)\n",
366 | " # LERP\n",
367 | " clip_embeddings = clip_embeddings[0::2] * (1-interp_mask) + clip_embeddings[1::2] * interp_mask\n",
368 | " # # SLERP\n",
369 | " # low, high = clip_embeddings[0::2], clip_embeddings[1::2]\n",
370 | " # low_norm = low/torch.norm(low, dim=1, keepdim=True)\n",
371 | " # high_norm = high/torch.norm(high, dim=1, keepdim=True)\n",
372 | " # omega = torch.acos((low_norm*high_norm).sum(1)).unsqueeze(1)\n",
373 | " # so = torch.sin(omega)\n",
374 | " # clip_embeddings = (torch.sin((1.0-interp_mask)*omega)/so)*low + (torch.sin(interp_mask*omega)/so) * high\n",
375 | " \n",
376 | " sampled = sample(model, clip_embeddings, T=12, size=latent_shape, starting_t=0, temp_range=[1.0, 1.0],\n",
377 | " typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=5, renoise_steps=11,\n",
378 | " renoise_mode=\"start\")\n",
379 | " sampled = decode(sampled[-1], latent_shape)\n",
380 | "\n",
381 | "showimages(sampled)"
382 | ]
383 | },
384 | {
385 | "cell_type": "code",
386 | "execution_count": null,
387 | "id": "3838e233-34fd-4637-b0df-476f4e66cd83",
388 | "metadata": {},
389 | "outputs": [],
390 | "source": [
391 | "mode = \"multiconditioning\"\n",
392 | "batch_size = 4\n",
393 | "latent_shape = (32, 32)\n",
394 | "conditions = [\n",
395 | " [\"High quality portrait of a dog.\", 16],\n",
396 | " [\"High quality portrait of a wolf.\", 32],\n",
397 | "]\n",
398 | "clip_embedding = torch.zeros(batch_size, 1024, *latent_shape).to(device)\n",
399 | "last_pos = 0\n",
400 | "for text, pos in conditions:\n",
401 | " tokenized_text = tokenizer.tokenize([text] * batch_size).to(device)\n",
402 | " part_clip_embedding = clip_model.encode_text(tokenized_text).float()[:, :, None, None]\n",
403 | " print(f\"{last_pos}:{pos}={text}\")\n",
404 | " clip_embedding[:, :, :, last_pos:pos] = part_clip_embedding\n",
405 | " last_pos = pos\n",
406 | "with torch.inference_mode():\n",
407 | " with torch.autocast(device_type=\"cuda\"):\n",
408 | " sampled = sample(model, clip_embedding, T=12, size=latent_shape, starting_t=0, temp_range=[1.0, 1.0],\n",
409 | " typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=5, renoise_steps=11,\n",
410 | " renoise_mode=\"start\")\n",
411 | " sampled = decode(sampled[-1], latent_shape)\n",
412 | " \n",
413 | "showimages(sampled)\n",
414 | "saveimages(sampled, mode + \"_\" + \":\".join(list(map(lambda x: x[0], conditions))), nrow=batch_size)"
415 | ]
416 | },
417 | {
418 | "cell_type": "markdown",
419 | "id": "0a65e885-e166-4997-bf81-94e344e46a78",
420 | "metadata": {},
421 | "source": [
422 | "#### Load Image: Disk or Web"
423 | ]
424 | },
425 | {
426 | "cell_type": "code",
427 | "execution_count": null,
428 | "id": "9fc57b1a-de96-4ef4-b699-0417078b9da3",
429 | "metadata": {},
430 | "outputs": [],
431 | "source": [
432 | "images = preprocess(Image.open(\"path_to_image\")).unsqueeze(0).expand(4, -1, -1, -1).to(device)[:, :3]\n",
433 | "showimages(images)"
434 | ]
435 | },
436 | {
437 | "cell_type": "code",
438 | "execution_count": null,
439 | "id": "7fcf26db-a58a-4f53-9c0b-d83707e7713d",
440 | "metadata": {},
441 | "outputs": [],
442 | "source": [
443 | "url = \"https://media.istockphoto.com/id/1193591781/photo/obedient-dog-breed-welsh-corgi-pembroke-sitting-and-smiles-on-a-white-background-not-isolate.jpg?s=612x612&w=0&k=20&c=ZDKTgSFQFG9QvuDziGsnt55kvQoqJtIhrmVRkpYqxtQ=\"\n",
444 | "# url = \"https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1200px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg\"\n",
445 | "response = requests.get(url)\n",
446 | "img = Image.open(BytesIO(response.content)).convert(\"RGB\")\n",
447 | "images = preprocess(img).unsqueeze(0).expand(4, -1, -1, -1).to(device)[:, :3]\n",
448 | "showimages(images)"
449 | ]
450 | },
451 | {
452 | "cell_type": "markdown",
453 | "id": "11eb27d7-5d3b-4a2d-9b2a-86cb81850dad",
454 | "metadata": {},
455 | "source": [
456 | "# Inpainting"
457 | ]
458 | },
459 | {
460 | "cell_type": "code",
461 | "execution_count": null,
462 | "id": "7d59c35c-b7a5-4468-99b8-4e4377df63c2",
463 | "metadata": {},
464 | "outputs": [],
465 | "source": [
466 | "mode = \"inpainting\"\n",
467 | "text = \"a delicious spanish paella\"\n",
468 | "tokenized_text = tokenizer.tokenize([text] * images.shape[0]).to(device)\n",
469 | "with torch.inference_mode():\n",
470 | " with torch.autocast(device_type=\"cuda\"):\n",
471 | " # clip_embeddings = clip_model.encode_image(clip_preprocess(images)).float() # clip_embeddings = clip_model.encode_text(text).float()\n",
472 | " clip_embeddings = clip_model.encode_text(tokenized_text).float()\n",
473 | " encoded_tokens = encode(images)\n",
474 | " latent_shape = encoded_tokens.shape[1:]\n",
475 | " mask = torch.zeros_like(encoded_tokens)\n",
476 | " mask[:,5:28,5:28] = 1\n",
477 | " sampled = sample(model, clip_embeddings, x=encoded_tokens, mask=mask, T=12, size=latent_shape, starting_t=0, temp_range=[1.0, 1.0],\n",
478 | " typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=6, renoise_steps=11)\n",
479 | " sampled = decode(sampled[-1], latent_shape)\n",
480 | "\n",
481 | "showimages(images[0:1], height=10, width=10)\n",
482 | "showmask(mask[0:1])\n",
483 | "showimages(sampled, height=16, width=16)\n",
484 | "saveimages(torch.cat([images[0:1], sampled]), mode + \"_\" + text, nrow=images.shape[0]+1)"
485 | ]
486 | },
487 | {
488 | "cell_type": "markdown",
489 | "id": "86f2bd06-aef7-4887-a5e8-dcc4af883dab",
490 | "metadata": {
491 | "tags": []
492 | },
493 | "source": [
494 | "# Outpainting"
495 | ]
496 | },
497 | {
498 | "cell_type": "code",
499 | "execution_count": null,
500 | "id": "97e4243a-186e-40e1-9b71-c11df0c60963",
501 | "metadata": {},
502 | "outputs": [],
503 | "source": [
504 | "mode = \"outpainting\"\n",
505 | "size = (40, 64)\n",
506 | "top_left = (0, 16)\n",
507 | "text = \"black & white photograph of a rocket from the bottom.\"\n",
508 | "tokenized_text = tokenizer.tokenize([text] * images.shape[0]).to(device)\n",
509 | "with torch.inference_mode():\n",
510 | " with torch.autocast(device_type=\"cuda\"):\n",
511 | " # clip_embeddings = clip_model.encode_image(clip_preprocess(images)).float()\n",
512 | " clip_embeddings = clip_model.encode_text(tokenized_text).float()\n",
513 | " encoded_tokens = encode(images)\n",
514 | " canvas = torch.zeros((images.shape[0], *size), dtype=torch.long).to(device)\n",
515 | " canvas[:, top_left[0]:top_left[0]+encoded_tokens.shape[1], top_left[1]:top_left[1]+encoded_tokens.shape[2]] = encoded_tokens\n",
516 | " mask = torch.ones_like(canvas)\n",
517 | " mask[:, top_left[0]:top_left[0]+encoded_tokens.shape[1], top_left[1]:top_left[1]+encoded_tokens.shape[2]] = 0\n",
518 | " sampled = sample(model, clip_embeddings, x=canvas, mask=mask, T=12, size=size, starting_t=0, temp_range=[1.0, 1.0],\n",
519 | " typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=4, renoise_steps=11)\n",
520 | " sampled = decode(sampled[-1], size)\n",
521 | "\n",
522 | "showimages(images[0:1], height=10, width=10)\n",
523 | "showmask(mask[0:1])\n",
524 | "showimages(sampled, height=16, width=16)\n",
525 | "saveimages(sampled, mode + \"_\" + text, nrow=images.shape[0])"
526 | ]
527 | },
528 | {
529 | "cell_type": "markdown",
530 | "id": "8ad2d6da-ae1c-4c76-a494-ceb5469a649b",
531 | "metadata": {},
532 | "source": [
533 | "# Structural Morphing"
534 | ]
535 | },
536 | {
537 | "cell_type": "code",
538 | "execution_count": null,
539 | "id": "cd46fd02-6770-4824-ad1d-356c6f197eaf",
540 | "metadata": {},
541 | "outputs": [],
542 | "source": [
543 | "mode = \"morphing\"\n",
544 | "max_steps = 24\n",
545 | "init_step = 8\n",
546 | "\n",
547 | "text = \"A fox posing for a photo. stock photo. highly detailed. 4k\"\n",
548 | "\n",
549 | "with torch.inference_mode():\n",
550 | " with torch.autocast(device_type=\"cuda\"):\n",
551 | " # images = preprocess(Image.open(\"data/city sketch.png\")).unsqueeze(0).expand(4, -1, -1, -1).to(device)[:, :3]\n",
552 | " latent_image = encode(images)\n",
553 | " latent_shape = latent_image.shape[-2:]\n",
554 | " r = torch.ones(latent_image.size(0), device=device) * (init_step/max_steps)\n",
555 | " noised_latent_image, _ = model.add_noise(latent_image, r)\n",
556 | " \n",
557 | " tokenized_text = tokenizer.tokenize([text] * images.size(0)).to(device)\n",
558 | " clip_embeddings = clip_model.encode_text(tokenized_text).float()\n",
559 | " \n",
560 | " sampled = sample(model, clip_embeddings, x=noised_latent_image, T=max_steps, size=latent_shape, starting_t=init_step, temp_range=[1.0, 1.0],\n",
561 | " typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=6, renoise_steps=max_steps-1,\n",
562 | " renoise_mode=\"prev\")\n",
563 | " sampled = decode(sampled[-1], latent_shape)\n",
564 | "showimages(sampled)\n",
565 | "showimages(images)\n",
566 | "saveimages(torch.cat([images[0:1], sampled]), mode + \"_\" + text, nrow=images.shape[0]+1)"
567 | ]
568 | },
569 | {
570 | "cell_type": "markdown",
571 | "id": "109dbf24-8b68-4af1-9fa2-492bbf5c6a26",
572 | "metadata": {},
573 | "source": [
574 | "# Image Variations"
575 | ]
576 | },
577 | {
578 | "cell_type": "code",
579 | "execution_count": null,
580 | "id": "cf81732b-9adc-4cda-9961-12480d932ff4",
581 | "metadata": {},
582 | "outputs": [],
583 | "source": [
584 | "clip_preprocess = torchvision.transforms.Compose([\n",
585 | " torchvision.transforms.Resize((224, 224), interpolation=torchvision.transforms.InterpolationMode.BICUBIC),\n",
586 | " torchvision.transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),\n",
587 | "])"
588 | ]
589 | },
590 | {
591 | "cell_type": "code",
592 | "execution_count": null,
593 | "id": "ac72cd16-0f3e-407e-ba5e-ad7664254c7f",
594 | "metadata": {},
595 | "outputs": [],
596 | "source": [
597 | "latent_shape = (32, 32)\n",
598 | "with torch.inference_mode():\n",
599 | " with torch.autocast(device_type=\"cuda\"):\n",
600 | " clip_embeddings = clip_model.encode_image(clip_preprocess(images)).float() # clip_embeddings = clip_model.encode_text(text).float() \n",
601 | " sampled = sample(model, clip_embeddings, T=12, size=latent_shape, starting_t=0, temp_range=[1.0, 1.0],\n",
602 | " typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=5, renoise_steps=11)\n",
603 | " sampled = decode(sampled[-1], latent_shape)\n",
604 | "\n",
605 | "showimages(images)\n",
606 | "showimages(sampled)"
607 | ]
608 | },
609 | {
610 | "cell_type": "markdown",
611 | "id": "4446993a-e5b4-40b9-a898-f2b8e9e03576",
612 | "metadata": {
613 | "jp-MarkdownHeadingCollapsed": true,
614 | "tags": []
615 | },
616 | "source": [
617 | "# Experimental: Concept Learning"
618 | ]
619 | },
620 | {
621 | "cell_type": "code",
622 | "execution_count": null,
623 | "id": "4e32d746-3cb1-4a27-a537-0e0baad20830",
624 | "metadata": {},
625 | "outputs": [],
626 | "source": [
627 | "def text_encode(x, clip_model, insertion_index):\n",
628 | " # x = x.type(clip_model.dtype)\n",
629 | " x = x + clip_model.positional_embedding\n",
630 | " x = x.permute(1, 0, 2) # NLD -> LND\n",
631 | " x = clip_model.transformer(x)\n",
632 | " x = x.permute(1, 0, 2) # LND -> NLD\n",
633 | " x = clip_model.ln_final(x)\n",
634 | "\n",
635 | " # x.shape = [batch_size, n_ctx, transformer.width]\n",
636 | " # take features from the eot embedding (eot_token is the highest number in each sequence)\n",
637 | " x = x[torch.arange(x.shape[0]), insertion_index] @ clip_model.text_projection\n",
638 | "\n",
639 | " return x"
640 | ]
641 | },
642 | {
643 | "cell_type": "code",
644 | "execution_count": null,
645 | "id": "e391e34d-ddab-4718-83d3-3cbf56c19363",
646 | "metadata": {},
647 | "outputs": [],
648 | "source": [
649 | "from torch.optim import AdamW\n",
650 | "batch_size = 1\n",
651 | "asteriks_emb = clip_model.token_embedding(tokenizer.tokenize([\"*\"]).to(device))[0][1]\n",
652 | "context_word = torch.randn(batch_size, 1, asteriks_emb.shape[-1]).to(device)\n",
653 | "context_word.requires_grad_(True)\n",
654 | "optim = AdamW(params=[context_word], lr=0.1)\n",
655 | "criterion = nn.CrossEntropyLoss(label_smoothing=0.1)"
656 | ]
657 | },
658 | {
659 | "cell_type": "code",
660 | "execution_count": null,
661 | "id": "ff0c9101-c62e-465c-a3aa-2d2afd5de075",
662 | "metadata": {},
663 | "outputs": [],
664 | "source": [
665 | "import requests\n",
666 | "from torch.utils.data import TensorDataset\n",
667 | "\n",
668 | "_preprocess = torchvision.transforms.Compose([\n",
669 | " torchvision.transforms.Resize(256),\n",
670 | " torchvision.transforms.CenterCrop(256),\n",
671 | " torchvision.transforms.ToTensor(),\n",
672 | "])\n",
673 | "\n",
674 | "urls = [\n",
675 | " \"https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcStVHtFcMqIP4xuDYn8n_FzPDKjPtP_iTSbOQ&usqp=CAU\",\n",
676 | " \"https://i.insider.com/58d919eaf2d0331b008b4bbd?width=700\",\n",
677 | " \"https://media.cntraveler.com/photos/5539216cab60aad20f3f3aaa/16:9/w_2560%2Cc_limit/eiffel-tower-paris-secret-apartment.jpg\",\n",
678 | " \"https://static.independent.co.uk/s3fs-public/thumbnails/image/2014/03/25/12/eiffel.jpg?width=1200\"\n",
679 | "]\n",
680 | "images = []\n",
681 | "for url in urls:\n",
682 | " response = requests.get(url)\n",
683 | " img = Image.open(BytesIO(response.content))\n",
684 | " images.append(_preprocess(img))\n",
685 | "\n",
686 | "data = torch.stack(images)\n",
687 | "dataset = DataLoader(TensorDataset(data), batch_size=1, shuffle=True)\n",
688 | "loader = iter(dataset)"
689 | ]
690 | },
691 | {
692 | "cell_type": "code",
693 | "execution_count": null,
694 | "id": "d3db5f60-41d5-47e5-9c4b-b9b2911467e5",
695 | "metadata": {
696 | "scrolled": true,
697 | "tags": []
698 | },
699 | "outputs": [],
700 | "source": [
701 | "steps = 100\n",
702 | "total_loss = 0\n",
703 | "total_acc = 0\n",
704 | "pbar = tqdm(range(steps))\n",
705 | "for i in pbar:\n",
706 | " try:\n",
707 | " images = next(loader)[0]\n",
708 | " except StopIteration:\n",
709 | " loader = iter(dataset)\n",
710 | " images = next(loader)[0]\n",
711 | " images = images.to(device)\n",
712 | " text = \"a photo of *\"\n",
713 | " tokenized_text = tokenizer.tokenize([text]).to(device)\n",
714 | " insertion_index = tokenized_text.argmax(dim=-1)\n",
715 | " neutral_text_encoded = clip_model.token_embedding(tokenized_text)\n",
716 | " insertion_idx = torch.where(neutral_text_encoded == asteriks_emb)[1].unique()\n",
717 | " neutral_text_encoded[:, insertion_idx, :] = context_word\n",
718 | " clip_embeddings = text_encode(neutral_text_encoded, clip_model, insertion_index)\n",
719 | " with torch.no_grad():\n",
720 | " image_indices = encode(images)\n",
721 | " r = torch.rand(images.size(0), device=device)\n",
722 | " noised_indices, mask = model.add_noise(image_indices, r)\n",
723 | "\n",
724 | " # with torch.autocast(device_type=\"cuda\"):\n",
725 | " pred = model(noised_indices, clip_embeddings, r)\n",
726 | " loss = criterion(pred, image_indices)\n",
727 | " \n",
728 | " loss.backward()\n",
729 | " optim.step()\n",
730 | " optim.zero_grad()\n",
731 | " \n",
732 | " acc = (pred.argmax(1) == image_indices).float() # .mean()\n",
733 | " acc = acc.mean()\n",
734 | "\n",
735 | " total_loss += loss.item()\n",
736 | " total_acc += acc.item()\n",
737 | " pbar.set_postfix({\"total_loss\": total_loss / (i+1), \"total_acc\": total_acc / (i+1)})\n"
738 | ]
739 | },
740 | {
741 | "cell_type": "code",
742 | "execution_count": null,
743 | "id": "ce95a42f-e121-43da-b432-b5f81c408ba8",
744 | "metadata": {},
745 | "outputs": [],
746 | "source": [
747 | "with torch.inference_mode():\n",
748 | " with torch.autocast(device_type=\"cuda\"):\n",
749 | " sampled = sample(model, clip_embeddings.expand(4, -1), T=12, size=(32, 32), starting_t=0, temp_range=[1., 1.],\n",
750 | " typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=4, renoise_steps=11)\n",
751 | " sampled = decode(sampled[-1])\n",
752 | "\n",
753 | "plt.figure(figsize=(32, 32))\n",
754 | "plt.axis(\"off\")\n",
755 | "plt.imshow(torch.cat([\n",
756 | " torch.cat([i for i in images.expand(4, -1, -1, -1).cpu()], dim=-1),\n",
757 | " torch.cat([i for i in sampled.cpu()], dim=-1),\n",
758 | "], dim=-2).permute(1, 2, 0).cpu())\n",
759 | "plt.show()"
760 | ]
761 | },
762 | {
763 | "cell_type": "code",
764 | "execution_count": null,
765 | "id": "31eac484-b3bf-428a-8e59-1742dbca6cef",
766 | "metadata": {},
767 | "outputs": [],
768 | "source": [
769 | "text = \"* at night\"\n",
770 | "tokenized_text = tokenizer.tokenize([text]).to(device)\n",
771 | "insertion_index = tokenized_text.argmax(dim=-1)\n",
772 | "neutral_text_encoded = clip_model.token_embedding(tokenized_text)\n",
773 | "insertion_idx = torch.where(neutral_text_encoded == asteriks_emb)[1].unique()\n",
774 | "neutral_text_encoded[:, insertion_idx, :] = context_word\n",
775 | "clip_embeddings = text_encode(neutral_text_encoded, clip_model, insertion_index)\n",
776 | "with torch.inference_mode():\n",
777 | " with torch.autocast(device_type=\"cuda\"):\n",
778 | " sampled = sample(model, clip_embeddings.expand(4, -1), T=12, size=(32, 32), starting_t=0, temp_range=[1., 1.],\n",
779 | " typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=4, renoise_steps=11)\n",
780 | " sampled = decode(sampled[-1])\n",
781 | "\n",
782 | "plt.figure(figsize=(32, 32))\n",
783 | "plt.axis(\"off\")\n",
784 | "plt.imshow(torch.cat([\n",
785 | " torch.cat([i for i in images.expand(4, -1, -1, -1).cpu()], dim=-1),\n",
786 | " torch.cat([i for i in sampled.cpu()], dim=-1),\n",
787 | "], dim=-2).permute(1, 2, 0).cpu())\n",
788 | "plt.show()"
789 | ]
790 | }
791 | ],
792 | "metadata": {
793 | "colab": {
794 | "collapsed_sections": [],
795 | "name": "DenoiseGIT_sampling.ipynb",
796 | "provenance": []
797 | },
798 | "kernelspec": {
799 | "display_name": "Python 3 (ipykernel)",
800 | "language": "python",
801 | "name": "python3"
802 | },
803 | "language_info": {
804 | "codemirror_mode": {
805 | "name": "ipython",
806 | "version": 3
807 | },
808 | "file_extension": ".py",
809 | "mimetype": "text/x-python",
810 | "name": "python",
811 | "nbconvert_exporter": "python",
812 | "pygments_lexer": "ipython3",
813 | "version": "3.8.12"
814 | }
815 | },
816 | "nbformat": 4,
817 | "nbformat_minor": 5
818 | }
819 |
--------------------------------------------------------------------------------