├── LICENSE ├── README.md ├── laion30k └── datacheck.py ├── utils.py ├── paella_minimal.py ├── modules.py ├── paella.py ├── evaluation ├── evaluation_generation.py └── evaluator.py └── paella_sampling.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 delicious-tasty 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Paella (Anonymized CVPR submission) 2 | Conditional text-to-image generation has seen countless recent improvements in terms of quality, diversity and fidelity. Nevertheless, most state-of-the-art models require numerous inference steps to produce faithful generations, resulting in performance bottlenecks for end-user applications. In this paper we introduce Paella, a novel text-to-image model requiring less than 10 steps to sample high-fidelity images, using a speed-optimized architecture allowing to sample a single image in less than 500 ms, while having 573M parameters. The model operates on a compressed & quantized latent space, it is conditioned on CLIP embeddings and uses an improved sampling function over previous works. Aside from text-conditional image generation, our model is able to do latent space interpolation and image manipulations such as inpainting, outpainting, and structural editing. 3 |
4 |
5 | ![cover-figure](https://user-images.githubusercontent.com/117442814/201474789-a192f6ab-9626-4402-a3ec-81b8f3fd436c.png) 6 | 7 |
8 | 9 | ## Code 10 | We especially want to highlight the minimalistic amount of code that is necessary to run & train Paella. The entire code including training, sampling, architecture and utilities can fit in approx. 400 lines of code. We hope to make this method more accessible to more people this way. 11 | 12 | ## Sampling 13 | For sampling you can just take a look at the [sampling.ipynb](https://github.com/delicious-tasty/Paella/blob/main/paella_sampling.ipynb) notebook. :sunglasses: 14 | 15 | ## Train your own Paella 16 | The main file for training will be [paella.py](https://github.com/delicious-tasty/Paella/blob/main/paella.py). You can adjust all [hyperparameters](https://github.com/delicious-tasty/Paella/blob/main/paella.py#L322) to your own needs. During training we use webdataset, but you are free to replace that with your own custom dataloader. Just change the line on 119 in [paella.py](https://github.com/delicious-tasty/Paella/blob/main/paella.py#L119) to point to your own dataloader. Make sure it returns a tuple of ```(images, captions)``` where ```images``` is a ```torch.Tensor``` of shape ```batch_size x channels x height x width``` and captions is a ```List``` of length ```batch_size```. Now decide if you want to finetune Paella or start a new training from scratch: 17 | ### From Scratch 18 | ``` 19 | python3 paella.py 20 | ``` 21 | ### Finetune 22 | If you want to finetune you first need to download the [latest checkpoint and it's optimizer state](epic_download_link.py), set the [finetune hyperparameter](https://github.com/delicious-tasty/Paella/blob/main/paella.py#L254) to ```True``` and create a folder ```models/``` and move both checkpoints to this folder. After that you can also just run: 23 | ``` 24 | python3 paella.py 25 | ``` 26 | -------------------------------------------------------------------------------- /laion30k/datacheck.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import torch 4 | import torchvision 5 | import webdataset as wds 6 | from matplotlib import pyplot as plt 7 | from torch.utils.data import DataLoader 8 | from webdataset.handlers import warn_and_continue 9 | 10 | 11 | def clean_caption(caption): 12 | caption = re.sub(" +", " ", caption) 13 | if caption[0] == "\"" or caption[0] == "'": 14 | caption = caption[1:] 15 | if caption[-1] == "\"" or caption[-1] == "'": 16 | caption = caption[:-1] 17 | return caption 18 | 19 | 20 | def filter_captions(caption): 21 | possible_url_hints = ["www.", ".com", "http"] 22 | forbidden_characters = ["-", "_", ":", ";", "(", ")", "/", "%", "|", "?"] 23 | forbidden_words = ["download", "interior", "kitchen", "chair", "getty", "how", "what", "when", "why", "laminate", "furniture", "hair", "dress", "clothing"] 24 | if len(caption.split(" ")) < 2: 25 | print(False) 26 | return False 27 | if not all([False if i in caption else True for i in forbidden_characters]): 28 | print(False) 29 | return False 30 | if len(caption) > 150: 31 | print(False) 32 | return False 33 | if not all(ord(c) < 128 for c in caption): 34 | return False 35 | if not all([False if i in caption else True for i in possible_url_hints]): 36 | return False 37 | if any(char.isdigit() for char in caption): 38 | return False 39 | if not all([False if i in caption.lower() else True for i in forbidden_words]): 40 | return False 41 | return True 42 | 43 | 44 | class ProcessDataV2: 45 | def __init__(self,): 46 | self.transforms = torchvision.transforms.Compose([ 47 | torchvision.transforms.ToTensor(), 48 | torchvision.transforms.Resize(256), 49 | torchvision.transforms.CenterCrop(256), 50 | ]) 51 | 52 | def __call__(self, data): 53 | data["jpg"] = self.transforms(data["jpg"]) 54 | return data 55 | 56 | 57 | def collate(batch): 58 | images = torch.stack([i[0] for i in batch], dim=0) 59 | json_file = [i[1] for i in batch] 60 | captions = [i[2] for i in batch] 61 | return [images, json_file, captions] 62 | 63 | 64 | def get_dataloader_new(path): 65 | dataset = wds.WebDataset(path, resampled=True, handler=warn_and_continue).decode("rgb", handler=warn_and_continue).map( 66 | ProcessDataV2(), handler=warn_and_continue).to_tuple("jpg", "json", "txt", handler=warn_and_continue).shuffle(1000, handler=warn_and_continue) 67 | dataloader = DataLoader(dataset, batch_size=10, collate_fn=collate) 68 | return dataloader 69 | 70 | dataset_length = 30_000 71 | dataset_path = "30k" 72 | print(os.getcwd()) 73 | os.makedirs(dataset_path, exist_ok=True) 74 | # path = "file:000069.tar" 75 | path = "path_to_laion_aesthetic_dataset" 76 | dataloader = get_dataloader_new(path) 77 | idx = 0 78 | 79 | 80 | for _, (images, json_files, captions) in enumerate(dataloader): 81 | if idx < dataset_length: 82 | f = [i for i, json_file in enumerate(json_files) if json_file["AESTHETIC_SCORE"] > 6.0] 83 | if f: 84 | print(f) 85 | f = [i for i in f if filter_captions(captions[i])] 86 | captions = [clean_caption(captions[i]) for i in f if filter_captions(captions[i])] 87 | if f: 88 | print(captions) 89 | aesthetic_images = images[f] 90 | for image, caption in zip(aesthetic_images, captions): 91 | torchvision.utils.save_image(image, os.path.join(dataset_path, f"{idx}.jpg")) 92 | open(os.path.join(dataset_path, f"{idx}.txt"), "w").write(caption) 93 | idx += 1 94 | # plt.figure(figsize=(32, 32)) 95 | # plt.imshow(torch.cat([ 96 | # torch.cat([i for i in aesthetic_images.cpu()], dim=-1), 97 | # ], dim=-2).permute(1, 2, 0).cpu()) 98 | # plt.show() 99 | 100 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import webdataset as wds 4 | from torch.utils.data import DataLoader 5 | from webdataset.handlers import warn_and_continue 6 | 7 | 8 | def encode(vq, x): 9 | return vq.model.encode((2 * x - 1))[-1][-1] 10 | 11 | 12 | def decode(vq, z): 13 | return vq.decode(z.view(z.shape[0], -1)) 14 | 15 | 16 | def log(t, eps=1e-20): 17 | return torch.log(t + eps) 18 | 19 | 20 | def gumbel_noise(t): 21 | noise = torch.zeros_like(t).uniform_(0, 1) 22 | return -log(-log(noise)) 23 | 24 | 25 | def gumbel_sample(t, temperature=1., dim=-1): 26 | return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim) 27 | 28 | 29 | def sample(model, c, x=None, mask=None, T=12, size=(32, 32), starting_t=0, temp_range=[1.0, 1.0], typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=-1, renoise_steps=11, renoise_mode='start'): 30 | with torch.inference_mode(): 31 | r_range = torch.linspace(0, 1, T+1)[:-1][:, None].expand(-1, c.size(0)).to(c.device) 32 | temperatures = torch.linspace(temp_range[0], temp_range[1], T) 33 | if x is None: 34 | x = torch.randint(0, model.num_labels, size=(c.size(0), *size), device=c.device) 35 | elif mask is not None: 36 | noise = torch.randint(0, model.num_labels, size=(c.size(0), *size), device=c.device) 37 | x = noise * mask + (1-mask) * x 38 | init_x = x.clone() 39 | for i in range(starting_t, T): 40 | if renoise_mode == 'prev': 41 | prev_x = x.clone() 42 | r, temp = r_range[i], temperatures[i] 43 | logits = model(x, c, r) 44 | if classifier_free_scale >= 0: 45 | logits_uncond = model(x, torch.zeros_like(c), r) 46 | logits = torch.lerp(logits_uncond, logits, classifier_free_scale) 47 | x = logits 48 | x_flat = x.permute(0, 2, 3, 1).reshape(-1, x.size(1)) 49 | if typical_filtering: 50 | x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1) 51 | x_flat_norm_p = torch.exp(x_flat_norm) 52 | entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True) 53 | 54 | c_flat_shifted = torch.abs((-x_flat_norm) - entropy) 55 | c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False) 56 | x_flat_cumsum = x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1) 57 | 58 | last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1) 59 | sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(1, last_ind.view(-1, 1)) 60 | if typical_min_tokens > 1: 61 | sorted_indices_to_remove[..., :typical_min_tokens] = 0 62 | indices_to_remove = sorted_indices_to_remove.scatter(1, x_flat_indices, sorted_indices_to_remove) 63 | x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf")) 64 | # x_flat = torch.multinomial(x_flat.div(temp).softmax(-1), num_samples=1)[:, 0] 65 | x_flat = gumbel_sample(x_flat, temperature=temp) 66 | x = x_flat.view(x.size(0), *x.shape[2:]) 67 | if mask is not None: 68 | x = x * mask + (1-mask) * init_x 69 | if i < renoise_steps: 70 | if renoise_mode == 'start': 71 | x, _ = model.add_noise(x, r_range[i+1], random_x=init_x) 72 | elif renoise_mode == 'prev': 73 | x, _ = model.add_noise(x, r_range[i+1], random_x=prev_x) 74 | else: # 'rand' 75 | x, _ = model.add_noise(x, r_range[i+1]) 76 | return x.detach() 77 | 78 | 79 | class ProcessData: 80 | def __init__(self, image_size=256): 81 | self.transforms = torchvision.transforms.Compose([ 82 | torchvision.transforms.ToTensor(), 83 | torchvision.transforms.Resize(image_size), 84 | torchvision.transforms.RandomCrop(image_size), 85 | ]) 86 | 87 | def __call__(self, data): 88 | data["jpg"] = self.transforms(data["jpg"]) 89 | return data 90 | 91 | 92 | def collate(batch): 93 | images = torch.stack([i[0] for i in batch], dim=0) 94 | captions = [i[1] for i in batch] 95 | return [images, captions] 96 | 97 | 98 | def get_dataloader(args): 99 | dataset = wds.WebDataset(args.dataset_path, resampled=True, handler=warn_and_continue).decode("rgb", handler=warn_and_continue).map( 100 | ProcessData(args.image_size), handler=warn_and_continue).to_tuple("jpg", "txt", handler=warn_and_continue).shuffle(690, handler=warn_and_continue) 101 | dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=collate) 102 | return dataloader 103 | -------------------------------------------------------------------------------- /paella_minimal.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import torch 4 | from torch import nn, optim 5 | import torchvision 6 | from tqdm import tqdm 7 | import numpy as np 8 | import torch.multiprocessing as mp 9 | import torch.distributed as dist 10 | from torch.nn.parallel import DistributedDataParallel 11 | from modules import DenoiseUNet 12 | from utils import get_dataloader, sample, encode, decode 13 | import open_clip 14 | from open_clip import tokenizer 15 | from rudalle import get_vae 16 | 17 | 18 | def train(proc_id, args): 19 | parallel = len(args.devices) > 1 20 | device = torch.device(proc_id) 21 | 22 | vqmodel = get_vae().to(device) 23 | vqmodel.eval().requires_grad_(False) 24 | 25 | if parallel: 26 | torch.cuda.set_device(proc_id) 27 | torch.backends.cudnn.benchmark = True 28 | dist.init_process_group(backend="nccl", init_method="file://dist_file", world_size=args.n_nodes * len(args.devices), rank=proc_id + len(args.devices) * args.node_id) 29 | torch.set_num_threads(6) 30 | 31 | model = DenoiseUNet(num_labels=args.num_codebook_vectors, c_clip=1024).to(device) 32 | 33 | clip_model, _, _ = open_clip.create_model_and_transforms('ViT-g-14', pretrained='laion2b_s12b_b42k') 34 | del clip_model.visual 35 | clip_model = clip_model.to(device).eval().requires_grad_(False) 36 | 37 | lr = 3e-4 38 | dataset = get_dataloader(args) 39 | optimizer = optim.AdamW(model.parameters(), lr=lr) 40 | criterion = nn.CrossEntropyLoss(label_smoothing=0.1) 41 | 42 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=math.ceil(1000 / args.accum_grad), epochs=600, pct_start=30 / 300, div_factor=25, final_div_factor=1 / 25, anneal_strategy='linear') 43 | 44 | losses, accuracies = [], [] 45 | start_step, total_loss, total_acc = 0, 0, 0 46 | 47 | if parallel: 48 | model = DistributedDataParallel(model, device_ids=[device], output_device=device) 49 | 50 | pbar = tqdm(enumerate(dataset, start=start_step), total=args.total_steps, initial=start_step) if args.node_id == 0 and proc_id == 0 else enumerate(dataset, start=start_step) 51 | model.train() 52 | for step, (images, captions) in pbar: 53 | images = images.to(device) 54 | with torch.no_grad(): 55 | image_indices = encode(vqmodel, images) # encode images (batch_size x 3 x 256 x 256) to tokens (batch_size x 32 x 32) 56 | r = torch.rand(images.size(0), device=device) # generate random timesteps 57 | noised_indices, mask = model.module.add_noise(image_indices, r) # noise the tokens according to the timesteps 58 | 59 | if np.random.rand() < 0.1: # 10% of the times -> unconditional training for classifier-free-guidance 60 | text_embeddings = images.new_zeros(images.size(0), 1024) 61 | else: 62 | text_tokens = tokenizer.tokenize(captions) 63 | text_tokens = text_tokens.to(device) 64 | text_embeddings = clip_model.encode_text(text_tokens).float() # text embeddings (batch_size x 1024) 65 | 66 | pred = model(noised_indices, text_embeddings, r) # predict denoised tokens (batch_size x 32 x 32 x 8192 67 | loss = criterion(pred, image_indices) # cross entropy loss 68 | loss_adjusted = loss / args.accum_grad 69 | 70 | loss_adjusted.backward() 71 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 5).item() 72 | if (step + 1) % args.accum_grad == 0: 73 | optimizer.step() 74 | scheduler.step() 75 | optimizer.zero_grad() 76 | 77 | acc = (pred.argmax(1) == image_indices).float().mean() 78 | 79 | total_loss += loss.item() 80 | total_acc += acc.item() 81 | 82 | if not proc_id and args.node_id == 0: 83 | pbar.set_postfix({"loss": total_loss / (step + 1), "acc": total_acc / (step + 1), "curr_loss": loss.item(), "curr_acc": acc.item(), "ppx": np.exp(total_loss / (step + 1)), "lr": optimizer.param_groups[0]['lr'], "grad_norm": grad_norm}) 84 | 85 | if args.node_id == 0 and proc_id == 0 and step % args.log_period == 0: 86 | print(f"Step {step} - loss {total_loss / (step + 1)} - acc {total_acc / (step + 1)} - ppx {np.exp(total_loss / (step + 1))}") 87 | 88 | losses.append(total_loss / (step + 1)) 89 | accuracies.append(total_acc / (step + 1)) 90 | 91 | model.eval() 92 | with torch.no_grad(): 93 | sampled = sample(model.module, c=text_embeddings)[-1] 94 | sampled = decode(vqmodel, sampled) 95 | 96 | model.train() 97 | log_images = torch.cat([torch.cat([i for i in sampled.cpu()], dim=-1)], dim=-2) 98 | torchvision.utils.save_image(log_images, os.path.join(f"results/{args.run_name}", f"{step:03d}.png")) 99 | 100 | del sampled 101 | 102 | torch.save(model.module.state_dict(), f"models/{args.run_name}/model.pt") 103 | torch.save(optimizer.state_dict(), f"models/{args.run_name}/optim.pt") 104 | torch.save({'step': step, 'losses': losses, 'accuracies': accuracies}, f"results/{args.run_name}/log.pt") 105 | 106 | del images, image_indices, r, text_embeddings 107 | del noised_indices, mask, pred, loss, loss_adjusted, acc 108 | 109 | 110 | def launch(args): 111 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(d) for d in args.devices]) 112 | if len(args.devices) == 1: 113 | train(0, args) 114 | else: 115 | os.environ["MASTER_ADDR"] = "localhost" 116 | os.environ["MASTER_PORT"] = "33751" 117 | p = mp.spawn(train, nprocs=len(args.devices), args=(args,)) 118 | 119 | 120 | if __name__ == '__main__': 121 | import argparse 122 | parser = argparse.ArgumentParser() 123 | args = parser.parse_args() 124 | args.run_name = "Paella_f8_8192" 125 | args.dataset_type = "webdataset" 126 | args.total_steps = 501_000 127 | args.batch_size = 22 128 | args.image_size = 256 129 | args.num_workers = 10 130 | args.log_period = 5000 131 | args.accum_grad = 1 132 | args.num_codebook_vectors = 8192 133 | 134 | args.n_nodes = 8 135 | args.node_id = int(os.environ["SLURM_PROCID"]) 136 | args.devices = [0, 1, 2, 3, 4, 5, 6, 7] 137 | 138 | args.dataset_path = "" 139 | print("Launching with args: ", args) 140 | launch( 141 | args 142 | ) 143 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class ModulatedLayerNorm(nn.Module): 8 | def __init__(self, num_features, eps=1e-6, channels_first=True): 9 | super().__init__() 10 | self.ln = nn.LayerNorm(num_features, eps=eps) 11 | self.gamma = nn.Parameter(torch.randn(1, 1, 1)) 12 | self.beta = nn.Parameter(torch.randn(1, 1, 1)) 13 | self.channels_first = channels_first 14 | 15 | def forward(self, x, w=None): 16 | x = x.permute(0, 2, 3, 1) if self.channels_first else x 17 | if w is None: 18 | x = self.ln(x) 19 | else: 20 | x = self.gamma * w * self.ln(x) + self.beta * w 21 | x = x.permute(0, 3, 1, 2) if self.channels_first else x 22 | return x 23 | 24 | 25 | class ResBlock(nn.Module): 26 | def __init__(self, c, c_hidden, c_cond=0, c_skip=0, scaler=None, layer_scale_init_value=1e-6): 27 | super().__init__() 28 | self.depthwise = nn.Sequential( 29 | nn.ReflectionPad2d(1), 30 | nn.Conv2d(c, c, kernel_size=3, groups=c) 31 | ) 32 | self.ln = ModulatedLayerNorm(c, channels_first=False) 33 | self.channelwise = nn.Sequential( 34 | nn.Linear(c+c_skip, c_hidden), 35 | nn.GELU(), 36 | nn.Linear(c_hidden, c), 37 | ) 38 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(c), requires_grad=True) if layer_scale_init_value > 0 else None 39 | self.scaler = scaler 40 | if c_cond > 0: 41 | self.cond_mapper = nn.Linear(c_cond, c) 42 | 43 | def forward(self, x, s=None, skip=None): 44 | res = x 45 | x = self.depthwise(x) 46 | if s is not None: 47 | s = self.cond_mapper(s.permute(0, 2, 3, 1)) 48 | if s.size(1) == s.size(2) == 1: 49 | s = s.expand(-1, x.size(2), x.size(3), -1) 50 | x = self.ln(x.permute(0, 2, 3, 1), s) 51 | if skip is not None: 52 | x = torch.cat([x, skip.permute(0, 2, 3, 1)], dim=-1) 53 | x = self.channelwise(x) 54 | x = self.gamma * x if self.gamma is not None else x 55 | x = res + x.permute(0, 3, 1, 2) 56 | if self.scaler is not None: 57 | x = self.scaler(x) 58 | return x 59 | 60 | 61 | class DenoiseUNet(nn.Module): 62 | def __init__(self, num_labels, c_hidden=1280, c_clip=1024, c_r=64, down_levels=[4, 8, 16], up_levels=[16, 8, 4]): 63 | super().__init__() 64 | self.num_labels = num_labels 65 | self.c_r = c_r 66 | c_levels = [c_hidden // (2 ** i) for i in reversed(range(len(down_levels)))] 67 | self.embedding = nn.Embedding(num_labels, c_levels[0]) 68 | 69 | # DOWN BLOCKS 70 | self.down_blocks = nn.ModuleList() 71 | for i, num_blocks in enumerate(down_levels): 72 | blocks = [] 73 | if i > 0: 74 | blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) 75 | for j in range(num_blocks): 76 | block = ResBlock(c_levels[i], c_levels[i] * 4, c_clip + c_r) 77 | block.channelwise[-1].weight.data *= np.sqrt(1 / sum(down_levels)) 78 | blocks.append(block) 79 | self.down_blocks.append(nn.ModuleList(blocks)) 80 | 81 | # UP BLOCKS 82 | self.up_blocks = nn.ModuleList() 83 | for i, num_blocks in enumerate(up_levels): 84 | blocks = [] 85 | for j in range(num_blocks): 86 | block = ResBlock(c_levels[len(c_levels) - 1 - i], c_levels[len(c_levels) - 1 - i] * 4, c_clip + c_r, 87 | c_levels[len(c_levels) - 1 - i] if (j == 0 and i > 0) else 0) 88 | block.channelwise[-1].weight.data *= np.sqrt(1 / sum(up_levels)) 89 | blocks.append(block) 90 | if i < len(up_levels) - 1: 91 | blocks.append( 92 | nn.ConvTranspose2d(c_levels[len(c_levels) - 1 - i], c_levels[len(c_levels) - 2 - i], kernel_size=4, stride=2, padding=1)) 93 | self.up_blocks.append(nn.ModuleList(blocks)) 94 | 95 | self.clf = nn.Conv2d(c_levels[0], num_labels, kernel_size=1) 96 | 97 | def gamma(self, r): 98 | return (r * torch.pi / 2).cos() 99 | 100 | def add_noise(self, x, r, random_x=None): 101 | r = self.gamma(r)[:, None, None] 102 | mask = torch.bernoulli(r * torch.ones_like(x), ) 103 | mask = mask.round().long() 104 | if random_x is None: 105 | random_x = torch.randint_like(x, 0, self.num_labels) 106 | x = x * (1 - mask) + random_x * mask 107 | return x, mask 108 | 109 | def gen_r_embedding(self, r, max_positions=10000): 110 | r = self.gamma(r) * max_positions 111 | half_dim = self.c_r // 2 112 | emb = math.log(max_positions) / (half_dim - 1) 113 | emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() 114 | emb = r[:, None] * emb[None, :] 115 | emb = torch.cat([emb.sin(), emb.cos()], dim=1) 116 | if self.c_r % 2 == 1: # zero pad 117 | emb = nn.functional.pad(emb, (0, 1), mode='constant') 118 | return emb 119 | 120 | def _down_encode_(self, x, s): 121 | level_outputs = [] 122 | for i, blocks in enumerate(self.down_blocks): 123 | for block in blocks: 124 | if isinstance(block, ResBlock): 125 | x = block(x, s) 126 | else: 127 | x = block(x) 128 | level_outputs.insert(0, x) 129 | return level_outputs 130 | 131 | def _up_decode(self, level_outputs, s): 132 | x = level_outputs[0] 133 | for i, blocks in enumerate(self.up_blocks): 134 | for j, block in enumerate(blocks): 135 | if isinstance(block, ResBlock): 136 | if i > 0 and j == 0: 137 | x = block(x, s, level_outputs[i]) 138 | else: 139 | x = block(x, s) 140 | else: 141 | x = block(x) 142 | return x 143 | 144 | def forward(self, x, c, r): # r is a uniform value between 0 and 1 145 | r_embed = self.gen_r_embedding(r) 146 | x = self.embedding(x).permute(0, 3, 1, 2) 147 | s = torch.cat([c, r_embed], dim=-1)[:, :, None, None] 148 | level_outputs = self._down_encode_(x, s) 149 | x = self._up_decode(level_outputs, s) 150 | x = self.clf(x) 151 | return x 152 | 153 | 154 | if __name__ == '__main__': 155 | device = "cuda" 156 | model = DenoiseUNet(1024).to(device) 157 | print(sum([p.numel() for p in model.parameters()])) 158 | x = torch.randint(0, 1024, (1, 32, 32)).long().to(device) 159 | c = torch.randn((1, 1024)).to(device) 160 | r = torch.rand(1).to(device) 161 | model(x, c, r) 162 | 163 | -------------------------------------------------------------------------------- /paella.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import torch 4 | from torch.utils.data import TensorDataset, DataLoader 5 | import wandb 6 | from torch import nn, optim 7 | import torchvision 8 | from tqdm import tqdm 9 | import time 10 | import numpy as np 11 | import torch.multiprocessing as mp 12 | import torch.distributed as dist 13 | from torch.nn.parallel import DistributedDataParallel 14 | from modules import DenoiseUNet 15 | from utils import get_dataloader, sample, encode, decode 16 | import open_clip 17 | from open_clip import tokenizer 18 | from rudalle import get_vae 19 | 20 | 21 | def train(proc_id, args): 22 | if os.path.exists(f"results/{args.run_name}/log.pt"): 23 | resume = True 24 | else: 25 | resume = False 26 | if not proc_id and args.node_id == 0: 27 | if resume: 28 | wandb.init(project="project", name=args.run_name, entity="your_entity", config=vars(args)) 29 | else: 30 | wandb.init(project="project", name=args.run_name, entity="your_entity", config=vars(args)) 31 | print(f"Starting run '{args.run_name}'....") 32 | print(f"Batch Size check: {args.n_nodes * args.batch_size * args.accum_grad * len(args.devices)}") 33 | parallel = len(args.devices) > 1 34 | device = torch.device(proc_id) 35 | 36 | vqmodel = get_vae().to(device) 37 | vqmodel.eval().requires_grad_(False) 38 | 39 | if parallel: 40 | torch.cuda.set_device(proc_id) 41 | torch.backends.cudnn.benchmark = True 42 | dist.init_process_group(backend="nccl", init_method="file://dist_file", 43 | world_size=args.n_nodes * len(args.devices), 44 | rank=proc_id + len(args.devices) * args.node_id) 45 | torch.set_num_threads(6) 46 | 47 | model = DenoiseUNet(num_labels=args.num_codebook_vectors, c_clip=1024).to(device) 48 | 49 | if not proc_id and args.node_id == 0: 50 | print(f"Number of Parameters: {sum([p.numel() for p in model.parameters()])}") 51 | 52 | clip_model, _, _ = open_clip.create_model_and_transforms('ViT-g-14', pretrained='laion2b_s12b_b42k') 53 | del clip_model.visual 54 | clip_model = clip_model.to(device).eval().requires_grad_(False) 55 | 56 | lr = 3e-4 57 | dataset = get_dataloader(args) 58 | optimizer = optim.AdamW(model.parameters(), lr=lr) 59 | criterion = nn.CrossEntropyLoss(label_smoothing=0.1) 60 | 61 | if not proc_id and args.node_id == 0: 62 | wandb.watch(model) 63 | os.makedirs(f"results/{args.run_name}", exist_ok=True) 64 | os.makedirs(f"models/{args.run_name}", exist_ok=True) 65 | 66 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, 67 | steps_per_epoch=math.ceil(1000 / args.accum_grad), 68 | epochs=600, pct_start=30 / 300, div_factor=25, 69 | final_div_factor=1 / 25, anneal_strategy='linear') 70 | 71 | if resume: 72 | if not proc_id and args.node_id == 0: 73 | print("Loading last checkpoint....") 74 | logs = torch.load(f"results/{args.run_name}/log.pt") 75 | start_step = logs["step"] + 1 76 | losses = logs["losses"] 77 | accuracies = logs["accuracies"] 78 | total_loss, total_acc = losses[-1] * start_step, accuracies[-1] * start_step 79 | model.load_state_dict(torch.load(f"models/{args.run_name}/model.pt", map_location=device)) 80 | if not proc_id and args.node_id == 0: 81 | print("Loaded model.") 82 | opt_state = torch.load(f"models/{args.run_name}/optim.pt", map_location=device) 83 | last_lr = opt_state["param_groups"][0]["lr"] 84 | with torch.no_grad(): 85 | for _ in range(logs["step"]): 86 | scheduler.step() 87 | if not proc_id and args.node_id == 0: 88 | print(f"Initialized scheduler") 89 | print(f"Sanity check => Last-LR: {last_lr} == Current-LR: {optimizer.param_groups[0]['lr']} -> {last_lr == optimizer.param_groups[0]['lr']}") 90 | optimizer.load_state_dict(opt_state) 91 | del opt_state 92 | else: 93 | losses = [] 94 | accuracies = [] 95 | start_step, total_loss, total_acc = 0, 0, 0 96 | 97 | if parallel: 98 | model = DistributedDataParallel(model, device_ids=[device], output_device=device) 99 | 100 | pbar = tqdm(enumerate(dataset, start=start_step), total=args.total_steps, initial=start_step) if args.node_id == 0 and proc_id == 0 else enumerate(dataset, start=start_step) 101 | model.train() 102 | for step, (images, captions) in pbar: 103 | images = images.to(device) 104 | with torch.no_grad(): 105 | image_indices = encode(vqmodel, images) 106 | r = torch.rand(images.size(0), device=device) 107 | noised_indices, mask = model.module.add_noise(image_indices, r) 108 | 109 | if np.random.rand() < 0.1: # 10% of the times -> unconditional training for classifier-free-guidance 110 | text_embeddings = images.new_zeros(images.size(0), 1024) 111 | else: 112 | text_tokens = tokenizer.tokenize(captions) 113 | text_tokens = text_tokens.to(device) 114 | text_embeddings = clip_model.encode_text(text_tokens).float() 115 | 116 | pred = model(noised_indices, text_embeddings, r) 117 | loss = criterion(pred, image_indices) 118 | loss_adjusted = loss / args.accum_grad 119 | 120 | loss_adjusted.backward() 121 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 5).item() 122 | if (step + 1) % args.accum_grad == 0: 123 | optimizer.step() 124 | scheduler.step() 125 | optimizer.zero_grad() 126 | 127 | acc = (pred.argmax(1) == image_indices).float() 128 | acc = acc.mean() 129 | 130 | total_loss += loss.item() 131 | total_acc += acc.item() 132 | 133 | if not proc_id and args.node_id == 0: 134 | log = { 135 | "loss": total_loss / (step + 1), 136 | "acc": total_acc / (step + 1), 137 | "curr_loss": loss.item(), 138 | "curr_acc": acc.item(), 139 | "ppx": np.exp(total_loss / (step + 1)), 140 | "lr": optimizer.param_groups[0]['lr'], 141 | "grad_norm": grad_norm 142 | } 143 | pbar.set_postfix(log) 144 | wandb.log(log) 145 | 146 | if args.node_id == 0 and proc_id == 0 and step % args.log_period == 0: 147 | print(f"Step {step} - loss {total_loss / (step + 1)} - acc {total_acc / (step + 1)} - ppx {np.exp(total_loss / (step + 1))}") 148 | 149 | losses.append(total_loss / (step + 1)) 150 | accuracies.append(total_acc / (step + 1)) 151 | 152 | model.eval() 153 | with torch.no_grad(): 154 | n = 1 155 | images = images[:10] 156 | image_indices = image_indices[:10] 157 | captions = captions[:10] 158 | text_embeddings = text_embeddings[:10] 159 | sampled = sample(model.module, c=text_embeddings)[-1] 160 | sampled = decode(vqmodel, sampled) 161 | recon_images = decode(vqmodel, image_indices) 162 | 163 | if args.log_captions: 164 | cool_captions_data = torch.load("cool_captions.pth") 165 | cool_captions_text = cool_captions_data["captions"] 166 | 167 | text_tokens = tokenizer.tokenize(cool_captions_text) 168 | text_tokens = text_tokens.to(device) 169 | cool_captions_embeddings = clip_model.encode_text(text_tokens).float() 170 | 171 | cool_captions = DataLoader(TensorDataset(cool_captions_embeddings.repeat_interleave(n, dim=0)), batch_size=11) 172 | cool_captions_sampled = [] 173 | cool_captions_sampled_ema = [] 174 | st = time.time() 175 | for caption_embedding in cool_captions: 176 | caption_embedding = caption_embedding[0].float().to(device) 177 | sampled_text = sample(model.module, c=caption_embedding)[-1] 178 | sampled_text = decode(vqmodel, sampled_text) 179 | sampled_text_ema = decode(vqmodel, sampled_text_ema) 180 | for s, t in zip(sampled_text, sampled_text_ema): 181 | cool_captions_sampled.append(s.cpu()) 182 | cool_captions_sampled_ema.append(t.cpu()) 183 | print(f"Took {time.time() - st} seconds to sample {len(cool_captions_text) * 2} captions.") 184 | 185 | cool_captions_sampled = torch.stack(cool_captions_sampled) 186 | torchvision.utils.save_image( 187 | torchvision.utils.make_grid(cool_captions_sampled, nrow=11), 188 | os.path.join(f"results/{args.run_name}", f"cool_captions_{step:03d}.png") 189 | ) 190 | 191 | cool_captions_sampled_ema = torch.stack(cool_captions_sampled_ema) 192 | torchvision.utils.save_image( 193 | torchvision.utils.make_grid(cool_captions_sampled_ema, nrow=11), 194 | os.path.join(f"results/{args.run_name}", f"cool_captions_{step:03d}_ema.png") 195 | ) 196 | 197 | log_images = torch.cat([ 198 | torch.cat([i for i in sampled.cpu()], dim=-1), 199 | ], dim=-2) 200 | 201 | model.train() 202 | 203 | torchvision.utils.save_image(log_images, os.path.join(f"results/{args.run_name}", f"{step:03d}.png")) 204 | 205 | log_data = [[captions[i]] + [wandb.Image(sampled[i])] + [wandb.Image(images[i])] + [wandb.Image(recon_images[i])] for i in range(len(captions))] 206 | log_table = wandb.Table(data=log_data, columns=["Caption", "Image", "EMA", "Orig", "Recon"]) 207 | wandb.log({"Log": log_table}) 208 | 209 | if args.log_captions: 210 | log_data_cool = [[cool_captions_text[i]] + [wandb.Image(cool_captions_sampled[i])] + [wandb.Image(cool_captions_sampled_ema[i])] for i in range(len(cool_captions_text))] 211 | log_table_cool = wandb.Table(data=log_data_cool, columns=["Caption", "Image", "EMA Image"]) 212 | wandb.log({"Log Cool": log_table_cool}) 213 | del sampled_text, log_data_cool 214 | 215 | del sampled, log_data 216 | 217 | if step % args.extra_ckpt == 0: 218 | torch.save(model.module.state_dict(), f"models/{args.run_name}/model_{step}.pt") 219 | torch.save(optimizer.state_dict(), f"models/{args.run_name}/model_{step}_optim.pt") 220 | torch.save(model.module.state_dict(), f"models/{args.run_name}/model.pt") 221 | torch.save(optimizer.state_dict(), f"models/{args.run_name}/optim.pt") 222 | torch.save({'step': step, 'losses': losses, 'accuracies': accuracies}, f"results/{args.run_name}/log.pt") 223 | 224 | del images, image_indices, r, text_embeddings 225 | del noised_indices, mask, pred, loss, loss_adjusted, acc 226 | 227 | 228 | def launch(args): 229 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(d) for d in args.devices]) 230 | if len(args.devices) == 1: 231 | train(0, args) 232 | else: 233 | os.environ["MASTER_ADDR"] = "localhost" 234 | os.environ["MASTER_PORT"] = "33751" 235 | p = mp.spawn(train, nprocs=len(args.devices), args=(args,)) 236 | 237 | 238 | if __name__ == '__main__': 239 | import argparse 240 | parser = argparse.ArgumentParser() 241 | args = parser.parse_args() 242 | args.run_name = "run_name" 243 | args.model = "UNet" 244 | args.dataset_type = "webdataset" 245 | args.total_steps = 501_000 246 | args.batch_size = 22 247 | args.image_size = 256 248 | args.num_workers = 10 249 | args.log_period = 5000 250 | args.extra_ckpt = 50_000 251 | args.accum_grad = 1 252 | args.num_codebook_vectors = 8192 253 | args.log_captions = True 254 | args.finetune = False 255 | 256 | args.n_nodes = 8 257 | args.node_id = int(os.environ["SLURM_PROCID"]) 258 | args.devices = [0, 1, 2, 3, 4, 5, 6, 7] 259 | 260 | args.dataset_path = "" 261 | print("Launching with args: ", args) 262 | launch( 263 | args 264 | ) 265 | -------------------------------------------------------------------------------- /evaluation/evaluation_generation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import queue 3 | import torch.multiprocessing as mp 4 | from collections import OrderedDict 5 | import os 6 | import time 7 | import torch 8 | import pandas as pd 9 | from itertools import product 10 | import torchvision 11 | import numpy as np 12 | import open_clip 13 | from open_clip import tokenizer 14 | from rudalle import get_vae 15 | from einops import rearrange 16 | import tensorflow.compat.v1 as tf 17 | from PIL import Image 18 | from evaluator import Evaluator 19 | from modules import DenoiseUNet 20 | 21 | 22 | def chunk(lst, n): 23 | return [lst[i:i + n] for i in range(0, len(lst), n)] 24 | 25 | 26 | def save_images_npz(path, save_path): 27 | arr = [] 28 | base_shape = None 29 | for item in os.listdir(path): 30 | if os.path.isfile(os.path.join(path, item)) and item.endswith(".jpg"): 31 | img = Image.open(os.path.join(path, item)) 32 | img = np.array(img.resize((256,256), Image.ANTIALIAS)) 33 | if base_shape is None: 34 | base_shape = img.shape 35 | try: 36 | if img.shape == base_shape: 37 | arr.append(img) 38 | except Exception as e: 39 | print(e) 40 | continue 41 | arr = np.stack(arr) 42 | print(arr.shape) 43 | np.savez(save_path, arr) 44 | 45 | 46 | def log(t, eps=1e-20): 47 | return torch.log(t + eps) 48 | 49 | 50 | def gumbel_noise(t): 51 | noise = torch.zeros_like(t).uniform_(0, 1) 52 | return -log(-log(noise)) 53 | 54 | 55 | def gumbel_sample(t, temperature=1., dim=-1): 56 | return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim) 57 | 58 | 59 | def sample(model, c, x=None, mask=None, T=12, size=(32, 32), starting_t=0, temp_range=[1.0, 1.0], typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=-1, renoise_steps=11, renoise_mode='start'): 60 | with torch.inference_mode(): 61 | r_range = torch.linspace(0, 1, T+1)[:-1][:, None].expand(-1, c.size(0)).to(c.device) 62 | temperatures = torch.linspace(temp_range[0], temp_range[1], T) 63 | if x is None: 64 | x = torch.randint(0, model.num_labels, size=(c.size(0), *size), device=c.device) 65 | if renoise_mode == 'start': 66 | init_x = x.clone() 67 | for i in range(starting_t, T): 68 | if renoise_mode == 'prev': 69 | prev_x = x.clone() 70 | r, temp = r_range[i], temperatures[i] 71 | logits = model(x, c, r) 72 | if classifier_free_scale >= 0: 73 | logits_uncond = model(x, torch.zeros_like(c), r) 74 | logits = torch.lerp(logits_uncond, logits, classifier_free_scale) 75 | x = logits 76 | x_flat = x.permute(0, 2, 3, 1).reshape(-1, x.size(1)) 77 | if typical_filtering: 78 | x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1) 79 | x_flat_norm_p = torch.exp(x_flat_norm) 80 | entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True) 81 | 82 | c_flat_shifted = torch.abs((-x_flat_norm) - entropy) 83 | c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False) 84 | x_flat_cumsum = x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1) 85 | 86 | last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1) 87 | sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(1, last_ind.view(-1, 1)) 88 | if typical_min_tokens > 1: 89 | sorted_indices_to_remove[..., :typical_min_tokens] = 0 90 | indices_to_remove = sorted_indices_to_remove.scatter(1, x_flat_indices, sorted_indices_to_remove) 91 | x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf")) 92 | x_flat = torch.multinomial(x_flat.div(temp).softmax(-1), num_samples=1)[:, 0] 93 | # print(x_flat.shape) 94 | # x_flat = gumbel_sample(x_flat, temperature=temp) 95 | x = x_flat.view(x.size(0), *x.shape[2:]) 96 | if i < renoise_steps: 97 | if renoise_mode == 'start': 98 | x, _ = model.add_noise(x, r_range[i+1], random_x=init_x) 99 | elif renoise_mode == 'prev': 100 | x, _ = model.add_noise(x, r_range[i+1], random_x=prev_x) 101 | else: # 'rand' 102 | x, _ = model.add_noise(x, r_range[i+1]) 103 | return x.detach() 104 | 105 | 106 | def encode(vq, x): 107 | return vq.encode(x)[-1] 108 | 109 | 110 | def decode(vq, z): 111 | return vq.decode_indices(z) 112 | 113 | 114 | class DatasetWriter: 115 | def __init__(self, date, base_path="/home/data"): 116 | self.date = date 117 | self.base_path = base_path 118 | 119 | def saveimages(self, imgs, captions, **kwargs): 120 | try: 121 | for img, caption in zip(imgs, captions): 122 | caption = caption.replace(" ", "_").replace(".", "") 123 | path = os.path.join(self.base_path, caption + ".jpg") 124 | torchvision.utils.save_image(img, path, **kwargs) 125 | except Exception as e: 126 | print(e) 127 | 128 | def save(self, que): 129 | while True: 130 | try: 131 | data = que.get(True, 1) 132 | except (queue.Empty, FileNotFoundError): 133 | continue 134 | if data is None: 135 | print("Finished") 136 | return 137 | sampled, captions = data["payload"] 138 | self.saveimages(sampled, captions) 139 | 140 | 141 | class Sample: 142 | def __init__(self, date, device, cfg_weight=5, steps=8, typical_filtering=True, batch_size=8, base_path="/home/data", dataset="coco", captions_path="cap.parquet"): 143 | self.date = date 144 | self.cfg_weight = cfg_weight 145 | self.steps = steps 146 | self.typical_filtering = typical_filtering 147 | self.dataset = dataset 148 | self.captions_path = captions_path 149 | self.device = torch.device(device) 150 | self.batch_size = batch_size 151 | self.base_path = base_path 152 | self.path = os.path.join(base_path, f"{steps}_{cfg_weight}_{typical_filtering}") 153 | self.model, self.vqmodel, self.clip, self.clip_preprocess = self.load_models() 154 | self.setup() 155 | self.que = mp.Queue() 156 | mp.Process(target=DatasetWriter(date, base_path=self.path).save, args=(self.que,)).start() 157 | 158 | def setup(self): 159 | if self.dataset == "coco": 160 | self.captions = pd.read_parquet(self.captions_path)["caption"] 161 | elif self.dataset == "laion": 162 | self.captions = pd.read_parquet(self.captions_path)["caption"] 163 | else: 164 | raise ValueError 165 | num_sampled = len(os.listdir(self.base_path)) 166 | self.captions = self.captions[num_sampled:] 167 | os.makedirs(self.path, exist_ok=True) 168 | with open(os.path.join(self.path, "log.json"), "w") as f: 169 | json.dump({ 170 | "date": self.date, 171 | "cfg": self.cfg_weight, 172 | "steps": self.steps 173 | }, f) 174 | 175 | def load_models(self): 176 | # --- Paella MODEL --- 177 | model_path = f"./models/Paella_f8_8192/model_600000.pt" 178 | state_dict = torch.load(model_path, map_location=self.device) 179 | # new_state_dict = OrderedDict() 180 | # for k, v in state_dict.items(): 181 | # name = k[7:] # remove `module.` 182 | # new_state_dict[name] = v 183 | model = DenoiseUNet(num_labels=8192, c_clip=1024).to(self.device) 184 | model.load_state_dict(state_dict) 185 | model.eval().requires_grad_() 186 | # --- VQ MODEL --- 187 | vqmodel = get_vae().to(self.device) 188 | vqmodel.eval().requires_grad_(False) 189 | # --- CLIP MODEL --- 190 | clip_model, _, _ = open_clip.create_model_and_transforms('ViT-g-14', pretrained='laion2b_s12b_b42k', cache_dir="/fsx/mas/.cache") 191 | del clip_model.visual 192 | clip_model = clip_model.to(self.device).eval().requires_grad_(False) 193 | clip_preprocess = torchvision.transforms.Compose([ 194 | torchvision.transforms.Resize(224), 195 | torchvision.transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), 196 | std=(0.26862954, 0.26130258, 0.27577711)), 197 | ]) 198 | return model, vqmodel, clip_model, clip_preprocess 199 | 200 | @torch.no_grad() 201 | def r_decode(self, img_seq, shape=(32, 32)): 202 | img_seq = img_seq.view(img_seq.shape[0], -1) 203 | one_hot_indices = torch.nn.functional.one_hot(img_seq, num_classes=self.vqmodel.num_tokens).float() 204 | z = (one_hot_indices @ self.vqmodel.model.quantize.embed.weight) 205 | z = rearrange(z, 'b (h w) c -> b c h w', h=shape[0], w=shape[1]) 206 | img = self.vqmodel.model.decode(z) 207 | img = (img.clamp(-1., 1.) + 1) * 0.5 208 | return img 209 | 210 | def convert_dataset(self): 211 | """ 212 | path: base_path + folder + tar_file 213 | """ 214 | batch_size = len(self.captions) / 8 215 | latent_shape = (32, 32) 216 | for cap in np.array_split(self.captions, batch_size): 217 | cap = list(cap) 218 | s = time.time() 219 | # print(len(cap)) 220 | text = tokenizer.tokenize(cap).to(self.device) 221 | with torch.inference_mode(): 222 | with torch.autocast(device_type="cuda"): 223 | clip_embeddings = self.clip.encode_text(text).float() 224 | 225 | sampled = sample(self.model, clip_embeddings, T=self.steps, size=latent_shape, starting_t=0, 226 | temp_range=[1.0, 1.0], 227 | typical_filtering=self.typical_filtering, typical_mass=0.2, typical_min_tokens=1, 228 | classifier_free_scale=self.cfg_weight, renoise_steps=self.steps - 1) 229 | sampled = self.r_decode(sampled, latent_shape) 230 | data = { 231 | "payload": [sampled, cap] 232 | } 233 | self.que.put(data) 234 | # print(f"Sampled {len(cap)} in {time.time() - s} seconds.") 235 | 236 | 237 | if __name__ == '__main__': 238 | mp.set_start_method('spawn') 239 | 240 | date = "f8_600k_no_ema" 241 | dataset = "coco" 242 | base_dir = "/fsx/mas/paella_unet/evaluation" 243 | ref_images = os.path.join(base_dir, f"{dataset}_30k.npz") 244 | ref_captions = os.path.join(base_dir, f"{dataset}_30k.parquet") 245 | base_path = os.path.join(base_dir, date) 246 | os.makedirs(base_path, exist_ok=True) 247 | devices = [0, 1, 2, 3, 4, 5, 6, 7] 248 | cfgs = [3, 4, 5] # [3, 4, 5, 8] 249 | steps = [12] # [6, 8, 10, 12] 250 | typical_filtering = [True, False] 251 | 252 | combinations = iter(list(product(steps, cfgs, typical_filtering))) 253 | # chunked_combinations = chunk(combinations, n=len(devices)) 254 | 255 | try: 256 | while True: 257 | processes = [] 258 | for proc_id in devices: 259 | steps, cfg_weight, typical_filtering = next(combinations) 260 | while os.path.exists(os.path.join(base_path, f"{steps}_{cfg_weight}_{typical_filtering}")): 261 | print(os.path.join(base_path, f"{steps}_{cfg_weight}_{typical_filtering}") + " already done. skipping....") 262 | steps, cfg_weight, typical_filtering = next(combinations) 263 | print(f"Starting sampling with steps={steps}, cfg_weight={cfg_weight}.") 264 | conv = Sample(date, proc_id, steps=steps, cfg_weight=cfg_weight, typical_filtering=typical_filtering, batch_size=8, base_path=base_path, dataset=dataset, captions_path=ref_captions) 265 | processes.append(mp.Process(target=conv.convert_dataset)) 266 | processes[proc_id].start() 267 | for p in processes: 268 | p.join() 269 | except StopIteration: 270 | if len(processes) > 0: 271 | for p in processes: 272 | p.join() 273 | print("Finished sampling....") 274 | 275 | for run in os.listdir(base_path): 276 | stat_dict = {} 277 | run_path = os.path.join(base_path, run) 278 | batch_path = os.path.join(run_path, "batch.npz") 279 | print(f"Converting {run_path} to npz....") 280 | save_images_npz(run_path, batch_path) 281 | 282 | config = tf.ConfigProto(allow_soft_placement=True) 283 | config.gpu_options.allow_growth = True 284 | evaluator = Evaluator(tf.Session(config=config)) 285 | 286 | evaluator.warmup() 287 | 288 | ref_acts = evaluator.read_activations(ref_images) 289 | ref_stats, ref_stats_spatial = evaluator.read_statistics(ref_images, ref_acts) 290 | 291 | sample_acts = evaluator.read_activations(batch_path) 292 | sample_stats, sample_stats_spatial = evaluator.read_statistics(batch_path, sample_acts) 293 | 294 | prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0]) 295 | stat_dict["inception_score"] = evaluator.compute_inception_score(sample_acts[0]) 296 | stat_dict["fid"] = sample_stats.frechet_distance(ref_stats) 297 | stat_dict["sfid"] = sample_stats_spatial.frechet_distance(ref_stats_spatial) 298 | stat_dict["prec"], stat_dict["recall"] = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0]) 299 | print("---------------------------------------------------------") 300 | print(f"Metrics for {run_path}") 301 | print(stat_dict) 302 | print("---------------------------------------------------------") 303 | json.dump(stat_dict, open(os.path.join(run_path, "stat_dict.json"), "w")) 304 | os.remove(batch_path) 305 | 306 | 307 | -------------------------------------------------------------------------------- /evaluation/evaluator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import io 3 | import os 4 | import random 5 | import warnings 6 | import zipfile 7 | from abc import ABC, abstractmethod 8 | from contextlib import contextmanager 9 | from functools import partial 10 | from multiprocessing import cpu_count 11 | from multiprocessing.pool import ThreadPool 12 | from typing import Iterable, Optional, Tuple 13 | 14 | import numpy as np 15 | import requests 16 | import tensorflow.compat.v1 as tf 17 | from scipy import linalg 18 | from tqdm.auto import tqdm 19 | 20 | INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb" 21 | INCEPTION_V3_PATH = "classify_image_graph_def.pb" 22 | 23 | FID_POOL_NAME = "pool_3:0" 24 | FID_SPATIAL_NAME = "mixed_6/conv:0" 25 | 26 | 27 | def main(): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("ref_batch", help="path to reference batch npz file") 30 | parser.add_argument("sample_batch", help="path to sample batch npz file") 31 | args = parser.parse_args() 32 | 33 | config = tf.ConfigProto( 34 | allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph 35 | ) 36 | config.gpu_options.allow_growth = True 37 | evaluator = Evaluator(tf.Session(config=config)) 38 | 39 | print("warming up TensorFlow...") 40 | # This will cause TF to print a bunch of verbose stuff now rather 41 | # than after the next print(), to help prevent confusion. 42 | evaluator.warmup() 43 | 44 | print("computing reference batch activations...") 45 | ref_acts = evaluator.read_activations(args.ref_batch) 46 | print("computing/reading reference batch statistics...") 47 | ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts) 48 | 49 | print("computing sample batch activations...") 50 | sample_acts = evaluator.read_activations(args.sample_batch) 51 | print("computing/reading sample batch statistics...") 52 | sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts) 53 | 54 | print("Computing evaluations...") 55 | print("Inception Score:", evaluator.compute_inception_score(sample_acts[0])) 56 | print("FID:", sample_stats.frechet_distance(ref_stats)) 57 | print("sFID:", sample_stats_spatial.frechet_distance(ref_stats_spatial)) 58 | prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0]) 59 | print("Precision:", prec) 60 | print("Recall:", recall) 61 | 62 | 63 | class InvalidFIDException(Exception): 64 | pass 65 | 66 | 67 | class FIDStatistics: 68 | def __init__(self, mu: np.ndarray, sigma: np.ndarray): 69 | self.mu = mu 70 | self.sigma = sigma 71 | 72 | def frechet_distance(self, other, eps=1e-6): 73 | """ 74 | Compute the Frechet distance between two sets of statistics. 75 | """ 76 | # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132 77 | mu1, sigma1 = self.mu, self.sigma 78 | mu2, sigma2 = other.mu, other.sigma 79 | 80 | mu1 = np.atleast_1d(mu1) 81 | mu2 = np.atleast_1d(mu2) 82 | 83 | sigma1 = np.atleast_2d(sigma1) 84 | sigma2 = np.atleast_2d(sigma2) 85 | 86 | assert ( 87 | mu1.shape == mu2.shape 88 | ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}" 89 | assert ( 90 | sigma1.shape == sigma2.shape 91 | ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}" 92 | 93 | diff = mu1 - mu2 94 | 95 | # product might be almost singular 96 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 97 | if not np.isfinite(covmean).all(): 98 | msg = ( 99 | "fid calculation produces singular product; adding %s to diagonal of cov estimates" 100 | % eps 101 | ) 102 | warnings.warn(msg) 103 | offset = np.eye(sigma1.shape[0]) * eps 104 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 105 | 106 | # numerical error might give slight imaginary component 107 | if np.iscomplexobj(covmean): 108 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 109 | m = np.max(np.abs(covmean.imag)) 110 | raise ValueError("Imaginary component {}".format(m)) 111 | covmean = covmean.real 112 | 113 | tr_covmean = np.trace(covmean) 114 | 115 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 116 | 117 | 118 | class Evaluator: 119 | def __init__( 120 | self, 121 | session, 122 | batch_size=64, 123 | softmax_batch_size=512, 124 | ): 125 | self.sess = session 126 | self.batch_size = batch_size 127 | self.softmax_batch_size = softmax_batch_size 128 | self.manifold_estimator = ManifoldEstimator(session) 129 | with self.sess.graph.as_default(): 130 | self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3]) 131 | self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048]) 132 | self.pool_features, self.spatial_features = _create_feature_graph(self.image_input) 133 | self.softmax = _create_softmax_graph(self.softmax_input) 134 | 135 | def warmup(self): 136 | self.compute_activations(np.zeros([1, 8, 64, 64, 3])) 137 | 138 | def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]: 139 | with open_npz_array(npz_path, "arr_0") as reader: 140 | return self.compute_activations(reader.read_batches(self.batch_size)) 141 | 142 | def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]: 143 | """ 144 | Compute image features for downstream evals. 145 | 146 | :param batches: a iterator over NHWC numpy arrays in [0, 255]. 147 | :return: a tuple of numpy arrays of shape [N x X], where X is a feature 148 | dimension. The tuple is (pool_3, spatial). 149 | """ 150 | preds = [] 151 | spatial_preds = [] 152 | for batch in tqdm(batches): 153 | batch = batch.astype(np.float32) 154 | pred, spatial_pred = self.sess.run( 155 | [self.pool_features, self.spatial_features], {self.image_input: batch} 156 | ) 157 | preds.append(pred.reshape([pred.shape[0], -1])) 158 | spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1])) 159 | return ( 160 | np.concatenate(preds, axis=0), 161 | np.concatenate(spatial_preds, axis=0), 162 | ) 163 | 164 | def read_statistics( 165 | self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray] 166 | ) -> Tuple[FIDStatistics, FIDStatistics]: 167 | obj = np.load(npz_path) 168 | if "mu" in list(obj.keys()): 169 | return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics( 170 | obj["mu_s"], obj["sigma_s"] 171 | ) 172 | return tuple(self.compute_statistics(x) for x in activations) 173 | 174 | def compute_statistics(self, activations: np.ndarray) -> FIDStatistics: 175 | mu = np.mean(activations, axis=0) 176 | sigma = np.cov(activations, rowvar=False) 177 | return FIDStatistics(mu, sigma) 178 | 179 | def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float: 180 | softmax_out = [] 181 | for i in range(0, len(activations), self.softmax_batch_size): 182 | acts = activations[i : i + self.softmax_batch_size] 183 | softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts})) 184 | preds = np.concatenate(softmax_out, axis=0) 185 | # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46 186 | scores = [] 187 | for i in range(0, len(preds), split_size): 188 | part = preds[i : i + split_size] 189 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 190 | kl = np.mean(np.sum(kl, 1)) 191 | scores.append(np.exp(kl)) 192 | return float(np.mean(scores)) 193 | 194 | def compute_prec_recall( 195 | self, activations_ref: np.ndarray, activations_sample: np.ndarray 196 | ) -> Tuple[float, float]: 197 | radii_1 = self.manifold_estimator.manifold_radii(activations_ref) 198 | radii_2 = self.manifold_estimator.manifold_radii(activations_sample) 199 | pr = self.manifold_estimator.evaluate_pr( 200 | activations_ref, radii_1, activations_sample, radii_2 201 | ) 202 | return (float(pr[0][0]), float(pr[1][0])) 203 | 204 | 205 | class ManifoldEstimator: 206 | """ 207 | A helper for comparing manifolds of feature vectors. 208 | 209 | Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57 210 | """ 211 | 212 | def __init__( 213 | self, 214 | session, 215 | row_batch_size=10000, 216 | col_batch_size=10000, 217 | nhood_sizes=(3,), 218 | clamp_to_percentile=None, 219 | eps=1e-5, 220 | ): 221 | """ 222 | Estimate the manifold of given feature vectors. 223 | 224 | :param session: the TensorFlow session. 225 | :param row_batch_size: row batch size to compute pairwise distances 226 | (parameter to trade-off between memory usage and performance). 227 | :param col_batch_size: column batch size to compute pairwise distances. 228 | :param nhood_sizes: number of neighbors used to estimate the manifold. 229 | :param clamp_to_percentile: prune hyperspheres that have radius larger than 230 | the given percentile. 231 | :param eps: small number for numerical stability. 232 | """ 233 | self.distance_block = DistanceBlock(session) 234 | self.row_batch_size = row_batch_size 235 | self.col_batch_size = col_batch_size 236 | self.nhood_sizes = nhood_sizes 237 | self.num_nhoods = len(nhood_sizes) 238 | self.clamp_to_percentile = clamp_to_percentile 239 | self.eps = eps 240 | 241 | def warmup(self): 242 | feats, radii = ( 243 | np.zeros([1, 2048], dtype=np.float32), 244 | np.zeros([1, 1], dtype=np.float32), 245 | ) 246 | self.evaluate_pr(feats, radii, feats, radii) 247 | 248 | def manifold_radii(self, features: np.ndarray) -> np.ndarray: 249 | num_images = len(features) 250 | 251 | # Estimate manifold of features by calculating distances to k-NN of each sample. 252 | radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32) 253 | distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32) 254 | seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32) 255 | 256 | for begin1 in range(0, num_images, self.row_batch_size): 257 | end1 = min(begin1 + self.row_batch_size, num_images) 258 | row_batch = features[begin1:end1] 259 | 260 | for begin2 in range(0, num_images, self.col_batch_size): 261 | end2 = min(begin2 + self.col_batch_size, num_images) 262 | col_batch = features[begin2:end2] 263 | 264 | # Compute distances between batches. 265 | distance_batch[ 266 | 0 : end1 - begin1, begin2:end2 267 | ] = self.distance_block.pairwise_distances(row_batch, col_batch) 268 | 269 | # Find the k-nearest neighbor from the current batch. 270 | radii[begin1:end1, :] = np.concatenate( 271 | [ 272 | x[:, self.nhood_sizes] 273 | for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1) 274 | ], 275 | axis=0, 276 | ) 277 | 278 | if self.clamp_to_percentile is not None: 279 | max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0) 280 | radii[radii > max_distances] = 0 281 | return radii 282 | 283 | def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray): 284 | """ 285 | Evaluate if new feature vectors are at the manifold. 286 | """ 287 | num_eval_images = eval_features.shape[0] 288 | num_ref_images = radii.shape[0] 289 | distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32) 290 | batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32) 291 | max_realism_score = np.zeros([num_eval_images], dtype=np.float32) 292 | nearest_indices = np.zeros([num_eval_images], dtype=np.int32) 293 | 294 | for begin1 in range(0, num_eval_images, self.row_batch_size): 295 | end1 = min(begin1 + self.row_batch_size, num_eval_images) 296 | feature_batch = eval_features[begin1:end1] 297 | 298 | for begin2 in range(0, num_ref_images, self.col_batch_size): 299 | end2 = min(begin2 + self.col_batch_size, num_ref_images) 300 | ref_batch = features[begin2:end2] 301 | 302 | distance_batch[ 303 | 0 : end1 - begin1, begin2:end2 304 | ] = self.distance_block.pairwise_distances(feature_batch, ref_batch) 305 | 306 | # From the minibatch of new feature vectors, determine if they are in the estimated manifold. 307 | # If a feature vector is inside a hypersphere of some reference sample, then 308 | # the new sample lies at the estimated manifold. 309 | # The radii of the hyperspheres are determined from distances of neighborhood size k. 310 | samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii 311 | batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32) 312 | 313 | max_realism_score[begin1:end1] = np.max( 314 | radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1 315 | ) 316 | nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1) 317 | 318 | return { 319 | "fraction": float(np.mean(batch_predictions)), 320 | "batch_predictions": batch_predictions, 321 | "max_realisim_score": max_realism_score, 322 | "nearest_indices": nearest_indices, 323 | } 324 | 325 | def evaluate_pr( 326 | self, 327 | features_1: np.ndarray, 328 | radii_1: np.ndarray, 329 | features_2: np.ndarray, 330 | radii_2: np.ndarray, 331 | ) -> Tuple[np.ndarray, np.ndarray]: 332 | """ 333 | Evaluate precision and recall efficiently. 334 | 335 | :param features_1: [N1 x D] feature vectors for reference batch. 336 | :param radii_1: [N1 x K1] radii for reference vectors. 337 | :param features_2: [N2 x D] feature vectors for the other batch. 338 | :param radii_2: [N x K2] radii for other vectors. 339 | :return: a tuple of arrays for (precision, recall): 340 | - precision: an np.ndarray of length K1 341 | - recall: an np.ndarray of length K2 342 | """ 343 | features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool) 344 | features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool) 345 | for begin_1 in range(0, len(features_1), self.row_batch_size): 346 | end_1 = begin_1 + self.row_batch_size 347 | batch_1 = features_1[begin_1:end_1] 348 | for begin_2 in range(0, len(features_2), self.col_batch_size): 349 | end_2 = begin_2 + self.col_batch_size 350 | batch_2 = features_2[begin_2:end_2] 351 | batch_1_in, batch_2_in = self.distance_block.less_thans( 352 | batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2] 353 | ) 354 | features_1_status[begin_1:end_1] |= batch_1_in 355 | features_2_status[begin_2:end_2] |= batch_2_in 356 | return ( 357 | np.mean(features_2_status.astype(np.float64), axis=0), 358 | np.mean(features_1_status.astype(np.float64), axis=0), 359 | ) 360 | 361 | 362 | class DistanceBlock: 363 | """ 364 | Calculate pairwise distances between vectors. 365 | 366 | Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34 367 | """ 368 | 369 | def __init__(self, session): 370 | self.session = session 371 | 372 | # Initialize TF graph to calculate pairwise distances. 373 | with session.graph.as_default(): 374 | self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None]) 375 | self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None]) 376 | distance_block_16 = _batch_pairwise_distances( 377 | tf.cast(self._features_batch1, tf.float16), 378 | tf.cast(self._features_batch2, tf.float16), 379 | ) 380 | self.distance_block = tf.cond( 381 | tf.reduce_all(tf.math.is_finite(distance_block_16)), 382 | lambda: tf.cast(distance_block_16, tf.float32), 383 | lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2), 384 | ) 385 | 386 | # Extra logic for less thans. 387 | self._radii1 = tf.placeholder(tf.float32, shape=[None, None]) 388 | self._radii2 = tf.placeholder(tf.float32, shape=[None, None]) 389 | dist32 = tf.cast(self.distance_block, tf.float32)[..., None] 390 | self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1) 391 | self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0) 392 | 393 | def pairwise_distances(self, U, V): 394 | """ 395 | Evaluate pairwise distances between two batches of feature vectors. 396 | """ 397 | return self.session.run( 398 | self.distance_block, 399 | feed_dict={self._features_batch1: U, self._features_batch2: V}, 400 | ) 401 | 402 | def less_thans(self, batch_1, radii_1, batch_2, radii_2): 403 | return self.session.run( 404 | [self._batch_1_in, self._batch_2_in], 405 | feed_dict={ 406 | self._features_batch1: batch_1, 407 | self._features_batch2: batch_2, 408 | self._radii1: radii_1, 409 | self._radii2: radii_2, 410 | }, 411 | ) 412 | 413 | 414 | def _batch_pairwise_distances(U, V): 415 | """ 416 | Compute pairwise distances between two batches of feature vectors. 417 | """ 418 | with tf.variable_scope("pairwise_dist_block"): 419 | # Squared norms of each row in U and V. 420 | norm_u = tf.reduce_sum(tf.square(U), 1) 421 | norm_v = tf.reduce_sum(tf.square(V), 1) 422 | 423 | # norm_u as a column and norm_v as a row vectors. 424 | norm_u = tf.reshape(norm_u, [-1, 1]) 425 | norm_v = tf.reshape(norm_v, [1, -1]) 426 | 427 | # Pairwise squared Euclidean distances. 428 | D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0) 429 | 430 | return D 431 | 432 | 433 | class NpzArrayReader(ABC): 434 | @abstractmethod 435 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]: 436 | pass 437 | 438 | @abstractmethod 439 | def remaining(self) -> int: 440 | pass 441 | 442 | def read_batches(self, batch_size: int) -> Iterable[np.ndarray]: 443 | def gen_fn(): 444 | while True: 445 | batch = self.read_batch(batch_size) 446 | if batch is None: 447 | break 448 | yield batch 449 | 450 | rem = self.remaining() 451 | num_batches = rem // batch_size + int(rem % batch_size != 0) 452 | return BatchIterator(gen_fn, num_batches) 453 | 454 | 455 | class BatchIterator: 456 | def __init__(self, gen_fn, length): 457 | self.gen_fn = gen_fn 458 | self.length = length 459 | 460 | def __len__(self): 461 | return self.length 462 | 463 | def __iter__(self): 464 | return self.gen_fn() 465 | 466 | 467 | class StreamingNpzArrayReader(NpzArrayReader): 468 | def __init__(self, arr_f, shape, dtype): 469 | self.arr_f = arr_f 470 | self.shape = shape 471 | self.dtype = dtype 472 | self.idx = 0 473 | 474 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]: 475 | if self.idx >= self.shape[0]: 476 | return None 477 | 478 | bs = min(batch_size, self.shape[0] - self.idx) 479 | self.idx += bs 480 | 481 | if self.dtype.itemsize == 0: 482 | return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype) 483 | 484 | read_count = bs * np.prod(self.shape[1:]) 485 | read_size = int(read_count * self.dtype.itemsize) 486 | data = _read_bytes(self.arr_f, read_size, "array data") 487 | return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]]) 488 | 489 | def remaining(self) -> int: 490 | return max(0, self.shape[0] - self.idx) 491 | 492 | 493 | class MemoryNpzArrayReader(NpzArrayReader): 494 | def __init__(self, arr): 495 | self.arr = arr 496 | self.idx = 0 497 | 498 | @classmethod 499 | def load(cls, path: str, arr_name: str): 500 | with open(path, "rb") as f: 501 | arr = np.load(f)[arr_name] 502 | return cls(arr) 503 | 504 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]: 505 | if self.idx >= self.arr.shape[0]: 506 | return None 507 | 508 | res = self.arr[self.idx : self.idx + batch_size] 509 | self.idx += batch_size 510 | return res 511 | 512 | def remaining(self) -> int: 513 | return max(0, self.arr.shape[0] - self.idx) 514 | 515 | 516 | @contextmanager 517 | def open_npz_array(path: str, arr_name: str) -> NpzArrayReader: 518 | with _open_npy_file(path, arr_name) as arr_f: 519 | version = np.lib.format.read_magic(arr_f) 520 | if version == (1, 0): 521 | header = np.lib.format.read_array_header_1_0(arr_f) 522 | elif version == (2, 0): 523 | header = np.lib.format.read_array_header_2_0(arr_f) 524 | else: 525 | yield MemoryNpzArrayReader.load(path, arr_name) 526 | return 527 | shape, fortran, dtype = header 528 | if fortran or dtype.hasobject: 529 | yield MemoryNpzArrayReader.load(path, arr_name) 530 | else: 531 | yield StreamingNpzArrayReader(arr_f, shape, dtype) 532 | 533 | 534 | def _read_bytes(fp, size, error_template="ran out of data"): 535 | """ 536 | Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886 537 | 538 | Read from file-like object until size bytes are read. 539 | Raises ValueError if not EOF is encountered before size bytes are read. 540 | Non-blocking objects only supported if they derive from io objects. 541 | Required as e.g. ZipExtFile in python 2.6 can return less data than 542 | requested. 543 | """ 544 | data = bytes() 545 | while True: 546 | # io files (default in python3) return None or raise on 547 | # would-block, python2 file will truncate, probably nothing can be 548 | # done about that. note that regular files can't be non-blocking 549 | try: 550 | r = fp.read(size - len(data)) 551 | data += r 552 | if len(r) == 0 or len(data) == size: 553 | break 554 | except io.BlockingIOError: 555 | pass 556 | if len(data) != size: 557 | msg = "EOF: reading %s, expected %d bytes got %d" 558 | raise ValueError(msg % (error_template, size, len(data))) 559 | else: 560 | return data 561 | 562 | 563 | @contextmanager 564 | def _open_npy_file(path: str, arr_name: str): 565 | with open(path, "rb") as f: 566 | with zipfile.ZipFile(f, "r") as zip_f: 567 | if f"{arr_name}.npy" not in zip_f.namelist(): 568 | raise ValueError(f"missing {arr_name} in npz file") 569 | with zip_f.open(f"{arr_name}.npy", "r") as arr_f: 570 | yield arr_f 571 | 572 | 573 | def _download_inception_model(): 574 | if os.path.exists(INCEPTION_V3_PATH): 575 | return 576 | print("downloading InceptionV3 model...") 577 | with requests.get(INCEPTION_V3_URL, stream=True) as r: 578 | r.raise_for_status() 579 | tmp_path = INCEPTION_V3_PATH + ".tmp" 580 | with open(tmp_path, "wb") as f: 581 | for chunk in tqdm(r.iter_content(chunk_size=8192)): 582 | f.write(chunk) 583 | os.rename(tmp_path, INCEPTION_V3_PATH) 584 | 585 | 586 | def _create_feature_graph(input_batch): 587 | _download_inception_model() 588 | prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}" 589 | with open(INCEPTION_V3_PATH, "rb") as f: 590 | graph_def = tf.GraphDef() 591 | graph_def.ParseFromString(f.read()) 592 | pool3, spatial = tf.import_graph_def( 593 | graph_def, 594 | input_map={f"ExpandDims:0": input_batch}, 595 | return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME], 596 | name=prefix, 597 | ) 598 | _update_shapes(pool3) 599 | spatial = spatial[..., :7] 600 | return pool3, spatial 601 | 602 | 603 | def _create_softmax_graph(input_batch): 604 | _download_inception_model() 605 | prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}" 606 | with open(INCEPTION_V3_PATH, "rb") as f: 607 | graph_def = tf.GraphDef() 608 | graph_def.ParseFromString(f.read()) 609 | (matmul,) = tf.import_graph_def( 610 | graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix 611 | ) 612 | w = matmul.inputs[1] 613 | logits = tf.matmul(input_batch, w) 614 | return tf.nn.softmax(logits) 615 | 616 | 617 | def _update_shapes(pool3): 618 | # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63 619 | ops = pool3.graph.get_operations() 620 | for op in ops: 621 | for o in op.outputs: 622 | shape = o.get_shape() 623 | if shape._dims is not None: # pylint: disable=protected-access 624 | # shape = [s.value for s in shape] TF 1.x 625 | shape = [s for s in shape] # TF 2.x 626 | new_shape = [] 627 | for j, s in enumerate(shape): 628 | if s == 1 and j == 0: 629 | new_shape.append(None) 630 | else: 631 | new_shape.append(s) 632 | o.__dict__["_shape_val"] = tf.TensorShape(new_shape) 633 | return pool3 634 | 635 | 636 | def _numpy_partition(arr, kth, **kwargs): 637 | num_workers = min(cpu_count(), len(arr)) 638 | chunk_size = len(arr) // num_workers 639 | extra = len(arr) % num_workers 640 | 641 | start_idx = 0 642 | batches = [] 643 | for i in range(num_workers): 644 | size = chunk_size + (1 if i < extra else 0) 645 | batches.append(arr[start_idx : start_idx + size]) 646 | start_idx += size 647 | 648 | with ThreadPool(num_workers) as pool: 649 | return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches)) 650 | 651 | 652 | if __name__ == "__main__": 653 | main() 654 | -------------------------------------------------------------------------------- /paella_sampling.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "b3ea0ec1-436f-4fca-99fa-5122a73b52d7", 7 | "metadata": { 8 | "id": "b3ea0ec1-436f-4fca-99fa-5122a73b52d7", 9 | "outputId": "b90ef2e0-156a-458d-f495-8a0fda0253e5", 10 | "scrolled": true, 11 | "tags": [] 12 | }, 13 | "outputs": [], 14 | "source": [ 15 | "!pip install kornia lpips einops rudalle open_clip_torch pytorch_lightning webdataset timm git+https://github.com/pabloppp/pytorch-tools git+https://github.com/openai/CLIP.git -U\n", 16 | "!pip uninstall torch torchvision torchaudio -y\n", 17 | "!pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116\n", 18 | "!pip install --upgrade Pillow" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "id": "e05c89a6-4648-46e2-90c7-cf7c1b0dd861", 25 | "metadata": { 26 | "id": "788a2a72", 27 | "outputId": "702cbb5b-8c71-4b33-e86e-f84b718b7a1b" 28 | }, 29 | "outputs": [], 30 | "source": [ 31 | "import os\n", 32 | "import time\n", 33 | "import torch\n", 34 | "from torch import nn\n", 35 | "import torchvision\n", 36 | "import matplotlib.pyplot as plt\n", 37 | "from tqdm import tqdm\n", 38 | "from PIL import Image\n", 39 | "import requests\n", 40 | "from io import BytesIO\n", 41 | "from modules import DenoiseUNet\n", 42 | "import open_clip\n", 43 | "from open_clip import tokenizer\n", 44 | "from rudalle import get_vae\n", 45 | "from einops import rearrange\n", 46 | "\n", 47 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 48 | "print(\"Using device:\", device)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "id": "d6640155", 55 | "metadata": { 56 | "id": "d6640155", 57 | "outputId": "99d4752b-3a85-49da-cbf3-72637160e2b9", 58 | "scrolled": true 59 | }, 60 | "outputs": [], 61 | "source": [ 62 | "def showmask(mask):\n", 63 | " plt.axis(\"off\")\n", 64 | " plt.imshow(torch.cat([\n", 65 | " torch.cat([i for i in mask[0:1].cpu()], dim=-1),\n", 66 | " ], dim=-2).cpu())\n", 67 | " plt.show()\n", 68 | "\n", 69 | "def showimages(imgs, **kwargs):\n", 70 | " plt.figure(figsize=(kwargs.get(\"width\", 32), kwargs.get(\"height\", 32)))\n", 71 | " plt.axis(\"off\")\n", 72 | " plt.imshow(torch.cat([\n", 73 | " torch.cat([i for i in imgs], dim=-1),\n", 74 | " ], dim=-2).permute(1, 2, 0).cpu())\n", 75 | " plt.show()\n", 76 | " \n", 77 | "def saveimages(imgs, name, **kwargs):\n", 78 | " name = name.replace(\" \", \"_\").replace(\".\", \"\")\n", 79 | " path = os.path.join(\"outputs\", name + \".jpg\")\n", 80 | " while os.path.exists(path):\n", 81 | " base, ext = path.split(\".\")\n", 82 | " num = base.split(\"_\")[-1]\n", 83 | " if num.isdigit():\n", 84 | " num = int(num) + 1\n", 85 | " base = \"_\".join(base.split(\"_\")[:-1])\n", 86 | " else:\n", 87 | " num = 0\n", 88 | " path = base + \"_\" + str(num) + \".\" + ext\n", 89 | " torchvision.utils.save_image(imgs, path, **kwargs)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "id": "7e2e1f26-b4ca-4e11-947d-be80196d440f", 96 | "metadata": { 97 | "tags": [] 98 | }, 99 | "outputs": [], 100 | "source": [ 101 | "def log(t, eps=1e-20):\n", 102 | " return torch.log(t + eps)\n", 103 | "\n", 104 | "def gumbel_noise(t):\n", 105 | " noise = torch.zeros_like(t).uniform_(0, 1)\n", 106 | " return -log(-log(noise))\n", 107 | "\n", 108 | "def gumbel_sample(t, temperature=1., dim=-1):\n", 109 | " return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)\n", 110 | "\n", 111 | "def sample(model, c, x=None, mask=None, T=12, size=(32, 32), starting_t=0, temp_range=[1.0, 1.0], typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=-1, renoise_steps=11, renoise_mode='start'):\n", 112 | " with torch.inference_mode():\n", 113 | " r_range = torch.linspace(0, 1, T+1)[:-1][:, None].expand(-1, c.size(0)).to(c.device)\n", 114 | " temperatures = torch.linspace(temp_range[0], temp_range[1], T)\n", 115 | " preds = []\n", 116 | " if x is None:\n", 117 | " x = torch.randint(0, model.num_labels, size=(c.size(0), *size), device=c.device)\n", 118 | " elif mask is not None:\n", 119 | " noise = torch.randint(0, model.num_labels, size=(c.size(0), *size), device=c.device)\n", 120 | " x = noise * mask + (1-mask) * x\n", 121 | " init_x = x.clone()\n", 122 | " for i in range(starting_t, T):\n", 123 | " if renoise_mode == 'prev':\n", 124 | " prev_x = x.clone()\n", 125 | " r, temp = r_range[i], temperatures[i]\n", 126 | " logits = model(x, c, r)\n", 127 | " if classifier_free_scale >= 0:\n", 128 | " logits_uncond = model(x, torch.zeros_like(c), r)\n", 129 | " logits = torch.lerp(logits_uncond, logits, classifier_free_scale)\n", 130 | " x = logits\n", 131 | " x_flat = x.permute(0, 2, 3, 1).reshape(-1, x.size(1))\n", 132 | " if typical_filtering:\n", 133 | " x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)\n", 134 | " x_flat_norm_p = torch.exp(x_flat_norm)\n", 135 | " entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True)\n", 136 | "\n", 137 | " c_flat_shifted = torch.abs((-x_flat_norm) - entropy)\n", 138 | " c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)\n", 139 | " x_flat_cumsum = x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)\n", 140 | "\n", 141 | " last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)\n", 142 | " sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(1, last_ind.view(-1, 1))\n", 143 | " if typical_min_tokens > 1:\n", 144 | " sorted_indices_to_remove[..., :typical_min_tokens] = 0\n", 145 | " indices_to_remove = sorted_indices_to_remove.scatter(1, x_flat_indices, sorted_indices_to_remove)\n", 146 | " x_flat = x_flat.masked_fill(indices_to_remove, -float(\"Inf\"))\n", 147 | " # x_flat = torch.multinomial(x_flat.div(temp).softmax(-1), num_samples=1)[:, 0]\n", 148 | " x_flat = gumbel_sample(x_flat, temperature=temp)\n", 149 | " x = x_flat.view(x.size(0), *x.shape[2:])\n", 150 | " if mask is not None:\n", 151 | " x = x * mask + (1-mask) * init_x\n", 152 | " if i < renoise_steps:\n", 153 | " if renoise_mode == 'start':\n", 154 | " x, _ = model.add_noise(x, r_range[i+1], random_x=init_x)\n", 155 | " elif renoise_mode == 'prev':\n", 156 | " x, _ = model.add_noise(x, r_range[i+1], random_x=prev_x)\n", 157 | " else: # 'rand'\n", 158 | " x, _ = model.add_noise(x, r_range[i+1])\n", 159 | " preds.append(x.detach())\n", 160 | " return preds" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "id": "34c7c626", 167 | "metadata": { 168 | "id": "34c7c626", 169 | "outputId": "16108313-6b3d-4751-d84b-a92d7f8ebf68", 170 | "tags": [] 171 | }, 172 | "outputs": [], 173 | "source": [ 174 | "vqmodel = get_vae().to(device)\n", 175 | "vqmodel.eval().requires_grad_(False)\n", 176 | "\n", 177 | "clip_model, _, _ = open_clip.create_model_and_transforms('ViT-g-14', pretrained='laion2b_s12b_b42k')\n", 178 | "clip_model = clip_model.to(device).eval().requires_grad_(False)\n", 179 | "\n", 180 | "clip_preprocess = torchvision.transforms.Compose([\n", 181 | " torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),\n", 182 | " torchvision.transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),\n", 183 | "])\n", 184 | "\n", 185 | "preprocess = torchvision.transforms.Compose([\n", 186 | " torchvision.transforms.Resize(256),\n", 187 | " # torchvision.transforms.CenterCrop(256),\n", 188 | " torchvision.transforms.ToTensor(),\n", 189 | "])\n", 190 | "\n", 191 | "def encode(x):\n", 192 | " return vqmodel.model.encode((2 * x - 1))[-1][-1]\n", 193 | " \n", 194 | "def decode(img_seq, shape=(32,32)):\n", 195 | " img_seq = img_seq.view(img_seq.shape[0], -1)\n", 196 | " b, n = img_seq.shape\n", 197 | " one_hot_indices = torch.nn.functional.one_hot(img_seq, num_classes=vqmodel.num_tokens).float()\n", 198 | " z = (one_hot_indices @ vqmodel.model.quantize.embed.weight)\n", 199 | " z = rearrange(z, 'b (h w) c -> b c h w', h=shape[0], w=shape[1])\n", 200 | " img = vqmodel.model.decode(z)\n", 201 | " img = (img.clamp(-1., 1.) + 1) * 0.5\n", 202 | " return img\n", 203 | " \n", 204 | "state_dict = torch.load(\"./models/f8_600000.pt\", map_location=device)\n", 205 | "# state_dict = torch.load(\"./models/f8_img_40000.pt\", map_location=device)\n", 206 | "model = DenoiseUNet(num_labels=8192).to(device)\n", 207 | "model.load_state_dict(state_dict)\n", 208 | "model.eval().requires_grad_()\n", 209 | "print()" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "id": "753a98f2-bf40-4059-86f6-c83dac8c15eb", 215 | "metadata": {}, 216 | "source": [ 217 | "# Text-Conditional" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": null, 223 | "id": "d39cd5d7-220a-438b-8229-823fa9e3fff9", 224 | "metadata": { 225 | "tags": [] 226 | }, 227 | "outputs": [], 228 | "source": [ 229 | "mode = \"text\"\n", 230 | "batch_size = 6\n", 231 | "text = \"highly detailed photograph of darth vader. artstation\"\n", 232 | "latent_shape = (32, 32)\n", 233 | "tokenized_text = tokenizer.tokenize([text] * batch_size).to(device)\n", 234 | "with torch.inference_mode():\n", 235 | " with torch.autocast(device_type=\"cuda\"):\n", 236 | " clip_embeddings = clip_model.encode_text(tokenized_text)\n", 237 | " s = time.time()\n", 238 | " sampled = sample(model, clip_embeddings, T=12, size=latent_shape, starting_t=0, temp_range=[1.0, 1.0],\n", 239 | " typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=5, renoise_steps=11,\n", 240 | " renoise_mode=\"start\")\n", 241 | " print(time.time() - s)\n", 242 | " sampled = decode(sampled[-1], latent_shape)\n", 243 | "\n", 244 | "showimages(sampled)\n", 245 | "saveimages(sampled, mode + \"_\" + text, nrow=len(sampled))" 246 | ] 247 | }, 248 | { 249 | "cell_type": "markdown", 250 | "id": "c3f5cecc-4e30-465e-bc51-425ec76a7d06", 251 | "metadata": {}, 252 | "source": [ 253 | "# Interpolation" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "id": "35ae32e2", 260 | "metadata": { 261 | "id": "35ae32e2", 262 | "outputId": "28283340-fe61-44c6-9d92-720d04cf56bf" 263 | }, 264 | "outputs": [], 265 | "source": [ 266 | "mode = \"interpolation\"\n", 267 | "text = \"surreal painting of a yellow tulip. artstation\"\n", 268 | "text2 = \"surreal painting of a red tulip. artstation\"\n", 269 | "text_encoded = tokenizer.tokenize([text]).to(device)\n", 270 | "text2_encoded = tokenizer.tokenize([text2]).to(device)\n", 271 | "with torch.inference_mode():\n", 272 | " with torch.autocast(device_type=\"cuda\"):\n", 273 | " clip_embeddings = clip_model.encode_text(text_encoded).float()\n", 274 | " clip_embeddings2 = clip_model.encode_text(text2_encoded).float()\n", 275 | "\n", 276 | " l = torch.linspace(0, 1, 10).to(device)\n", 277 | " embeddings = []\n", 278 | " for i in l:\n", 279 | " lerp = torch.lerp(clip_embeddings, clip_embeddings2, i)\n", 280 | " embeddings.append(lerp)\n", 281 | " embeddings = torch.cat(embeddings)\n", 282 | " \n", 283 | " s = time.time()\n", 284 | " sampled = sample(model, embeddings, T=12, size=(32, 32), starting_t=0, temp_range=[1.0, 1.0],\n", 285 | " typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=4, renoise_steps=11)\n", 286 | " print(time.time() - s)\n", 287 | " sampled = decode(sampled[-1])\n", 288 | "showimages(sampled)\n", 289 | "saveimages(sampled, mode + \"_\" + text + \"_\" + text2, nrow=len(sampled))" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": null, 295 | "id": "026fade5-0384-4521-992e-8cd66929d26d", 296 | "metadata": {}, 297 | "outputs": [], 298 | "source": [ 299 | "mode = \"interpolation\"\n", 300 | "text = \"High quality front portrait photo of a tiger.\"\n", 301 | "text2 = \"High quality front portrait photo of a dog.\"\n", 302 | "text_encoded = tokenizer.tokenize([text]).to(device)\n", 303 | "text2_encoded = tokenizer.tokenize([text2]).to(device)\n", 304 | "with torch.inference_mode():\n", 305 | " with torch.autocast(device_type=\"cuda\"):\n", 306 | " clip_embeddings = clip_model.encode_text(text_encoded).float()\n", 307 | " clip_embeddings2 = clip_model.encode_text(text2_encoded).float()\n", 308 | "\n", 309 | " l = torch.linspace(0, 1, 10).to(device)\n", 310 | " s = time.time()\n", 311 | " outputs = []\n", 312 | " for i in l:\n", 313 | " # lerp = torch.lerp(clip_embeddings, clip_embeddings2, i)\n", 314 | " low, high = clip_embeddings, clip_embeddings2\n", 315 | " low_norm = low/torch.norm(low, dim=1, keepdim=True)\n", 316 | " high_norm = high/torch.norm(high, dim=1, keepdim=True)\n", 317 | " omega = torch.acos((low_norm*high_norm).sum(1)).unsqueeze(1)\n", 318 | " so = torch.sin(omega)\n", 319 | " lerp = (torch.sin((1.0-i)*omega)/so)*low + (torch.sin(i*omega)/so) * high\n", 320 | " with torch.random.fork_rng():\n", 321 | " torch.random.manual_seed(32)\n", 322 | " sampled = sample(model, lerp, T=20, size=(32, 32), starting_t=0, temp_range=[1.0, 1.0],\n", 323 | " typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=5, renoise_steps=11)\n", 324 | " outputs.append(sampled[-1])\n", 325 | " print(time.time() - s)\n", 326 | " sampled = torch.cat(outputs)\n", 327 | " sampled = decode(sampled)\n", 328 | "showimages(sampled)\n", 329 | "saveimages(sampled, mode + \"_\" + text + \"_\" + text2, nrow=len(sampled))" 330 | ] 331 | }, 332 | { 333 | "cell_type": "markdown", 334 | "id": "561133a9", 335 | "metadata": { 336 | "id": "0bd51975" 337 | }, 338 | "source": [ 339 | "# Multi-Conditioning" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": null, 345 | "id": "83e9ac9f", 346 | "metadata": { 347 | "id": "83e9ac9f" 348 | }, 349 | "outputs": [], 350 | "source": [ 351 | "batch_size = 4\n", 352 | "latent_shape = (32, 32)\n", 353 | "text_a = \"a cute portrait of a dog\"\n", 354 | "text_b = \"a cute portrait of a cat\"\n", 355 | "mode = \"vertical\"\n", 356 | "# mode = \"horizontal\"\n", 357 | "text = tokenizer.tokenize([text_a, text_b] * batch_size).to(device)\n", 358 | "\n", 359 | "with torch.inference_mode():\n", 360 | " with torch.autocast(device_type=\"cuda\"):\n", 361 | " clip_embeddings = clip_model.encode_text(text).float()[:, :, None, None].expand(-1, -1, latent_shape[0], latent_shape[1])\n", 362 | " if mode == 'vertical':\n", 363 | " interp_mask = torch.linspace(0, 1, latent_shape[0], device=device)[None, None, :, None].expand(batch_size, 1, -1, latent_shape[1])\n", 364 | " else: \n", 365 | " interp_mask = torch.linspace(0, 1, latent_shape[1], device=device)[None, None, None, :].expand(batch_size, 1, latent_shape[0], -1)\n", 366 | " # LERP\n", 367 | " clip_embeddings = clip_embeddings[0::2] * (1-interp_mask) + clip_embeddings[1::2] * interp_mask\n", 368 | " # # SLERP\n", 369 | " # low, high = clip_embeddings[0::2], clip_embeddings[1::2]\n", 370 | " # low_norm = low/torch.norm(low, dim=1, keepdim=True)\n", 371 | " # high_norm = high/torch.norm(high, dim=1, keepdim=True)\n", 372 | " # omega = torch.acos((low_norm*high_norm).sum(1)).unsqueeze(1)\n", 373 | " # so = torch.sin(omega)\n", 374 | " # clip_embeddings = (torch.sin((1.0-interp_mask)*omega)/so)*low + (torch.sin(interp_mask*omega)/so) * high\n", 375 | " \n", 376 | " sampled = sample(model, clip_embeddings, T=12, size=latent_shape, starting_t=0, temp_range=[1.0, 1.0],\n", 377 | " typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=5, renoise_steps=11,\n", 378 | " renoise_mode=\"start\")\n", 379 | " sampled = decode(sampled[-1], latent_shape)\n", 380 | "\n", 381 | "showimages(sampled)" 382 | ] 383 | }, 384 | { 385 | "cell_type": "code", 386 | "execution_count": null, 387 | "id": "3838e233-34fd-4637-b0df-476f4e66cd83", 388 | "metadata": {}, 389 | "outputs": [], 390 | "source": [ 391 | "mode = \"multiconditioning\"\n", 392 | "batch_size = 4\n", 393 | "latent_shape = (32, 32)\n", 394 | "conditions = [\n", 395 | " [\"High quality portrait of a dog.\", 16],\n", 396 | " [\"High quality portrait of a wolf.\", 32],\n", 397 | "]\n", 398 | "clip_embedding = torch.zeros(batch_size, 1024, *latent_shape).to(device)\n", 399 | "last_pos = 0\n", 400 | "for text, pos in conditions:\n", 401 | " tokenized_text = tokenizer.tokenize([text] * batch_size).to(device)\n", 402 | " part_clip_embedding = clip_model.encode_text(tokenized_text).float()[:, :, None, None]\n", 403 | " print(f\"{last_pos}:{pos}={text}\")\n", 404 | " clip_embedding[:, :, :, last_pos:pos] = part_clip_embedding\n", 405 | " last_pos = pos\n", 406 | "with torch.inference_mode():\n", 407 | " with torch.autocast(device_type=\"cuda\"):\n", 408 | " sampled = sample(model, clip_embedding, T=12, size=latent_shape, starting_t=0, temp_range=[1.0, 1.0],\n", 409 | " typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=5, renoise_steps=11,\n", 410 | " renoise_mode=\"start\")\n", 411 | " sampled = decode(sampled[-1], latent_shape)\n", 412 | " \n", 413 | "showimages(sampled)\n", 414 | "saveimages(sampled, mode + \"_\" + \":\".join(list(map(lambda x: x[0], conditions))), nrow=batch_size)" 415 | ] 416 | }, 417 | { 418 | "cell_type": "markdown", 419 | "id": "0a65e885-e166-4997-bf81-94e344e46a78", 420 | "metadata": {}, 421 | "source": [ 422 | "#### Load Image: Disk or Web" 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": null, 428 | "id": "9fc57b1a-de96-4ef4-b699-0417078b9da3", 429 | "metadata": {}, 430 | "outputs": [], 431 | "source": [ 432 | "images = preprocess(Image.open(\"path_to_image\")).unsqueeze(0).expand(4, -1, -1, -1).to(device)[:, :3]\n", 433 | "showimages(images)" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": null, 439 | "id": "7fcf26db-a58a-4f53-9c0b-d83707e7713d", 440 | "metadata": {}, 441 | "outputs": [], 442 | "source": [ 443 | "url = \"https://media.istockphoto.com/id/1193591781/photo/obedient-dog-breed-welsh-corgi-pembroke-sitting-and-smiles-on-a-white-background-not-isolate.jpg?s=612x612&w=0&k=20&c=ZDKTgSFQFG9QvuDziGsnt55kvQoqJtIhrmVRkpYqxtQ=\"\n", 444 | "# url = \"https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1200px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg\"\n", 445 | "response = requests.get(url)\n", 446 | "img = Image.open(BytesIO(response.content)).convert(\"RGB\")\n", 447 | "images = preprocess(img).unsqueeze(0).expand(4, -1, -1, -1).to(device)[:, :3]\n", 448 | "showimages(images)" 449 | ] 450 | }, 451 | { 452 | "cell_type": "markdown", 453 | "id": "11eb27d7-5d3b-4a2d-9b2a-86cb81850dad", 454 | "metadata": {}, 455 | "source": [ 456 | "# Inpainting" 457 | ] 458 | }, 459 | { 460 | "cell_type": "code", 461 | "execution_count": null, 462 | "id": "7d59c35c-b7a5-4468-99b8-4e4377df63c2", 463 | "metadata": {}, 464 | "outputs": [], 465 | "source": [ 466 | "mode = \"inpainting\"\n", 467 | "text = \"a delicious spanish paella\"\n", 468 | "tokenized_text = tokenizer.tokenize([text] * images.shape[0]).to(device)\n", 469 | "with torch.inference_mode():\n", 470 | " with torch.autocast(device_type=\"cuda\"):\n", 471 | " # clip_embeddings = clip_model.encode_image(clip_preprocess(images)).float() # clip_embeddings = clip_model.encode_text(text).float()\n", 472 | " clip_embeddings = clip_model.encode_text(tokenized_text).float()\n", 473 | " encoded_tokens = encode(images)\n", 474 | " latent_shape = encoded_tokens.shape[1:]\n", 475 | " mask = torch.zeros_like(encoded_tokens)\n", 476 | " mask[:,5:28,5:28] = 1\n", 477 | " sampled = sample(model, clip_embeddings, x=encoded_tokens, mask=mask, T=12, size=latent_shape, starting_t=0, temp_range=[1.0, 1.0],\n", 478 | " typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=6, renoise_steps=11)\n", 479 | " sampled = decode(sampled[-1], latent_shape)\n", 480 | "\n", 481 | "showimages(images[0:1], height=10, width=10)\n", 482 | "showmask(mask[0:1])\n", 483 | "showimages(sampled, height=16, width=16)\n", 484 | "saveimages(torch.cat([images[0:1], sampled]), mode + \"_\" + text, nrow=images.shape[0]+1)" 485 | ] 486 | }, 487 | { 488 | "cell_type": "markdown", 489 | "id": "86f2bd06-aef7-4887-a5e8-dcc4af883dab", 490 | "metadata": { 491 | "tags": [] 492 | }, 493 | "source": [ 494 | "# Outpainting" 495 | ] 496 | }, 497 | { 498 | "cell_type": "code", 499 | "execution_count": null, 500 | "id": "97e4243a-186e-40e1-9b71-c11df0c60963", 501 | "metadata": {}, 502 | "outputs": [], 503 | "source": [ 504 | "mode = \"outpainting\"\n", 505 | "size = (40, 64)\n", 506 | "top_left = (0, 16)\n", 507 | "text = \"black & white photograph of a rocket from the bottom.\"\n", 508 | "tokenized_text = tokenizer.tokenize([text] * images.shape[0]).to(device)\n", 509 | "with torch.inference_mode():\n", 510 | " with torch.autocast(device_type=\"cuda\"):\n", 511 | " # clip_embeddings = clip_model.encode_image(clip_preprocess(images)).float()\n", 512 | " clip_embeddings = clip_model.encode_text(tokenized_text).float()\n", 513 | " encoded_tokens = encode(images)\n", 514 | " canvas = torch.zeros((images.shape[0], *size), dtype=torch.long).to(device)\n", 515 | " canvas[:, top_left[0]:top_left[0]+encoded_tokens.shape[1], top_left[1]:top_left[1]+encoded_tokens.shape[2]] = encoded_tokens\n", 516 | " mask = torch.ones_like(canvas)\n", 517 | " mask[:, top_left[0]:top_left[0]+encoded_tokens.shape[1], top_left[1]:top_left[1]+encoded_tokens.shape[2]] = 0\n", 518 | " sampled = sample(model, clip_embeddings, x=canvas, mask=mask, T=12, size=size, starting_t=0, temp_range=[1.0, 1.0],\n", 519 | " typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=4, renoise_steps=11)\n", 520 | " sampled = decode(sampled[-1], size)\n", 521 | "\n", 522 | "showimages(images[0:1], height=10, width=10)\n", 523 | "showmask(mask[0:1])\n", 524 | "showimages(sampled, height=16, width=16)\n", 525 | "saveimages(sampled, mode + \"_\" + text, nrow=images.shape[0])" 526 | ] 527 | }, 528 | { 529 | "cell_type": "markdown", 530 | "id": "8ad2d6da-ae1c-4c76-a494-ceb5469a649b", 531 | "metadata": {}, 532 | "source": [ 533 | "# Structural Morphing" 534 | ] 535 | }, 536 | { 537 | "cell_type": "code", 538 | "execution_count": null, 539 | "id": "cd46fd02-6770-4824-ad1d-356c6f197eaf", 540 | "metadata": {}, 541 | "outputs": [], 542 | "source": [ 543 | "mode = \"morphing\"\n", 544 | "max_steps = 24\n", 545 | "init_step = 8\n", 546 | "\n", 547 | "text = \"A fox posing for a photo. stock photo. highly detailed. 4k\"\n", 548 | "\n", 549 | "with torch.inference_mode():\n", 550 | " with torch.autocast(device_type=\"cuda\"):\n", 551 | " # images = preprocess(Image.open(\"data/city sketch.png\")).unsqueeze(0).expand(4, -1, -1, -1).to(device)[:, :3]\n", 552 | " latent_image = encode(images)\n", 553 | " latent_shape = latent_image.shape[-2:]\n", 554 | " r = torch.ones(latent_image.size(0), device=device) * (init_step/max_steps)\n", 555 | " noised_latent_image, _ = model.add_noise(latent_image, r)\n", 556 | " \n", 557 | " tokenized_text = tokenizer.tokenize([text] * images.size(0)).to(device)\n", 558 | " clip_embeddings = clip_model.encode_text(tokenized_text).float()\n", 559 | " \n", 560 | " sampled = sample(model, clip_embeddings, x=noised_latent_image, T=max_steps, size=latent_shape, starting_t=init_step, temp_range=[1.0, 1.0],\n", 561 | " typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=6, renoise_steps=max_steps-1,\n", 562 | " renoise_mode=\"prev\")\n", 563 | " sampled = decode(sampled[-1], latent_shape)\n", 564 | "showimages(sampled)\n", 565 | "showimages(images)\n", 566 | "saveimages(torch.cat([images[0:1], sampled]), mode + \"_\" + text, nrow=images.shape[0]+1)" 567 | ] 568 | }, 569 | { 570 | "cell_type": "markdown", 571 | "id": "109dbf24-8b68-4af1-9fa2-492bbf5c6a26", 572 | "metadata": {}, 573 | "source": [ 574 | "# Image Variations" 575 | ] 576 | }, 577 | { 578 | "cell_type": "code", 579 | "execution_count": null, 580 | "id": "cf81732b-9adc-4cda-9961-12480d932ff4", 581 | "metadata": {}, 582 | "outputs": [], 583 | "source": [ 584 | "clip_preprocess = torchvision.transforms.Compose([\n", 585 | " torchvision.transforms.Resize((224, 224), interpolation=torchvision.transforms.InterpolationMode.BICUBIC),\n", 586 | " torchvision.transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),\n", 587 | "])" 588 | ] 589 | }, 590 | { 591 | "cell_type": "code", 592 | "execution_count": null, 593 | "id": "ac72cd16-0f3e-407e-ba5e-ad7664254c7f", 594 | "metadata": {}, 595 | "outputs": [], 596 | "source": [ 597 | "latent_shape = (32, 32)\n", 598 | "with torch.inference_mode():\n", 599 | " with torch.autocast(device_type=\"cuda\"):\n", 600 | " clip_embeddings = clip_model.encode_image(clip_preprocess(images)).float() # clip_embeddings = clip_model.encode_text(text).float() \n", 601 | " sampled = sample(model, clip_embeddings, T=12, size=latent_shape, starting_t=0, temp_range=[1.0, 1.0],\n", 602 | " typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=5, renoise_steps=11)\n", 603 | " sampled = decode(sampled[-1], latent_shape)\n", 604 | "\n", 605 | "showimages(images)\n", 606 | "showimages(sampled)" 607 | ] 608 | }, 609 | { 610 | "cell_type": "markdown", 611 | "id": "4446993a-e5b4-40b9-a898-f2b8e9e03576", 612 | "metadata": { 613 | "jp-MarkdownHeadingCollapsed": true, 614 | "tags": [] 615 | }, 616 | "source": [ 617 | "# Experimental: Concept Learning" 618 | ] 619 | }, 620 | { 621 | "cell_type": "code", 622 | "execution_count": null, 623 | "id": "4e32d746-3cb1-4a27-a537-0e0baad20830", 624 | "metadata": {}, 625 | "outputs": [], 626 | "source": [ 627 | "def text_encode(x, clip_model, insertion_index):\n", 628 | " # x = x.type(clip_model.dtype)\n", 629 | " x = x + clip_model.positional_embedding\n", 630 | " x = x.permute(1, 0, 2) # NLD -> LND\n", 631 | " x = clip_model.transformer(x)\n", 632 | " x = x.permute(1, 0, 2) # LND -> NLD\n", 633 | " x = clip_model.ln_final(x)\n", 634 | "\n", 635 | " # x.shape = [batch_size, n_ctx, transformer.width]\n", 636 | " # take features from the eot embedding (eot_token is the highest number in each sequence)\n", 637 | " x = x[torch.arange(x.shape[0]), insertion_index] @ clip_model.text_projection\n", 638 | "\n", 639 | " return x" 640 | ] 641 | }, 642 | { 643 | "cell_type": "code", 644 | "execution_count": null, 645 | "id": "e391e34d-ddab-4718-83d3-3cbf56c19363", 646 | "metadata": {}, 647 | "outputs": [], 648 | "source": [ 649 | "from torch.optim import AdamW\n", 650 | "batch_size = 1\n", 651 | "asteriks_emb = clip_model.token_embedding(tokenizer.tokenize([\"*\"]).to(device))[0][1]\n", 652 | "context_word = torch.randn(batch_size, 1, asteriks_emb.shape[-1]).to(device)\n", 653 | "context_word.requires_grad_(True)\n", 654 | "optim = AdamW(params=[context_word], lr=0.1)\n", 655 | "criterion = nn.CrossEntropyLoss(label_smoothing=0.1)" 656 | ] 657 | }, 658 | { 659 | "cell_type": "code", 660 | "execution_count": null, 661 | "id": "ff0c9101-c62e-465c-a3aa-2d2afd5de075", 662 | "metadata": {}, 663 | "outputs": [], 664 | "source": [ 665 | "import requests\n", 666 | "from torch.utils.data import TensorDataset\n", 667 | "\n", 668 | "_preprocess = torchvision.transforms.Compose([\n", 669 | " torchvision.transforms.Resize(256),\n", 670 | " torchvision.transforms.CenterCrop(256),\n", 671 | " torchvision.transforms.ToTensor(),\n", 672 | "])\n", 673 | "\n", 674 | "urls = [\n", 675 | " \"https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcStVHtFcMqIP4xuDYn8n_FzPDKjPtP_iTSbOQ&usqp=CAU\",\n", 676 | " \"https://i.insider.com/58d919eaf2d0331b008b4bbd?width=700\",\n", 677 | " \"https://media.cntraveler.com/photos/5539216cab60aad20f3f3aaa/16:9/w_2560%2Cc_limit/eiffel-tower-paris-secret-apartment.jpg\",\n", 678 | " \"https://static.independent.co.uk/s3fs-public/thumbnails/image/2014/03/25/12/eiffel.jpg?width=1200\"\n", 679 | "]\n", 680 | "images = []\n", 681 | "for url in urls:\n", 682 | " response = requests.get(url)\n", 683 | " img = Image.open(BytesIO(response.content))\n", 684 | " images.append(_preprocess(img))\n", 685 | "\n", 686 | "data = torch.stack(images)\n", 687 | "dataset = DataLoader(TensorDataset(data), batch_size=1, shuffle=True)\n", 688 | "loader = iter(dataset)" 689 | ] 690 | }, 691 | { 692 | "cell_type": "code", 693 | "execution_count": null, 694 | "id": "d3db5f60-41d5-47e5-9c4b-b9b2911467e5", 695 | "metadata": { 696 | "scrolled": true, 697 | "tags": [] 698 | }, 699 | "outputs": [], 700 | "source": [ 701 | "steps = 100\n", 702 | "total_loss = 0\n", 703 | "total_acc = 0\n", 704 | "pbar = tqdm(range(steps))\n", 705 | "for i in pbar:\n", 706 | " try:\n", 707 | " images = next(loader)[0]\n", 708 | " except StopIteration:\n", 709 | " loader = iter(dataset)\n", 710 | " images = next(loader)[0]\n", 711 | " images = images.to(device)\n", 712 | " text = \"a photo of *\"\n", 713 | " tokenized_text = tokenizer.tokenize([text]).to(device)\n", 714 | " insertion_index = tokenized_text.argmax(dim=-1)\n", 715 | " neutral_text_encoded = clip_model.token_embedding(tokenized_text)\n", 716 | " insertion_idx = torch.where(neutral_text_encoded == asteriks_emb)[1].unique()\n", 717 | " neutral_text_encoded[:, insertion_idx, :] = context_word\n", 718 | " clip_embeddings = text_encode(neutral_text_encoded, clip_model, insertion_index)\n", 719 | " with torch.no_grad():\n", 720 | " image_indices = encode(images)\n", 721 | " r = torch.rand(images.size(0), device=device)\n", 722 | " noised_indices, mask = model.add_noise(image_indices, r)\n", 723 | "\n", 724 | " # with torch.autocast(device_type=\"cuda\"):\n", 725 | " pred = model(noised_indices, clip_embeddings, r)\n", 726 | " loss = criterion(pred, image_indices)\n", 727 | " \n", 728 | " loss.backward()\n", 729 | " optim.step()\n", 730 | " optim.zero_grad()\n", 731 | " \n", 732 | " acc = (pred.argmax(1) == image_indices).float() # .mean()\n", 733 | " acc = acc.mean()\n", 734 | "\n", 735 | " total_loss += loss.item()\n", 736 | " total_acc += acc.item()\n", 737 | " pbar.set_postfix({\"total_loss\": total_loss / (i+1), \"total_acc\": total_acc / (i+1)})\n" 738 | ] 739 | }, 740 | { 741 | "cell_type": "code", 742 | "execution_count": null, 743 | "id": "ce95a42f-e121-43da-b432-b5f81c408ba8", 744 | "metadata": {}, 745 | "outputs": [], 746 | "source": [ 747 | "with torch.inference_mode():\n", 748 | " with torch.autocast(device_type=\"cuda\"):\n", 749 | " sampled = sample(model, clip_embeddings.expand(4, -1), T=12, size=(32, 32), starting_t=0, temp_range=[1., 1.],\n", 750 | " typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=4, renoise_steps=11)\n", 751 | " sampled = decode(sampled[-1])\n", 752 | "\n", 753 | "plt.figure(figsize=(32, 32))\n", 754 | "plt.axis(\"off\")\n", 755 | "plt.imshow(torch.cat([\n", 756 | " torch.cat([i for i in images.expand(4, -1, -1, -1).cpu()], dim=-1),\n", 757 | " torch.cat([i for i in sampled.cpu()], dim=-1),\n", 758 | "], dim=-2).permute(1, 2, 0).cpu())\n", 759 | "plt.show()" 760 | ] 761 | }, 762 | { 763 | "cell_type": "code", 764 | "execution_count": null, 765 | "id": "31eac484-b3bf-428a-8e59-1742dbca6cef", 766 | "metadata": {}, 767 | "outputs": [], 768 | "source": [ 769 | "text = \"* at night\"\n", 770 | "tokenized_text = tokenizer.tokenize([text]).to(device)\n", 771 | "insertion_index = tokenized_text.argmax(dim=-1)\n", 772 | "neutral_text_encoded = clip_model.token_embedding(tokenized_text)\n", 773 | "insertion_idx = torch.where(neutral_text_encoded == asteriks_emb)[1].unique()\n", 774 | "neutral_text_encoded[:, insertion_idx, :] = context_word\n", 775 | "clip_embeddings = text_encode(neutral_text_encoded, clip_model, insertion_index)\n", 776 | "with torch.inference_mode():\n", 777 | " with torch.autocast(device_type=\"cuda\"):\n", 778 | " sampled = sample(model, clip_embeddings.expand(4, -1), T=12, size=(32, 32), starting_t=0, temp_range=[1., 1.],\n", 779 | " typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=4, renoise_steps=11)\n", 780 | " sampled = decode(sampled[-1])\n", 781 | "\n", 782 | "plt.figure(figsize=(32, 32))\n", 783 | "plt.axis(\"off\")\n", 784 | "plt.imshow(torch.cat([\n", 785 | " torch.cat([i for i in images.expand(4, -1, -1, -1).cpu()], dim=-1),\n", 786 | " torch.cat([i for i in sampled.cpu()], dim=-1),\n", 787 | "], dim=-2).permute(1, 2, 0).cpu())\n", 788 | "plt.show()" 789 | ] 790 | } 791 | ], 792 | "metadata": { 793 | "colab": { 794 | "collapsed_sections": [], 795 | "name": "DenoiseGIT_sampling.ipynb", 796 | "provenance": [] 797 | }, 798 | "kernelspec": { 799 | "display_name": "Python 3 (ipykernel)", 800 | "language": "python", 801 | "name": "python3" 802 | }, 803 | "language_info": { 804 | "codemirror_mode": { 805 | "name": "ipython", 806 | "version": 3 807 | }, 808 | "file_extension": ".py", 809 | "mimetype": "text/x-python", 810 | "name": "python", 811 | "nbconvert_exporter": "python", 812 | "pygments_lexer": "ipython3", 813 | "version": "3.8.12" 814 | } 815 | }, 816 | "nbformat": 4, 817 | "nbformat_minor": 5 818 | } 819 | --------------------------------------------------------------------------------