├── assets ├── ocr_v2.png ├── original.png ├── reconstruction.png ├── table_comparison_1.PNG └── table_comparison_2.PNG ├── taming ├── modules │ ├── losses │ │ ├── __init__.py │ │ ├── vgg16_bn.py │ │ ├── craft.py │ │ ├── lpips.py │ │ └── vqperceptual.py │ ├── discriminator │ │ └── model.py │ ├── util.py │ ├── vqvae │ │ └── quantize.py │ └── diffusionmodules │ │ └── model.py ├── data │ ├── custom.py │ ├── base.py │ ├── image_transforms.py │ └── utils.py ├── lr_scheduler.py ├── util.py └── models │ └── vqgan.py ├── setup.py ├── .gitignore ├── environment.yaml ├── scripts ├── parse_paper2fig1_img_to_VQGAN.py ├── parse_ICDAR2013_img_to_VQGAN.py ├── compute_ocr_perceptual_loss.py ├── compute_ssim.py ├── prepare_eval_samples.py ├── plot_ocr_features.py ├── generate_qualitative_results.py └── evaluate_DALLE_VQVAE.py ├── configs ├── ocr-vqgan-f16-c16384-d256.yaml └── ocr-vqgan-imagenet-16384.yaml ├── README.md └── main.py /assets/ocr_v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joanrod/ocr-vqgan/HEAD/assets/ocr_v2.png -------------------------------------------------------------------------------- /assets/original.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joanrod/ocr-vqgan/HEAD/assets/original.png -------------------------------------------------------------------------------- /assets/reconstruction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joanrod/ocr-vqgan/HEAD/assets/reconstruction.png -------------------------------------------------------------------------------- /taming/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from taming.modules.losses.vqperceptual import DummyLoss 2 | 3 | -------------------------------------------------------------------------------- /assets/table_comparison_1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joanrod/ocr-vqgan/HEAD/assets/table_comparison_1.PNG -------------------------------------------------------------------------------- /assets/table_comparison_2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joanrod/ocr-vqgan/HEAD/assets/table_comparison_2.PNG -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='taming-transformers-ocrvqgan', 5 | version='0.0.1', 6 | description='Taming Transformers for High-Resolution Image Synthesis + OCR Perceptual Loss', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | ], 13 | ) 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | logs/* 2 | test_* 3 | .git 4 | __pycache__/* 5 | *__pycache__/ 6 | *.egg-info/ 7 | ckpts/* 8 | taming/__pycache__/ 9 | taming/data/__pycache__/ 10 | taming/models/__pycache__/ 11 | taming_transformers.egg-info/ 12 | output/* 13 | lightning_logs/* 14 | scripts/taming/* 15 | taming/modules/autoencoder/lpips/vgg.pth 16 | taming/modules/autoencoder/ocr_perceptual/craft_mlt_25k.pth 17 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: ocr-vqgan 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.3 9 | - pytorch=1.11.0 10 | - torchvision=0.12.0 11 | - numpy=1.19.2 12 | - pip: 13 | - albumentations==0.4.3 14 | - opencv-python==4.1.2.30 15 | - pudb==2019.2 16 | - invisible-watermark 17 | - imageio==2.9.0 18 | - imageio-ffmpeg==0.4.2 19 | - pytorch-lightning==1.4.2 20 | - omegaconf==2.1.1 21 | - streamlit>=0.73.1 22 | - einops==0.3.0 23 | - torch-fidelity==0.3.0 24 | - torchmetrics==0.6.0 25 | - wandb==0.13.4 26 | - SSIM-PIL 27 | - torch-fidelity 28 | - -e . 29 | -------------------------------------------------------------------------------- /taming/data/custom.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | 5 | from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex 6 | 7 | 8 | class CustomBase(Dataset): 9 | def __init__(self, *args, **kwargs): 10 | super().__init__() 11 | self.data = None 12 | 13 | def __len__(self): 14 | return len(self.data) 15 | 16 | def __getitem__(self, i): 17 | example = self.data[i] 18 | return example 19 | 20 | 21 | 22 | class CustomTrain(CustomBase): 23 | def __init__(self, size, training_images_list_file, random_crop=True, augment=True): 24 | super().__init__() 25 | with open(training_images_list_file, "r") as f: 26 | paths = f.read().splitlines() 27 | self.data = ImagePaths(paths=paths, size=size, random_crop=random_crop, augment=augment) 28 | 29 | 30 | class CustomTest(CustomBase): 31 | def __init__(self, size, test_images_list_file, random_crop=False, augment=False): 32 | super().__init__() 33 | with open(test_images_list_file, "r") as f: 34 | paths = f.read().splitlines() 35 | self.data = ImagePaths(paths=paths, size=size, random_crop=random_crop, augment=augment) 36 | 37 | 38 | -------------------------------------------------------------------------------- /scripts/parse_paper2fig1_img_to_VQGAN.py: -------------------------------------------------------------------------------- 1 | # File: parse_paper2fig1_img_to_VQGAN.py 2 | # Created by Juan A. Rodriguez on 12/06/2022 3 | # Goal: This script is intended to access the json files corresponding to the paper2fig dataset (train, val) 4 | # and convert them to the format required by the VQ-GAN, that is, a txt file containing the image path, 5 | 6 | import json 7 | import os 8 | from tqdm import tqdm 9 | import argparse 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--path", type=str, required=True, help="Path to dataset root, containing json files and figures directory") 13 | args = parser.parse_args() 14 | 15 | if __name__ == '__main__': 16 | path = args.path 17 | splits = ['train', 'test'] 18 | count = 0 19 | for split in splits: 20 | with open(os.path.join(path, f'paper2fig_{split}.json')) as f: 21 | data = json.load(f) 22 | for item in tqdm(data): 23 | path_img = os.path.join(path, 'figures', f'{item["figure_id"]}.png') 24 | # append to txt file 25 | with open(path + '/paper2fig1_img_'+split+'.txt', 'a') as f: 26 | f.write(path_img + '\n') 27 | count += 1 28 | print(f"Stored {count} images in paper2fig1_img_{split}.txt") 29 | -------------------------------------------------------------------------------- /taming/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n): 33 | return self.schedule(n) 34 | 35 | -------------------------------------------------------------------------------- /scripts/parse_ICDAR2013_img_to_VQGAN.py: -------------------------------------------------------------------------------- 1 | # File: parse_ICDAR2013_img_to_VQGAN.py 2 | # Created by Juan A. Rodriguez on 12/06/2022 3 | # Goal: This script is intended to access the json files corresponding to the ICDAR13 dataset (train, val) 4 | # and convert them to the format required by the VQGAN, that is, a txt file containing the image path, 5 | 6 | import os 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--path", type=str, required=True, help="Path to dataset root, containing image directories (train and test)") 11 | args = parser.parse_args() 12 | 13 | if __name__ == '__main__': 14 | splits = ['train', 'test'] 15 | count = 0 16 | for split in splits: 17 | if split == "train": 18 | split_dir_name = "Challenge2_Training_Task12_Images" 19 | else: 20 | split_dir_name = "Challenge2_Test_Task12_Images" 21 | path = os.path.join(args.path, split_dir_name) 22 | for filename in os.listdir(path): 23 | if filename.endswith(".jpg"): 24 | path_img = os.path.join(path, filename) 25 | # append to txt file 26 | with open(os.path.join(path, 'ICDAR_2013_img_'+split+'.txt'), 'a') as f: 27 | f.write(path_img + '\n') 28 | count += 1 29 | print(f"Stored {count} images in paper2fig1_img_{split}.txt") 30 | -------------------------------------------------------------------------------- /scripts/compute_ocr_perceptual_loss.py: -------------------------------------------------------------------------------- 1 | # File: compute_ocr_perceptual_loss.py 2 | # Created by Juan A. Rodriguez on 18/06/2022 3 | # Goal: Script to compute OCR perceptual loss from a pair of images (input and reconstruction) 4 | 5 | import argparse 6 | from taming.modules.losses.lpips import OCR_CRAFT_LPIPS 7 | from PIL import Image 8 | import numpy as np 9 | import torchvision.transforms as T 10 | import torch 11 | 12 | # argument parsing 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--input_path", type=str, required=True, help="Path to input image") 15 | parser.add_argument("--recons_path", type=str, required=True, help="Path to reconstructed image") 16 | args = parser.parse_args() 17 | 18 | def get_image_tensor(image_path): 19 | image = Image.open(image_path) 20 | if not image.mode == "RGB": 21 | image = image.convert("RGB") 22 | image = np.array(image).astype(np.uint8) 23 | image = (image/127.5 - 1.0).astype(np.float32) 24 | return torch.unsqueeze(T.ToTensor()(image), 0) 25 | 26 | if __name__ == "__main__": 27 | # Load image and reconstruction to tensors 28 | input_path = args.input_path 29 | recons_path = args.recons_path 30 | 31 | input_tensor = get_image_tensor(input_path).cuda() 32 | rec_tensor = get_image_tensor(recons_path).cuda() 33 | 34 | OCR_perceptual_loss = OCR_CRAFT_LPIPS().eval() 35 | OCR_perceptual_loss.cuda() 36 | 37 | ocr_sim = OCR_perceptual_loss(input_tensor, rec_tensor) 38 | -------------------------------------------------------------------------------- /scripts/compute_ssim.py: -------------------------------------------------------------------------------- 1 | # File: compute_ssim.py 2 | # Created by Juan A. Rodriguez on 18/06/2022 3 | # Goal: Util script to compute Structural Similarity Index (SSIM) 4 | # from two sets of images (original and reconstructuted) 5 | # You must pass both directory paths in input1 and input2 6 | 7 | import argparse 8 | import os 9 | from PIL import Image 10 | from SSIM_PIL import compare_ssim 11 | from tqdm import tqdm 12 | import multiprocessing as mp 13 | 14 | # argument parsing 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--input1', type=str, required = True, help='path to directory 1 ') 17 | parser.add_argument('--input2', type=str, required = True, help='path to directory 2') 18 | args = parser.parse_args() 19 | 20 | if __name__ == "__main__": 21 | MAX_WORKERS = mp.cpu_count() 22 | CHUNK_SIZE = 50 23 | 24 | print(f'Working with :{MAX_WORKERS} CPUs on Multiprocessing') 25 | 26 | files_input1 = os.listdir(args.input1) 27 | files_input2 = os.listdir(args.input2) 28 | 29 | pair_image_tuples = zip(files_input1, files_input2) 30 | 31 | def compute_ssim(pair): 32 | im_1 = Image.open(os.path.join(args.input1, pair[0])) 33 | im_2 = Image.open(os.path.join(args.input2, pair[1])) 34 | return compare_ssim(im_1, im_2) 35 | 36 | with mp.Pool(processes=MAX_WORKERS) as p: 37 | ssim_list = list( 38 | tqdm(p.imap(compute_ssim, list(pair_image_tuples), CHUNK_SIZE), total=len(files_input2))) 39 | 40 | ssim_score = sum(ssim_list) / len(ssim_list) 41 | 42 | print(f'SSIM score: {ssim_score}') 43 | -------------------------------------------------------------------------------- /configs/ocr-vqgan-f16-c16384-d256.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-04 3 | target: taming.models.vqgan.VQModel 4 | params: 5 | embed_dim: 256 6 | n_embed: 16384 7 | ddconfig: 8 | double_z: false 9 | z_channels: 256 10 | resolution: 256 11 | in_channels: 3 12 | out_ch: 3 13 | ch: 128 14 | ch_mult: 15 | - 1 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 16 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminatorOCR 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 100000 30 | disc_weight: 0.75 31 | disc_num_layers: 2 32 | codebook_weight: 1.0 33 | 34 | perceptual_weight: 0.2 35 | ocr_perceptual_weight: 1.0 36 | 37 | 38 | data: 39 | target: main.DataModuleFromConfig 40 | params: 41 | batch_size: 1 42 | num_workers: 4 43 | train: 44 | target: taming.data.custom.CustomTrain 45 | params: 46 | training_images_list_file: /Paper2Fig100k/paper2fig1_img_train.txt 47 | size: 384 48 | random_crop: True 49 | augment: True 50 | validation: 51 | target: taming.data.custom.CustomTest 52 | params: 53 | test_images_list_file: /Paper2Fig100k/paper2fig1_img_test.txt 54 | size: 384 55 | random_crop: False 56 | augment: False 57 | test: 58 | target: taming.data.custom.CustomTest 59 | params: 60 | test_images_list_file: /Paper2Fig100k/paper2fig1_img_test.txt 61 | size: 384 62 | random_crop: False 63 | augment: False -------------------------------------------------------------------------------- /configs/ocr-vqgan-imagenet-16384.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-04 3 | target: taming.models.vqgan.VQModel 4 | params: 5 | embed_dim: 256 6 | n_embed: 16384 7 | ddconfig: 8 | double_z: false 9 | z_channels: 256 10 | resolution: 256 11 | in_channels: 3 12 | out_ch: 3 13 | ch: 128 14 | ch_mult: 15 | - 1 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 16 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminatorOCR 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 50000 30 | disc_weight: 0.75 31 | disc_num_layers: 2 32 | codebook_weight: 1.0 33 | 34 | perceptual_weight: 0.3 35 | ocr_perceptual_weight: 0.7 36 | 37 | 38 | data: 39 | target: main.DataModuleFromConfig 40 | params: 41 | batch_size: 1 42 | num_workers: 12 43 | train: 44 | target: taming.data.custom.CustomTrain 45 | params: 46 | training_images_list_file: /Paper2Fig100k/paper2fig1_img_train.txt 47 | size: 384 48 | random_crop: True 49 | augment: True 50 | validation: 51 | target: taming.data.custom.CustomTest 52 | params: 53 | test_images_list_file: /Paper2Fig100k/paper2fig1_img_test.txt 54 | size: 384 55 | random_crop: False 56 | augment: False 57 | test: 58 | target: taming.data.custom.CustomTest 59 | params: 60 | test_images_list_file: /Paper2Fig100k/paper2fig1_img_test.txt 61 | size: 384 62 | random_crop: False 63 | augment: False -------------------------------------------------------------------------------- /scripts/prepare_eval_samples.py: -------------------------------------------------------------------------------- 1 | # File: prepare_eval_samples.py 2 | # Created by Juan A. Rodriguez on 23/06/2022 3 | # Goal: Util script to process the samples in the test set and prepare them for evaluation 4 | 5 | from torch.utils.data import DataLoader 6 | from PIL import Image 7 | import numpy as np 8 | import os 9 | import argparse 10 | from taming.data.custom import CustomTest 11 | import torch 12 | 13 | # Args for the 4 models that we evaluate 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--image_txt_path", type=str, required=True, help="Path to directory containing images") 16 | parser.add_argument("--store_path", type=str, required=True, help="Path to directory containing images") 17 | parser.add_argument("--size", default=384, help="Image size") 18 | args = parser.parse_args() 19 | 20 | if __name__ == "__main__": 21 | 22 | # Create folder for output 23 | dataset = os.path.split(args.image_txt_path)[-1].split(".")[0] 24 | outpath = os.path.join(args.store_path, dataset) 25 | os.makedirs(outpath, exist_ok=True) 26 | 27 | ds = CustomTest(384, args.image_txt_path) 28 | dl = DataLoader(ds, batch_size=1, shuffle=False) 29 | 30 | for i, sample in enumerate(dl): 31 | image = sample['image'] 32 | if len(image.shape) == 3: 33 | image = image[..., None] 34 | image = image.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 35 | 36 | image = image.detach().cpu() 37 | image = torch.clamp(image, -1., 1.) 38 | image = (image+1.0)/2.0 # -1,1 -> 0,1; c,h,w 39 | image = image.transpose(1, 2).transpose(2, 3) 40 | image = image.numpy() 41 | image = (image*255).astype(np.uint8) 42 | 43 | filename = f"input_{i}.png" 44 | path = os.path.join(outpath, filename) 45 | Image.fromarray(image[0]).save(path) 46 | -------------------------------------------------------------------------------- /scripts/plot_ocr_features.py: -------------------------------------------------------------------------------- 1 | # File: plot_ocr_features.py 2 | # Created by Juan A. Rodriguez on 27/6/2022 3 | # Goal: Util script to plot OCR features from images. 4 | # It is itended to input a pair of images, (original and reconstruction/decoded), and will store ocr deep features for qualitative analysis 5 | 6 | import torch 7 | import os 8 | from PIL import Image 9 | import numpy as np 10 | import torchvision.transforms as T 11 | from taming.modules.losses.craft import CRAFT 12 | from taming.modules.util import copyStateDict 13 | import argparse 14 | 15 | # Args for the 4 models that we evaluate 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--input_path", type=str, required=True, help="Path to input image") 18 | parser.add_argument("--recons_path", type=str, required=True, help="Path to reconstructed image") 19 | parser.add_argument("--out", type=str, default='output', help="output path") 20 | args = parser.parse_args() 21 | 22 | def get_image_tensor(image_path): 23 | image = Image.open(image_path) 24 | if not image.mode == "RGB": 25 | image = image.convert("RGB") 26 | image = np.array(image).astype(np.uint8) 27 | image = (image/127.5 - 1.0).astype(np.float32) 28 | 29 | return torch.unsqueeze(T.ToTensor()(image), 0) 30 | 31 | def save_features(tensor, path): 32 | # get some feature maps 33 | feats = tensor[2][4][:,:12].squeeze(0) 34 | 35 | feats = feats.detach().cpu() 36 | feats = torch.clamp(feats, -1., 1.) 37 | feats = (feats+1.0)/2.0 # -1,1 -> 0,1; c,h,w 38 | feats = feats.reshape(-1, 3, 192, 192) 39 | feats = feats.transpose(1, 2).transpose(2, 3) 40 | feats = feats.numpy() 41 | feats = (feats*255).astype(np.uint8) 42 | 43 | for k in range(feats.shape[0]): 44 | filename = f"feature_{k}.png" 45 | im = feats[k] 46 | Image.fromarray(im).save(os.path.join(path, filename)) 47 | 48 | if __name__ == '__main__': 49 | # Load image and reconstruction to tensors 50 | input_path = args.input_path 51 | recons_path = args.recons_path 52 | 53 | input_tensor = get_image_tensor(input_path) 54 | rec_tensor = get_image_tensor(recons_path) 55 | model_path = '' 56 | craft = CRAFT(pretrained=True, freeze=True, amp=False) 57 | param = torch.load(model_path) 58 | print("Loading craft model from {}".format(model_path)) 59 | craft.load_state_dict(copyStateDict(param)) 60 | 61 | in_feats = craft(input_tensor) 62 | out_feats = craft(rec_tensor) 63 | path_store_in = '' 64 | path_store_rec = '' 65 | 66 | save_features(in_feats, path_store_in) 67 | save_features(out_feats, path_store_rec) -------------------------------------------------------------------------------- /taming/modules/discriminator/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | 4 | 5 | from taming.modules.util import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find('Conv') != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find('BatchNorm') != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class NLayerDiscriminator(nn.Module): 18 | """Defines a PatchGAN discriminator as in Pix2Pix 19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 20 | """ 21 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 22 | """Construct a PatchGAN discriminator 23 | Parameters: 24 | input_nc (int) -- the number of channels in input images 25 | ndf (int) -- the number of filters in the last conv layer 26 | n_layers (int) -- the number of conv layers in the discriminator 27 | norm_layer -- normalization layer 28 | """ 29 | super(NLayerDiscriminator, self).__init__() 30 | if not use_actnorm: 31 | norm_layer = nn.BatchNorm2d 32 | else: 33 | norm_layer = ActNorm 34 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 35 | use_bias = norm_layer.func != nn.BatchNorm2d 36 | else: 37 | use_bias = norm_layer != nn.BatchNorm2d 38 | 39 | kw = 4 40 | padw = 1 41 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 42 | nf_mult = 1 43 | nf_mult_prev = 1 44 | for n in range(1, n_layers): # gradually increase the number of filters 45 | nf_mult_prev = nf_mult 46 | nf_mult = min(2 ** n, 8) 47 | sequence += [ 48 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 49 | norm_layer(ndf * nf_mult), 50 | nn.LeakyReLU(0.2, True) 51 | ] 52 | 53 | nf_mult_prev = nf_mult 54 | nf_mult = min(2 ** n_layers, 8) 55 | sequence += [ 56 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 57 | norm_layer(ndf * nf_mult), 58 | nn.LeakyReLU(0.2, True) 59 | ] 60 | 61 | sequence += [ 62 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 63 | self.main = nn.Sequential(*sequence) 64 | 65 | def forward(self, input): 66 | """Standard forward.""" 67 | return self.main(input) 68 | -------------------------------------------------------------------------------- /taming/modules/losses/vgg16_bn.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | from torchvision import models 7 | from torchvision.models.vgg import model_urls 8 | 9 | def init_weights(modules): 10 | for m in modules: 11 | if isinstance(m, nn.Conv2d): 12 | init.xavier_uniform_(m.weight.data) 13 | if m.bias is not None: 14 | m.bias.data.zero_() 15 | elif isinstance(m, nn.BatchNorm2d): 16 | m.weight.data.fill_(1) 17 | m.bias.data.zero_() 18 | elif isinstance(m, nn.Linear): 19 | m.weight.data.normal_(0, 0.01) 20 | m.bias.data.zero_() 21 | 22 | class vgg16_bn(torch.nn.Module): 23 | def __init__(self, pretrained=True, freeze=True): 24 | super(vgg16_bn, self).__init__() 25 | model_urls['vgg16_bn'] = model_urls['vgg16_bn'].replace('https://', 'http://') 26 | vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features 27 | self.slice1 = torch.nn.Sequential() 28 | self.slice2 = torch.nn.Sequential() 29 | self.slice3 = torch.nn.Sequential() 30 | self.slice4 = torch.nn.Sequential() 31 | self.slice5 = torch.nn.Sequential() 32 | for x in range(12): # conv2_2 33 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 34 | for x in range(12, 19): # conv3_3 35 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 36 | for x in range(19, 29): # conv4_3 37 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 38 | for x in range(29, 39): # conv5_3 39 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 40 | 41 | # fc6, fc7 without atrous conv 42 | self.slice5 = torch.nn.Sequential( 43 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1), 44 | nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6), 45 | nn.Conv2d(1024, 1024, kernel_size=1) 46 | ) 47 | 48 | if not pretrained: 49 | init_weights(self.slice1.modules()) 50 | init_weights(self.slice2.modules()) 51 | init_weights(self.slice3.modules()) 52 | init_weights(self.slice4.modules()) 53 | 54 | init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7 55 | 56 | if freeze: 57 | for param in self.slice1.parameters(): # only first conv 58 | param.requires_grad= False 59 | 60 | def forward(self, X): 61 | h = self.slice1(X) 62 | h_relu2_2 = h 63 | h = self.slice2(h) 64 | h_relu3_2 = h 65 | h = self.slice3(h) 66 | h_relu4_3 = h 67 | h = self.slice4(h) 68 | h_relu5_3 = h 69 | h = self.slice5(h) 70 | h_fc7 = h 71 | vgg_outputs = namedtuple("VggOutputs", ['fc7', 'relu5_3', 'relu4_3', 'relu3_2', 'relu2_2']) 72 | out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2) 73 | return out -------------------------------------------------------------------------------- /scripts/generate_qualitative_results.py: -------------------------------------------------------------------------------- 1 | # File: generate_qualitative_results.py 2 | # Created by Juan A. Rodriguez on 18/06/2022 3 | # Goal: Access generated images from different models 4 | # and randomly pick samples and analize qualitative results. 5 | 6 | import argparse 7 | import os 8 | import random 9 | import datetime 10 | import shutil 11 | 12 | # Args for the 4 models that we evaluate 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--num_images", type=int, default=1, help="Number of random samples to extract") 15 | parser.add_argument("--DALLE", type=str, default = None, help="Path to directory containing DALLE generated images (using pretrained DALLE VQVAE)") 16 | parser.add_argument("--VQGAN_pretrained", default = None, type=str,help="Path to directory containing VQGAN_pretrained generated samples (using pretrained VQGAN on imagenet)") 17 | parser.add_argument("--VQGAN_finetuned", default = None, type=str, help="Path to directory containing VQGAN finetuned with Paper2Fig100k generated images") 18 | parser.add_argument("--OCR_VQGAN", type=str, default = None, help="Path to directory containing OCR-VQGAN finetuned with Paper2Fig100k generated images") 19 | parser.add_argument("--test_dataset", type=str, required=True, help="Path to directory containing images of test (original input images)") 20 | args = parser.parse_args() 21 | 22 | test_dataset = args.test_dataset 23 | 24 | # Note: This script must be executed one the evaluate script has been executed. Images should be in the "evaluate directory" 25 | models_to_evaluate = { 26 | "DALLE": args.DALLE if args.DALLE else None , 27 | "VQGAN_pretrained": args.VQGAN_pretrained if args.VQGAN_pretrained else None , 28 | "VQGAN_finetuned": args.VQGAN_finetuned if args.VQGAN_finetuned else None , 29 | "OCR_VQGAN":args.OCR_VQGAN if args.OCR_VQGAN else None, 30 | "Input":test_dataset 31 | } 32 | 33 | # Obtain number total images 34 | total_images = len(os.listdir(test_dataset)) 35 | 36 | # Extract num_images random generated samples 37 | rand_indeces = random.sample(range(0, total_images), k=args.num_images) 38 | 39 | # Create new directory to store comparison of samples 40 | out_dir = os.path.join('output', f'model_comparison_{datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")}') 41 | if not os.path.exists(out_dir): os.makedirs(out_dir) 42 | 43 | for key in models_to_evaluate: 44 | if models_to_evaluate[key]: 45 | model_name = key 46 | path_to_reconstructed_images = models_to_evaluate[key] 47 | list_paths_samples = os.listdir(path_to_reconstructed_images) 48 | 49 | # index those samples and store them in a folder 50 | out_dir_model = os.path.join(out_dir, model_name) 51 | if not os.path.exists(out_dir_model): os.makedirs(out_dir_model) 52 | count = 0 53 | for sample in rand_indeces: 54 | format_image = list_paths_samples[sample].split('.')[1] 55 | path_sample = os.path.join(path_to_reconstructed_images, list_paths_samples[sample]) 56 | # Parse values and mat 57 | path_out = os.path.join(out_dir_model, f"sample_{count}_id_{sample}.{format_image}") 58 | shutil.copy(path_sample, path_out) 59 | count+=1 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /taming/data/base.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import numpy as np 3 | import albumentations as A 4 | from PIL import Image 5 | from torch.utils.data import Dataset, ConcatDataset 6 | 7 | 8 | class ConcatDatasetWithIndex(ConcatDataset): 9 | """Modified from original pytorch code to return dataset idx""" 10 | def __getitem__(self, idx): 11 | if idx < 0: 12 | if -idx > len(self): 13 | raise ValueError("absolute value of index should not exceed dataset length") 14 | idx = len(self) + idx 15 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 16 | if dataset_idx == 0: 17 | sample_idx = idx 18 | else: 19 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 20 | return self.datasets[dataset_idx][sample_idx], dataset_idx 21 | 22 | 23 | class ImagePaths(Dataset): 24 | def __init__(self, paths, size=None, random_crop=False, labels=None, augment=False): 25 | self.size = size 26 | self.random_crop = random_crop 27 | self.augment = augment 28 | self.labels = dict() if labels is None else labels 29 | self.labels["file_path_"] = paths 30 | self._length = len(paths) 31 | 32 | if self.size is not None and self.size > 0: 33 | self.rescaler = A.SmallestMaxSize(max_size = self.size) 34 | if not self.random_crop: 35 | self.cropper = A.CenterCrop(height=self.size,width=self.size) 36 | else: 37 | self.cropper = A.RandomCrop(height=self.size,width=self.size) 38 | self.preprocessor = A.Compose([self.rescaler, self.cropper]) 39 | else: 40 | self.preprocessor = lambda **kwargs: kwargs 41 | 42 | if self.augment: 43 | # Add data aug transformations 44 | self.data_augmentation = A.Compose([ 45 | A.GaussianBlur(p=0.1), 46 | A.OneOf([ 47 | A.HueSaturationValue (p=0.3), 48 | A.ToGray(p=0.3), 49 | A.ChannelShuffle(p=0.3) 50 | ], p=0.3) 51 | ]) 52 | 53 | def __len__(self): 54 | return self._length 55 | 56 | def preprocess_image(self, image_path): 57 | image = Image.open(image_path) 58 | if not image.mode == "RGB": 59 | image = image.convert("RGB") 60 | image = np.array(image).astype(np.uint8) 61 | image = self.preprocessor(image=image)["image"] 62 | if self.augment: 63 | image = self.data_augmentation(image=image)['image'] 64 | image = (image/127.5 - 1.0).astype(np.float32) 65 | return image 66 | 67 | def __getitem__(self, i): 68 | example = dict() 69 | example["image"] = self.preprocess_image(self.labels["file_path_"][i]) 70 | for k in self.labels: 71 | example[k] = self.labels[k][i] 72 | return example 73 | 74 | 75 | class NumpyPaths(ImagePaths): 76 | def preprocess_image(self, image_path): 77 | image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024 78 | image = np.transpose(image, (1,2,0)) 79 | image = Image.fromarray(image, mode="RGB") 80 | image = np.array(image).astype(np.uint8) 81 | image = self.preprocessor(image=image)["image"] 82 | image = (image/127.5 - 1.0).astype(np.float32) 83 | return image 84 | -------------------------------------------------------------------------------- /taming/modules/losses/craft.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | MIT License 4 | """ 5 | 6 | # -*- coding: utf-8 -*- 7 | from numpy import source 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from collections import namedtuple 12 | 13 | from taming.modules.losses.vgg16_bn import vgg16_bn, init_weights 14 | 15 | class double_conv(nn.Module): 16 | def __init__(self, in_ch, mid_ch, out_ch): 17 | super(double_conv, self).__init__() 18 | self.conv = nn.Sequential( 19 | nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1), 20 | nn.BatchNorm2d(mid_ch), 21 | nn.ReLU(inplace=True), 22 | nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1), 23 | nn.BatchNorm2d(out_ch), 24 | nn.ReLU(inplace=True) 25 | ) 26 | 27 | def forward(self, x): 28 | x = self.conv(x) 29 | return x 30 | 31 | 32 | class CRAFT(nn.Module): 33 | def __init__(self, pretrained=True, freeze=False, amp=False): 34 | super(CRAFT, self).__init__() 35 | 36 | self.amp = amp 37 | 38 | """ Base network """ 39 | self.basenet = vgg16_bn(pretrained, freeze) 40 | 41 | """ U network """ 42 | self.upconv1 = double_conv(1024, 512, 256) 43 | self.upconv2 = double_conv(512, 256, 128) 44 | self.upconv3 = double_conv(256, 128, 64) 45 | self.upconv4 = double_conv(128, 64, 32) 46 | 47 | num_class = 2 48 | self.conv_cls = nn.Sequential( 49 | nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), 50 | nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), 51 | nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True), 52 | nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True), 53 | nn.Conv2d(16, num_class, kernel_size=1), 54 | ) 55 | 56 | init_weights(self.upconv1.modules()) 57 | init_weights(self.upconv2.modules()) 58 | init_weights(self.upconv3.modules()) 59 | init_weights(self.upconv4.modules()) 60 | init_weights(self.conv_cls.modules()) 61 | 62 | def forward(self, x): 63 | """ Base network """ 64 | if self.amp: 65 | with torch.cuda.amp.autocast(): 66 | sources = self.basenet(x) 67 | 68 | """ U network """ 69 | y = torch.cat([sources[0], sources[1]], dim=1) 70 | y = self.upconv1(y) 71 | 72 | y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False) 73 | y = torch.cat([y, sources[2]], dim=1) 74 | y = self.upconv2(y) 75 | 76 | y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False) 77 | y = torch.cat([y, sources[3]], dim=1) 78 | y = self.upconv3(y) 79 | 80 | y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False) 81 | y = torch.cat([y, sources[4]], dim=1) 82 | feature = self.upconv4(y) 83 | 84 | y = self.conv_cls(feature) 85 | 86 | return y.permute(0,2,3,1), feature 87 | else: 88 | 89 | sources = self.basenet(x) 90 | 91 | """ U network """ 92 | y = torch.cat([sources[0], sources[1]], dim=1) 93 | y = self.upconv1(y) 94 | y1 = y 95 | 96 | y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False) 97 | y = torch.cat([y, sources[2]], dim=1) 98 | y = self.upconv2(y) 99 | y2 = y 100 | 101 | y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False) 102 | y = torch.cat([y, sources[3]], dim=1) 103 | y = self.upconv3(y) 104 | y3 = y 105 | 106 | y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False) 107 | y = torch.cat([y, sources[4]], dim=1) 108 | feature = self.upconv4(y) 109 | 110 | y = self.conv_cls(feature) 111 | 112 | OCRoutputs = namedtuple("OCROutputs", ['feature', 'y3', 'y2', 'y1']) 113 | out = OCRoutputs(feature, y3, y2, y1) 114 | 115 | 116 | return y.permute(0, 2, 3, 1), feature, sources 117 | 118 | if __name__ == '__main__': 119 | model = CRAFT(pretrained=True).cuda() 120 | output, _ = model(torch.randn(1, 3, 768, 768).cuda()) 121 | print(output.shape) -------------------------------------------------------------------------------- /scripts/evaluate_DALLE_VQVAE.py: -------------------------------------------------------------------------------- 1 | # File: compute_ssim.py 2 | # Created by Juan A. Rodriguez on 18/06/2022 3 | # Goal: Evaluate VQVAE model in the text-within-image reconstruction task. 4 | # Two datasets are used, Paper2Fig100k and ICDAR13 5 | 6 | # Note: This script is no longer working, because DALLE-Pytorch is not compatble with torch==1.11.0 7 | # TODO: Find a workaround to run VQVAEs in the project 8 | 9 | from dalle_pytorch import OpenAIDiscreteVAE # pip install dalle-pytorch 10 | from torch.utils.data import DataLoader, Dataset 11 | from taming.modules.losses.lpips import LPIPS, OCR_CRAFT_LPIPS 12 | import torchvision.transforms as T 13 | import torchvision.transforms.functional as TF 14 | import albumentations as A 15 | from albumentations.pytorch import ToTensorV2 16 | from packaging import version 17 | 18 | from PIL import Image 19 | import numpy as np 20 | import os 21 | import torch 22 | from tqdm import tqdm 23 | import argparse 24 | 25 | # Args for the 4 models that we evaluate 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("--data_path", type=str, required=True, help="Path to directory containing images") 28 | parser.add_argument("--store_path", type=str, required=True, help="Path to directory containing images") 29 | args = parser.parse_args() 30 | 31 | # Add args, and add that it is all stored in VQGAN logs 32 | class ImageDataset(Dataset): 33 | def __init__(self, paths, size=None, random_crop=False, labels=None): 34 | self.size = size 35 | self.random_crop = False #random_crop 36 | 37 | with open(paths, "r") as f: 38 | self.data = f.read().splitlines() 39 | 40 | self.image_transform = A.Compose([ 41 | A.SmallestMaxSize(max_size = self.size), 42 | # A.RandomCrop(height=image_size,width=image_size), 43 | A.CenterCrop(height=self.size,width=self.size), 44 | ToTensorV2() 45 | ]) 46 | self._length = len(self.data) 47 | 48 | def __len__(self): 49 | return self._length 50 | 51 | def preprocess_image(self, image_path): 52 | image = Image.open(image_path) 53 | if not image.mode == "RGB": 54 | image = image.convert("RGB") 55 | image = np.array(image).astype(np.uint8) 56 | image = self.image_transform(image=image)["image"] 57 | return image 58 | 59 | def preprocess(self, img_path): 60 | img = Image.open(img_path) 61 | s = min(img.size) 62 | r = self.size / s 63 | s = (round(r * img.size[1]), round(r * img.size[0])) 64 | img = TF.resize(img, s, interpolation=Image.LANCZOS) 65 | img = TF.center_crop(img, output_size=2 * [self.size]) 66 | img = T.ToTensor()(img) 67 | return img 68 | 69 | def __getitem__(self, i): 70 | sample = self.preprocess(self.data[i]) 71 | return sample 72 | 73 | if __name__ == "__main__": 74 | 75 | datasets = ["Paper2Fig100k", "ICDAR2013"] 76 | 77 | save_dir = args.store_path 78 | if not os.path.exists(save_dir): os.makedirs(save_dir) 79 | 80 | B = 1 81 | IMAGE_SIZE = 384 82 | IMAGE_MODE = 'RGB' 83 | 84 | for d in datasets: 85 | name_experiment = 'openAI_VAE_' + d 86 | path_results = os.path.join(save_dir, name_experiment) 87 | if not os.path.exists(path_results): os.makedirs(path_results) 88 | 89 | if d == 'ICDAR2013': 90 | IMAGE_PATH = args.data_path + '/ICDAR2013/Challenge2_Test_Task12_Images/ICDAR_2013_img_test.txt' 91 | else: 92 | IMAGE_PATH = args.data_path + '/Paper2Fig100k/paper2fig1_img_test.txt' 93 | 94 | ds = ImageDataset(IMAGE_PATH, IMAGE_SIZE) 95 | dl = DataLoader(ds, batch_size = B, num_workers=1, shuffle=False) 96 | 97 | # Losses 98 | perceptual_loss = LPIPS().eval() 99 | perceptual_loss.cuda() 100 | ocr_perceptual_loss = OCR_CRAFT_LPIPS().eval().cuda() 101 | ocr_perceptual_loss.cuda() 102 | vae = OpenAIDiscreteVAE().cuda() 103 | vae.eval() 104 | print(f"dataset with {ds.__len__()} images") 105 | LPIPS_list = [] 106 | OCR_list = [] 107 | for i, images in tqdm(enumerate((dl))): 108 | images = images.cuda() 109 | image_tokens = vae.get_codebook_indices(images) 110 | rec_images = vae.decode(image_tokens) 111 | 112 | # Compute LPIPS 113 | LPIPS_list.append(perceptual_loss(images.contiguous(), rec_images.contiguous()).item()) 114 | # Compute OCR SIM 115 | OCR_list.append(ocr_perceptual_loss(images.contiguous(), rec_images.contiguous()).item()) 116 | 117 | # Store samples 118 | for k in range(rec_images.shape[0]): 119 | filename = f"reconstruction_batch_{i}_id_{k}.png" 120 | path = os.path.join(path_results, filename) 121 | x_rec = T.ToPILImage(mode='RGB')(rec_images[k]).save(path) 122 | 123 | print(f"LPIPS loss: {np.mean(LPIPS_list)}, OCR loss: {np.mean(OCR_list)}") 124 | -------------------------------------------------------------------------------- /taming/data/image_transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import warnings 3 | from typing import Union 4 | 5 | import torch 6 | from torch import Tensor 7 | from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor 8 | from torchvision.transforms.functional import _get_image_size as get_image_size 9 | 10 | from taming.data.helper_types import BoundingBox, Image 11 | 12 | pil_to_tensor = PILToTensor() 13 | 14 | 15 | def convert_pil_to_tensor(image: Image) -> Tensor: 16 | with warnings.catch_warnings(): 17 | # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194 18 | warnings.simplefilter("ignore") 19 | return pil_to_tensor(image) 20 | 21 | 22 | class RandomCrop1dReturnCoordinates(RandomCrop): 23 | def forward(self, img: Image) -> (BoundingBox, Image): 24 | """ 25 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 26 | Args: 27 | img (PIL Image or Tensor): Image to be cropped. 28 | 29 | Returns: 30 | Bounding box: x0, y0, w, h 31 | PIL Image or Tensor: Cropped image. 32 | 33 | Based on: 34 | torchvision.transforms.RandomCrop, torchvision 1.7.0 35 | """ 36 | if self.padding is not None: 37 | img = F.pad(img, self.padding, self.fill, self.padding_mode) 38 | 39 | width, height = get_image_size(img) 40 | # pad the width if needed 41 | if self.pad_if_needed and width < self.size[1]: 42 | padding = [self.size[1] - width, 0] 43 | img = F.pad(img, padding, self.fill, self.padding_mode) 44 | # pad the height if needed 45 | if self.pad_if_needed and height < self.size[0]: 46 | padding = [0, self.size[0] - height] 47 | img = F.pad(img, padding, self.fill, self.padding_mode) 48 | 49 | i, j, h, w = self.get_params(img, self.size) 50 | bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h 51 | return bbox, F.crop(img, i, j, h, w) 52 | 53 | 54 | class Random2dCropReturnCoordinates(torch.nn.Module): 55 | """ 56 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 57 | Args: 58 | img (PIL Image or Tensor): Image to be cropped. 59 | 60 | Returns: 61 | Bounding box: x0, y0, w, h 62 | PIL Image or Tensor: Cropped image. 63 | 64 | Based on: 65 | torchvision.transforms.RandomCrop, torchvision 1.7.0 66 | """ 67 | 68 | def __init__(self, min_size: int): 69 | super().__init__() 70 | self.min_size = min_size 71 | 72 | def forward(self, img: Image) -> (BoundingBox, Image): 73 | width, height = get_image_size(img) 74 | max_size = min(width, height) 75 | if max_size <= self.min_size: 76 | size = max_size 77 | else: 78 | size = random.randint(self.min_size, max_size) 79 | top = random.randint(0, height - size) 80 | left = random.randint(0, width - size) 81 | bbox = left / width, top / height, size / width, size / height 82 | return bbox, F.crop(img, top, left, size, size) 83 | 84 | 85 | class CenterCropReturnCoordinates(CenterCrop): 86 | @staticmethod 87 | def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox: 88 | if width > height: 89 | w = height / width 90 | h = 1.0 91 | x0 = 0.5 - w / 2 92 | y0 = 0. 93 | else: 94 | w = 1.0 95 | h = width / height 96 | x0 = 0. 97 | y0 = 0.5 - h / 2 98 | return x0, y0, w, h 99 | 100 | def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]): 101 | """ 102 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 103 | Args: 104 | img (PIL Image or Tensor): Image to be cropped. 105 | 106 | Returns: 107 | Bounding box: x0, y0, w, h 108 | PIL Image or Tensor: Cropped image. 109 | Based on: 110 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0) 111 | """ 112 | width, height = get_image_size(img) 113 | return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size) 114 | 115 | 116 | class RandomHorizontalFlipReturn(RandomHorizontalFlip): 117 | def forward(self, img: Image) -> (bool, Image): 118 | """ 119 | Additionally to flipping, returns a boolean whether it was flipped or not. 120 | Args: 121 | img (PIL Image or Tensor): Image to be flipped. 122 | 123 | Returns: 124 | flipped: whether the image was flipped or not 125 | PIL Image or Tensor: Randomly flipped image. 126 | 127 | Based on: 128 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0) 129 | """ 130 | if torch.rand(1) < self.p: 131 | return True, F.hflip(img) 132 | return False, img 133 | -------------------------------------------------------------------------------- /taming/modules/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from collections import OrderedDict 4 | 5 | def copyStateDict(state_dict): 6 | if list(state_dict.keys())[0].startswith("module"): 7 | start_idx = 1 8 | else: 9 | start_idx = 0 10 | new_state_dict = OrderedDict() 11 | for k, v in state_dict.items(): 12 | name = ".".join(k.split(".")[start_idx:]) 13 | new_state_dict[name] = v 14 | return new_state_dict 15 | 16 | 17 | def count_params(model): 18 | total_params = sum(p.numel() for p in model.parameters()) 19 | return total_params 20 | 21 | 22 | class ActNorm(nn.Module): 23 | def __init__(self, num_features, logdet=False, affine=True, 24 | allow_reverse_init=False): 25 | assert affine 26 | super().__init__() 27 | self.logdet = logdet 28 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 29 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 30 | self.allow_reverse_init = allow_reverse_init 31 | 32 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 33 | 34 | def initialize(self, input): 35 | with torch.no_grad(): 36 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 37 | mean = ( 38 | flatten.mean(1) 39 | .unsqueeze(1) 40 | .unsqueeze(2) 41 | .unsqueeze(3) 42 | .permute(1, 0, 2, 3) 43 | ) 44 | std = ( 45 | flatten.std(1) 46 | .unsqueeze(1) 47 | .unsqueeze(2) 48 | .unsqueeze(3) 49 | .permute(1, 0, 2, 3) 50 | ) 51 | 52 | self.loc.data.copy_(-mean) 53 | self.scale.data.copy_(1 / (std + 1e-6)) 54 | 55 | def forward(self, input, reverse=False): 56 | if reverse: 57 | return self.reverse(input) 58 | if len(input.shape) == 2: 59 | input = input[:,:,None,None] 60 | squeeze = True 61 | else: 62 | squeeze = False 63 | 64 | _, _, height, width = input.shape 65 | 66 | if self.training and self.initialized.item() == 0: 67 | self.initialize(input) 68 | self.initialized.fill_(1) 69 | 70 | h = self.scale * (input + self.loc) 71 | 72 | if squeeze: 73 | h = h.squeeze(-1).squeeze(-1) 74 | 75 | if self.logdet: 76 | log_abs = torch.log(torch.abs(self.scale)) 77 | logdet = height*width*torch.sum(log_abs) 78 | logdet = logdet * torch.ones(input.shape[0]).to(input) 79 | return h, logdet 80 | 81 | return h 82 | 83 | def reverse(self, output): 84 | if self.training and self.initialized.item() == 0: 85 | if not self.allow_reverse_init: 86 | raise RuntimeError( 87 | "Initializing ActNorm in reverse direction is " 88 | "disabled by default. Use allow_reverse_init=True to enable." 89 | ) 90 | else: 91 | self.initialize(output) 92 | self.initialized.fill_(1) 93 | 94 | if len(output.shape) == 2: 95 | output = output[:,:,None,None] 96 | squeeze = True 97 | else: 98 | squeeze = False 99 | 100 | h = output / self.scale - self.loc 101 | 102 | if squeeze: 103 | h = h.squeeze(-1).squeeze(-1) 104 | return h 105 | 106 | 107 | class AbstractEncoder(nn.Module): 108 | def __init__(self): 109 | super().__init__() 110 | 111 | def encode(self, *args, **kwargs): 112 | raise NotImplementedError 113 | 114 | 115 | class Labelator(AbstractEncoder): 116 | """Net2Net Interface for Class-Conditional Model""" 117 | def __init__(self, n_classes, quantize_interface=True): 118 | super().__init__() 119 | self.n_classes = n_classes 120 | self.quantize_interface = quantize_interface 121 | 122 | def encode(self, c): 123 | c = c[:,None] 124 | if self.quantize_interface: 125 | return c, None, [None, None, c.long()] 126 | return c 127 | 128 | 129 | class SOSProvider(AbstractEncoder): 130 | # for unconditional training 131 | def __init__(self, sos_token, quantize_interface=True): 132 | super().__init__() 133 | self.sos_token = sos_token 134 | self.quantize_interface = quantize_interface 135 | 136 | def encode(self, x): 137 | # get batch size from data and replicate sos_token 138 | c = torch.ones(x.shape[0], 1)*self.sos_token 139 | c = c.long().to(x.device) 140 | if self.quantize_interface: 141 | return c, None, [None, None, c] 142 | return c 143 | 144 | def copyStateDict(state_dict): 145 | if list(state_dict.keys())[0].startswith("module"): 146 | start_idx = 1 147 | else: 148 | start_idx = 0 149 | new_state_dict = OrderedDict() 150 | for k, v in state_dict.items(): 151 | name = ".".join(k.split(".")[start_idx:]) 152 | new_state_dict[name] = v 153 | return new_state_dict -------------------------------------------------------------------------------- /taming/util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1", 7 | "ocr_craft": "1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ" 8 | } 9 | 10 | CKPT_MAP = { 11 | "vgg_lpips": "vgg.pth", 12 | "ocr_craft":"craft_mlt_25k.pth" 13 | } 14 | 15 | MD5_MAP = { 16 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a", 17 | "ocr_craft": None 18 | 19 | } 20 | 21 | 22 | def download(name, local_path, chunk_size=1024): 23 | url = URL_MAP[name] 24 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 25 | if name == "ocr_craft": 26 | URL = 'https://drive.google.com/uc?export=download' 27 | session = requests.Session() 28 | response = session.get(URL, params = { 'id' : URL_MAP[name] , 'confirm': 1 }, stream = True) 29 | total_size = int(response.headers.get('Content-Length')) 30 | chunk_size = 1024 31 | 32 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 33 | with open(local_path, "wb") as f: 34 | for data in response.iter_content(chunk_size=chunk_size): 35 | if data: 36 | f.write(data) 37 | pbar.update(chunk_size) 38 | else: 39 | with requests.get(url, stream=True) as r: 40 | total_size = int(r.headers.get("content-length", 0)) 41 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 42 | with open(local_path, "wb") as f: 43 | for data in r.iter_content(chunk_size=chunk_size): 44 | if data: 45 | f.write(data) 46 | pbar.update(chunk_size) 47 | 48 | 49 | def md5_hash(path): 50 | with open(path, "rb") as f: 51 | content = f.read() 52 | return hashlib.md5(content).hexdigest() 53 | 54 | 55 | def get_ckpt_path(name, root, check=False): 56 | assert name in URL_MAP 57 | path = os.path.join(root, CKPT_MAP[name]) 58 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 59 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 60 | download(name, path) 61 | if MD5_MAP[name]: 62 | md5 = md5_hash(path) 63 | assert md5 == MD5_MAP[name], md5 64 | return path 65 | 66 | 67 | 68 | class KeyNotFoundError(Exception): 69 | def __init__(self, cause, keys=None, visited=None): 70 | self.cause = cause 71 | self.keys = keys 72 | self.visited = visited 73 | messages = list() 74 | if keys is not None: 75 | messages.append("Key not found: {}".format(keys)) 76 | if visited is not None: 77 | messages.append("Visited: {}".format(visited)) 78 | messages.append("Cause:\n{}".format(cause)) 79 | message = "\n".join(messages) 80 | super().__init__(message) 81 | 82 | 83 | def retrieve( 84 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False 85 | ): 86 | """Given a nested list or dict return the desired value at key expanding 87 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion 88 | is done in-place. 89 | 90 | Parameters 91 | ---------- 92 | list_or_dict : list or dict 93 | Possibly nested list or dictionary. 94 | key : str 95 | key/to/value, path like string describing all keys necessary to 96 | consider to get to the desired value. List indices can also be 97 | passed here. 98 | splitval : str 99 | String that defines the delimiter between keys of the 100 | different depth levels in `key`. 101 | default : obj 102 | Value returned if :attr:`key` is not found. 103 | expand : bool 104 | Whether to expand callable nodes on the path or not. 105 | 106 | Returns 107 | ------- 108 | The desired value or if :attr:`default` is not ``None`` and the 109 | :attr:`key` is not found returns ``default``. 110 | 111 | Raises 112 | ------ 113 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is 114 | ``None``. 115 | """ 116 | 117 | keys = key.split(splitval) 118 | 119 | success = True 120 | try: 121 | visited = [] 122 | parent = None 123 | last_key = None 124 | for key in keys: 125 | if callable(list_or_dict): 126 | if not expand: 127 | raise KeyNotFoundError( 128 | ValueError( 129 | "Trying to get past callable node with expand=False." 130 | ), 131 | keys=keys, 132 | visited=visited, 133 | ) 134 | list_or_dict = list_or_dict() 135 | parent[last_key] = list_or_dict 136 | 137 | last_key = key 138 | parent = list_or_dict 139 | 140 | try: 141 | if isinstance(list_or_dict, dict): 142 | list_or_dict = list_or_dict[key] 143 | else: 144 | list_or_dict = list_or_dict[int(key)] 145 | except (KeyError, IndexError, ValueError) as e: 146 | raise KeyNotFoundError(e, keys=keys, visited=visited) 147 | 148 | visited += [key] 149 | # final expansion of retrieved value 150 | if expand and callable(list_or_dict): 151 | list_or_dict = list_or_dict() 152 | parent[last_key] = list_or_dict 153 | except KeyNotFoundError as e: 154 | if default is None: 155 | raise e 156 | else: 157 | list_or_dict = default 158 | success = False 159 | 160 | if not pass_success: 161 | return list_or_dict 162 | else: 163 | return list_or_dict, success 164 | 165 | 166 | 167 | if __name__ == "__main__": 168 | config = {"keya": "a", 169 | "keyb": "b", 170 | "keyc": 171 | {"cc1": 1, 172 | "cc2": 2, 173 | } 174 | } 175 | from omegaconf import OmegaConf 176 | config = OmegaConf.create(config) 177 | print(config) 178 | retrieve(config, "keya") 179 | 180 | -------------------------------------------------------------------------------- /taming/data/utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import tarfile 4 | import urllib 5 | import zipfile 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | import torch 10 | from taming.data.helper_types import Annotation 11 | from torch._six import string_classes 12 | from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format 13 | from tqdm import tqdm 14 | 15 | 16 | def unpack(path): 17 | if path.endswith("tar.gz"): 18 | with tarfile.open(path, "r:gz") as tar: 19 | tar.extractall(path=os.path.split(path)[0]) 20 | elif path.endswith("tar"): 21 | with tarfile.open(path, "r:") as tar: 22 | tar.extractall(path=os.path.split(path)[0]) 23 | elif path.endswith("zip"): 24 | with zipfile.ZipFile(path, "r") as f: 25 | f.extractall(path=os.path.split(path)[0]) 26 | else: 27 | raise NotImplementedError( 28 | "Unknown file extension: {}".format(os.path.splitext(path)[1]) 29 | ) 30 | 31 | 32 | def reporthook(bar): 33 | """tqdm progress bar for downloads.""" 34 | 35 | def hook(b=1, bsize=1, tsize=None): 36 | if tsize is not None: 37 | bar.total = tsize 38 | bar.update(b * bsize - bar.n) 39 | 40 | return hook 41 | 42 | 43 | def get_root(name): 44 | base = "data/" 45 | root = os.path.join(base, name) 46 | os.makedirs(root, exist_ok=True) 47 | return root 48 | 49 | 50 | def is_prepared(root): 51 | return Path(root).joinpath(".ready").exists() 52 | 53 | 54 | def mark_prepared(root): 55 | Path(root).joinpath(".ready").touch() 56 | 57 | 58 | def prompt_download(file_, source, target_dir, content_dir=None): 59 | targetpath = os.path.join(target_dir, file_) 60 | while not os.path.exists(targetpath): 61 | if content_dir is not None and os.path.exists( 62 | os.path.join(target_dir, content_dir) 63 | ): 64 | break 65 | print( 66 | "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath) 67 | ) 68 | if content_dir is not None: 69 | print( 70 | "Or place its content into '{}'.".format( 71 | os.path.join(target_dir, content_dir) 72 | ) 73 | ) 74 | input("Press Enter when done...") 75 | return targetpath 76 | 77 | 78 | def download_url(file_, url, target_dir): 79 | targetpath = os.path.join(target_dir, file_) 80 | os.makedirs(target_dir, exist_ok=True) 81 | with tqdm( 82 | unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_ 83 | ) as bar: 84 | urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar)) 85 | return targetpath 86 | 87 | 88 | def download_urls(urls, target_dir): 89 | paths = dict() 90 | for fname, url in urls.items(): 91 | outpath = download_url(fname, url, target_dir) 92 | paths[fname] = outpath 93 | return paths 94 | 95 | 96 | def quadratic_crop(x, bbox, alpha=1.0): 97 | """bbox is xmin, ymin, xmax, ymax""" 98 | im_h, im_w = x.shape[:2] 99 | bbox = np.array(bbox, dtype=np.float32) 100 | bbox = np.clip(bbox, 0, max(im_h, im_w)) 101 | center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3]) 102 | w = bbox[2] - bbox[0] 103 | h = bbox[3] - bbox[1] 104 | l = int(alpha * max(w, h)) 105 | l = max(l, 2) 106 | 107 | required_padding = -1 * min( 108 | center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l) 109 | ) 110 | required_padding = int(np.ceil(required_padding)) 111 | if required_padding > 0: 112 | padding = [ 113 | [required_padding, required_padding], 114 | [required_padding, required_padding], 115 | ] 116 | padding += [[0, 0]] * (len(x.shape) - 2) 117 | x = np.pad(x, padding, "reflect") 118 | center = center[0] + required_padding, center[1] + required_padding 119 | xmin = int(center[0] - l / 2) 120 | ymin = int(center[1] - l / 2) 121 | return np.array(x[ymin : ymin + l, xmin : xmin + l, ...]) 122 | 123 | 124 | def custom_collate(batch): 125 | r"""source: pytorch 1.9.0, only one modification to original code """ 126 | 127 | elem = batch[0] 128 | elem_type = type(elem) 129 | if isinstance(elem, torch.Tensor): 130 | out = None 131 | if torch.utils.data.get_worker_info() is not None: 132 | # If we're in a background process, concatenate directly into a 133 | # shared memory tensor to avoid an extra copy 134 | numel = sum([x.numel() for x in batch]) 135 | storage = elem.storage()._new_shared(numel) 136 | out = elem.new(storage) 137 | return torch.stack(batch, 0, out=out) 138 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 139 | and elem_type.__name__ != 'string_': 140 | if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': 141 | # array of string classes and object 142 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None: 143 | raise TypeError(default_collate_err_msg_format.format(elem.dtype)) 144 | 145 | return custom_collate([torch.as_tensor(b) for b in batch]) 146 | elif elem.shape == (): # scalars 147 | return torch.as_tensor(batch) 148 | elif isinstance(elem, float): 149 | return torch.tensor(batch, dtype=torch.float64) 150 | elif isinstance(elem, int): 151 | return torch.tensor(batch) 152 | elif isinstance(elem, string_classes): 153 | return batch 154 | elif isinstance(elem, collections.abc.Mapping): 155 | return {key: custom_collate([d[key] for d in batch]) for key in elem} 156 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple 157 | return elem_type(*(custom_collate(samples) for samples in zip(*batch))) 158 | if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added 159 | return batch # added 160 | elif isinstance(elem, collections.abc.Sequence): 161 | # check to make sure that the elements in batch have consistent size 162 | it = iter(batch) 163 | elem_size = len(next(it)) 164 | if not all(len(elem) == elem_size for elem in it): 165 | raise RuntimeError('each element in list of batch should be of equal size') 166 | transposed = zip(*batch) 167 | return [custom_collate(samples) for samples in transposed] 168 | 169 | raise TypeError(default_collate_err_msg_format.format(elem_type)) 170 | -------------------------------------------------------------------------------- /taming/modules/losses/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | from taming.util import get_ckpt_path 8 | from taming.modules.losses.craft import CRAFT 9 | from taming.modules.util import copyStateDict 10 | 11 | 12 | class OCR_CRAFT_LPIPS(nn.Module): 13 | def __init__(self, use_dropout=True): 14 | super().__init__() 15 | self.scaling_layer = ScalingLayer() 16 | self.craft = CRAFT(pretrained=True, freeze=True, amp=False) 17 | self.load_from_pretrained() 18 | self.chns = [1024, 512, 512, 256, 128] 19 | 20 | def load_from_pretrained(self): 21 | ckpt = get_ckpt_path('ocr_craft', "taming/modules/autoencoder/ocr_perceptual") 22 | param = torch.load(ckpt) 23 | print("Loading craft model from {}".format(ckpt)) 24 | self.craft.load_state_dict(copyStateDict(param)) 25 | 26 | def forward(self, inputs, reconstructions): 27 | in0_input, in1_input = (self.scaling_layer(inputs), self.scaling_layer(reconstructions)) 28 | _,_,outs0 = self.craft(in0_input) 29 | _,_,outs1 = self.craft(in1_input) 30 | 31 | feats0, feats1, diffs = {}, {}, {} 32 | for kk in range(len(self.chns)): 33 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 34 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 35 | 36 | res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(len(self.chns))] 37 | val = res[0] 38 | for l in range(1, len(self.chns)): 39 | val += res[l] 40 | return val 41 | 42 | class LPIPS(nn.Module): 43 | # Learned perceptual metric 44 | def __init__(self, use_dropout=True): 45 | super().__init__() 46 | self.scaling_layer = ScalingLayer() 47 | self.chns = [64, 128, 256, 512, 512] # vg16 features 48 | self.net = vgg16(pretrained=True, requires_grad=False) 49 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 50 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 51 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 52 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 53 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 54 | self.load_from_pretrained() 55 | for param in self.parameters(): 56 | param.requires_grad = False 57 | 58 | def load_from_pretrained(self, name="vgg_lpips"): 59 | ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") 60 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 61 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 62 | 63 | @classmethod 64 | def from_pretrained(cls, name="vgg_lpips"): 65 | if name != "vgg_lpips": 66 | raise NotImplementedError 67 | model = cls() 68 | ckpt = get_ckpt_path(name) 69 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 70 | return model 71 | 72 | def forward(self, input, target): 73 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 74 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 75 | feats0, feats1, diffs = {}, {}, {} 76 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 77 | for kk in range(len(self.chns)): 78 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 79 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 80 | 81 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 82 | val = res[0] 83 | for l in range(1, len(self.chns)): 84 | val += res[l] 85 | return val 86 | 87 | 88 | class ScalingLayer(nn.Module): 89 | def __init__(self): 90 | super(ScalingLayer, self).__init__() 91 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 92 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 93 | 94 | def forward(self, inp): 95 | return (inp - self.shift) / self.scale 96 | 97 | 98 | class NetLinLayer(nn.Module): 99 | """ A single linear layer which does a 1x1 conv """ 100 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 101 | super(NetLinLayer, self).__init__() 102 | layers = [nn.Dropout(), ] if (use_dropout) else [] 103 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 104 | self.model = nn.Sequential(*layers) 105 | 106 | 107 | class vgg16(torch.nn.Module): 108 | def __init__(self, requires_grad=False, pretrained=True): 109 | super(vgg16, self).__init__() 110 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 111 | self.slice1 = torch.nn.Sequential() 112 | self.slice2 = torch.nn.Sequential() 113 | self.slice3 = torch.nn.Sequential() 114 | self.slice4 = torch.nn.Sequential() 115 | self.slice5 = torch.nn.Sequential() 116 | self.N_slices = 5 117 | for x in range(4): 118 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 119 | for x in range(4, 9): 120 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 121 | for x in range(9, 16): 122 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 123 | for x in range(16, 23): 124 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 125 | for x in range(23, 30): 126 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 127 | if not requires_grad: 128 | for param in self.parameters(): 129 | param.requires_grad = False 130 | 131 | def forward(self, X): 132 | h = self.slice1(X) 133 | h_relu1_2 = h 134 | h = self.slice2(h) 135 | h_relu2_2 = h 136 | h = self.slice3(h) 137 | h_relu3_3 = h 138 | h = self.slice4(h) 139 | h_relu4_3 = h 140 | h = self.slice5(h) 141 | h_relu5_3 = h 142 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 143 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 144 | return out 145 | 146 | 147 | def normalize_tensor(x,eps=1e-10): 148 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 149 | return x/(norm_factor+eps) 150 | 151 | 152 | def spatial_average(x, keepdim=True): 153 | return x.mean([2,3],keepdim=keepdim) 154 | -------------------------------------------------------------------------------- /taming/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from taming.modules.losses.lpips import LPIPS, OCR_CRAFT_LPIPS 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | 8 | 9 | class DummyLoss(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | 14 | def adopt_weight(weight, global_step, threshold=0, value=0.): 15 | if global_step < threshold: 16 | weight = value 17 | return weight 18 | 19 | 20 | def hinge_d_loss(logits_real, logits_fake): 21 | loss_real = torch.mean(F.relu(1. - logits_real)) 22 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 23 | d_loss = 0.5 * (loss_real + loss_fake) 24 | return d_loss 25 | 26 | 27 | def vanilla_d_loss(logits_real, logits_fake): 28 | d_loss = 0.5 * ( 29 | torch.mean(torch.nn.functional.softplus(-logits_real)) + 30 | torch.mean(torch.nn.functional.softplus(logits_fake))) 31 | return d_loss 32 | 33 | 34 | class VQLPIPSWithDiscriminator(nn.Module): 35 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 36 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 37 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 38 | disc_ndf=64, disc_loss="hinge"): 39 | super().__init__() 40 | assert disc_loss in ["hinge", "vanilla"] 41 | self.codebook_weight = codebook_weight 42 | self.pixel_weight = pixelloss_weight 43 | self.perceptual_loss = LPIPS().eval() 44 | self.perceptual_weight = perceptual_weight 45 | 46 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 47 | n_layers=disc_num_layers, 48 | use_actnorm=use_actnorm, 49 | ndf=disc_ndf 50 | ).apply(weights_init) 51 | self.discriminator_iter_start = disc_start 52 | if disc_loss == "hinge": 53 | self.disc_loss = hinge_d_loss 54 | elif disc_loss == "vanilla": 55 | self.disc_loss = vanilla_d_loss 56 | else: 57 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 58 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 59 | self.disc_factor = disc_factor 60 | self.discriminator_weight = disc_weight 61 | self.disc_conditional = disc_conditional 62 | 63 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 64 | if last_layer is not None: 65 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 66 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 67 | else: 68 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 69 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 70 | 71 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 72 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 73 | d_weight = d_weight * self.discriminator_weight 74 | return d_weight 75 | 76 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 77 | global_step, last_layer=None, cond=None, split="train"): 78 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 79 | if self.perceptual_weight > 0: 80 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 81 | rec_loss = rec_loss + self.perceptual_weight * p_loss 82 | else: 83 | p_loss = torch.tensor([0.0]) 84 | 85 | nll_loss = rec_loss 86 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 87 | nll_loss = torch.mean(nll_loss) 88 | 89 | # now the GAN part 90 | if optimizer_idx == 0: 91 | # generator update 92 | if cond is None: 93 | assert not self.disc_conditional 94 | logits_fake = self.discriminator(reconstructions.contiguous()) 95 | else: 96 | assert self.disc_conditional 97 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 98 | g_loss = -torch.mean(logits_fake) 99 | 100 | try: 101 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 102 | except RuntimeError: 103 | assert not self.training 104 | d_weight = torch.tensor(0.0) 105 | 106 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 107 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 108 | 109 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 110 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 111 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 112 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 113 | "{}/p_loss".format(split): p_loss.detach().mean(), 114 | "{}/d_weight".format(split): d_weight.detach(), 115 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 116 | "{}/g_loss".format(split): g_loss.detach().mean(), 117 | } 118 | return loss, log 119 | 120 | if optimizer_idx == 1: 121 | # second pass for discriminator update 122 | if cond is None: 123 | logits_real = self.discriminator(inputs.contiguous().detach()) 124 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 125 | else: 126 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 127 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 128 | 129 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 130 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 131 | 132 | log = { 133 | "{}/disc_loss".format(split): d_loss.clone().detach().mean(), 134 | "{}/logits_real".format(split): logits_real.detach().mean(), 135 | "{}/logits_fake".format(split): logits_fake.detach().mean() 136 | } 137 | return d_loss, log 138 | 139 | class VQLPIPSWithDiscriminatorOCR(nn.Module): 140 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 141 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 142 | perceptual_weight=0.2, ocr_perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 143 | disc_ndf=64, disc_loss="hinge"): 144 | super().__init__() 145 | assert disc_loss in ["hinge", "vanilla"] 146 | self.codebook_weight = codebook_weight 147 | self.pixel_weight = pixelloss_weight 148 | self.perceptual_loss = LPIPS().eval() 149 | self.perceptual_weight = perceptual_weight 150 | 151 | # Definition of OCR perceptual losses 152 | self.ocr_perceptual_loss = OCR_CRAFT_LPIPS().eval() 153 | 154 | self.ocr_perceptual_weight = ocr_perceptual_weight 155 | 156 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 157 | n_layers=disc_num_layers, 158 | use_actnorm=use_actnorm, 159 | ndf=disc_ndf 160 | ).apply(weights_init) 161 | self.discriminator_iter_start = disc_start 162 | if disc_loss == "hinge": 163 | self.disc_loss = hinge_d_loss 164 | elif disc_loss == "vanilla": 165 | self.disc_loss = vanilla_d_loss 166 | else: 167 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 168 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 169 | self.disc_factor = disc_factor 170 | self.discriminator_weight = disc_weight 171 | self.disc_conditional = disc_conditional 172 | 173 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 174 | if last_layer is not None: 175 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 176 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 177 | else: 178 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 179 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 180 | 181 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 182 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 183 | d_weight = d_weight * self.discriminator_weight 184 | return d_weight 185 | 186 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 187 | global_step, last_layer=None, cond=None, split="train"): 188 | if split == 'test': 189 | self.perceptual_weight = 1 # Set this to one in the test set, to evaluate it 190 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 191 | if self.perceptual_weight > 0: 192 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 193 | rec_loss = rec_loss + self.perceptual_weight * p_loss 194 | else: 195 | p_loss = torch.tensor([0.0]) 196 | 197 | if self.ocr_perceptual_weight > 0: 198 | p_ocr_loss = self.ocr_perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 199 | rec_loss = rec_loss + self.ocr_perceptual_weight * p_ocr_loss 200 | else: 201 | p_ocr_loss = torch.tensor([0.0]) 202 | 203 | 204 | nll_loss = rec_loss 205 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 206 | nll_loss = torch.mean(nll_loss) 207 | 208 | # now the GAN part 209 | if optimizer_idx == 0: 210 | # generator update 211 | if cond is None: 212 | assert not self.disc_conditional 213 | logits_fake = self.discriminator(reconstructions.contiguous()) 214 | else: 215 | assert self.disc_conditional 216 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 217 | g_loss = -torch.mean(logits_fake) 218 | 219 | try: 220 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 221 | except RuntimeError: 222 | assert not self.training 223 | d_weight = torch.tensor(0.0) 224 | 225 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 226 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 227 | 228 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 229 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 230 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 231 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 232 | "{}/p_loss".format(split): p_loss.detach().mean(), 233 | "{}/d_weight".format(split): d_weight.detach(), 234 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 235 | "{}/g_loss".format(split): g_loss.detach().mean(), 236 | "{}/p_ocr_loss".format(split): p_ocr_loss.detach().mean() 237 | } 238 | return loss, log 239 | 240 | if optimizer_idx == 1: 241 | # second pass for discriminator update 242 | if cond is None: 243 | logits_real = self.discriminator(inputs.contiguous().detach()) 244 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 245 | else: 246 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 247 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 248 | 249 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 250 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 251 | 252 | log = { 253 | "{}/disc_loss".format(split): d_loss.clone().detach().mean(), 254 | "{}/logits_real".format(split): logits_real.detach().mean(), 255 | "{}/logits_fake".format(split): logits_fake.detach().mean() 256 | } 257 | return d_loss, log -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OCR-VQGAN @WACV 2023 🏝️ 2 | 3 | ## [OCR-VQGAN: Taming Text-within-Image Generation](https://arxiv.org/abs/2210.11248) 4 | 5 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2210.11248) 6 | 7 | [Juan A. Rodríguez](https://scholar.google.es/citations?user=0selhb4AAAAJ&hl=en), [David Vázquez](https://scholar.google.es/citations?user=1jHvtfsAAAAJ&hl=en), [Issam Laradji](https://scholar.google.ca/citations?user=8vRS7F0AAAAJ&hl=en), [Marco Pedersoli](https://scholar.google.com/citations?user=aVfyPAoAAAAJ&hl=en), [Pau Rodríguez](https://scholar.google.com/citations?user=IwBx73wAAAAJ) 8 | 9 | ----------- 10 | [Computer Vision Center, Autonomous University of Barcelona](http://www.cvc.uab.es/) 11 | 12 | [ServiceNow Research, Montréal, Canada](https://www.servicenow.com/research/) 13 | 14 | [ÉTS Montreal, University of Québec](https://www.etsmtl.ca/) 15 | 16 | ------------------ 17 | OCR-VQGAN is an image encoder designed to generate images that display clear and readable text. We propose to add an **OCR perceptual loss** term to the overall VQGAN loss, that encourages the learned discrete latent space to encode text patterns (i.e. learn rich latent representations to decode clear text-within-images). 18 | 19 | We experiment with OCR-VQGAN in and a novel dataset of images of figures and diagrams from research papers, called [**Paper2Fig100k dataset**](https://zenodo.org/record/7299423#.Y2lzonbMKUl). We find that using OCR-VQGAN to encode images in Paper2Fig100k results in much better figure reconstructions. 20 | 21 | This code is adapted from **VQGAN** at [CompVis/taming-transformers](https://github.com/CompVis/taming-transformers), and [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion). The OCR detector model used in OCR Perceptual loss is the **CRAFT** model from [clovaai/CRAFT-pytorch](https://github.com/clovaai/CRAFT-pytorch). 22 | 23 |

24 | comparison 25 |

26 | 27 | **Abstract** 28 | >Synthetic image generation has recently experienced significant improvements in domains such as natural image or 29 | art generation. However, the problem of figure and diagram generation remains unexplored. A challenging aspect of generating figures and diagrams is effectively rendering readable texts within the images. To alleviate this problem, we present OCR-VQGAN, an image encoder, and decoder that leverages OCR pre-trained features to optimize a text perceptual loss, encouraging the architecture to preserve 30 | high-fidelity text and diagram structure. To explore our approach, we introduce the Paper2Fig100k dataset, with over 100k images of figures and texts from research papers. The figures show architecture diagrams and methodologies of articles available at arXiv.org from fields like artificial intelligence and computer vision. Figures usually include text and discrete objects, e.g., boxes in a diagram, with lines and arrows that connect them. We demonstrate the superiority of our method by conducting several experiments on the task of figure reconstruction. Additionally, we explore the qualitative and quantitative impact of weighting different perceptual metrics in the overall loss function. 31 | 32 | ## Installation 33 | Create a [conda](https://conda.io/) environment named `ocr-vqgan`, 34 | and activate with: 35 | 36 | ```bash 37 | conda env create -f environment.yaml 38 | conda activate ocr-vqgan 39 | pip install -e . 40 | ``` 41 | 42 | ## How to use OCR Perceptual loss and OCR Similarity 43 | Because we are working with images of figures (i.e., images are non-natural), a VGG perceptual loss (LPIPS) is not enough to effectively encode and decode clear texts and sharp diagrams. We propose an additional [**OCR perceptual loss**](https://arxiv.org/abs/2210.11248) to encourage vqgan to learn a rich latent space and reconstruct clear and readable text-within-images. 44 | 45 | The OCR perceptual loss can be computed as follows. You can pass a pair of input and reconstructed images (using any type of image encoder/decoder): 46 | 47 | ```python 48 | from taming.modules.losses.lpips import OCR_CRAFT_LPIPS 49 | from PIL import Image 50 | import numpy as np 51 | import torchvision.transforms as T 52 | import torch 53 | 54 | def get_image_tensor(image_path): 55 | image = Image.open(image_path) 56 | if not image.mode == "RGB": 57 | image = image.convert("RGB") 58 | image = np.array(image).astype(np.uint8) 59 | image = (image/127.5 - 1.0).astype(np.float32) 60 | return torch.unsqueeze(T.ToTensor()(image), 0) 61 | 62 | # Load image and reconstruction to tensors 63 | input_path = 'assets/original.png' 64 | recons_path = 'assets/reconstruction.png' 65 | 66 | input_tensor = get_image_tensor(input_path).cuda() 67 | rec_tensor = get_image_tensor(recons_path).cuda() 68 | 69 | OCR_perceptual_loss = OCR_CRAFT_LPIPS().eval() 70 | OCR_perceptual_loss.cuda() 71 | 72 | ocr_sim = OCR_perceptual_loss(input_tensor, rec_tensor) 73 | ``` 74 | 75 | Our OCR-VQGAN method uses OCR perceptual loss as an additional term in the overall VQGAN loss (see [VQLPIPSWithDiscriminatorOCR](https://github.com/joanrod/ocr-vqgan/blob/bd122c0b7ae02a59c87e568aab72d1e82b754973/taming/modules/losses/vqperceptual.py#L139)). 76 | 77 | ---------------------- 78 | 79 | ## Training OCR-VQGANs 80 | 81 | Logs and checkpoints for experiments are saved into a `logs` directory. By default, this directory will be created inside the project, but we recommend passing the argument `-l dir_path` with a path where you have sufficient disk space. 82 | 83 | ### Download Paper2Fig100k dataset 84 | 85 | We train our models using Paper2Fig100k dataset, that can be downloaded [here](https://zenodo.org/record/7299423#.Y2lv7nbMKUk). Once downloaded, you will find the following structure: 86 | 87 | ``` 88 | 📂Paper2Fig100k/ 89 | ├── 📂figures 90 | │ ├── 🖼️1001.1968v1-Figure1-1.png 91 | │ ├── 🖼️1001.1988v1-Figure1-1.png 92 | │ ├── ... 93 | ├── 📜paper2fig_train.json 94 | ├── 📜paper2fig_test.json 95 | 96 | ``` 97 | 98 | The directory `figures` contains all images in the dataset, and the train and test JSON files define data about each figure (id, captions, etc.). Run the following command to prepare Paper2Figure100k samples for the OCR-VQGAN training: 99 | 100 | ```bash 101 | python scripts/parse_paper2fig1_img_to_VQGAN.py --path 102 | ``` 103 | 104 | ### Download ICDAR 13 105 | We also use ICDAR13 to evaluate OCR-VQGAN. [Download ICDAR13 dataset](https://rrc.cvc.uab.es/?ch=2&com=downloads)(train and test sets). Create a root directory `ICDAR13` and add both downloaded sets. 106 | 107 | ``` 108 | 📂ICDAR13/ 109 | ├── 📂Challenge2_Test_Task12_Images 110 | ├── 📂Challenge2_Training_Task12_Images 111 | ``` 112 | 113 | Run the following command to prepare images for evaluation of ICDAR13 with OCR-VQGAN. 114 | 115 | ```bash 116 | python scripts/parse_ICDAR2013_img_to_VQGAN.py --path 117 | ``` 118 | 119 | This will create a .txt file with the paths of the images in ICDAR13 (we unify both splits for validation). 120 | 121 | ----------------- 122 | 123 | ### Training OCR-VQGAN from scratch 124 | 125 | Create a new configuration for your model using a `config.yaml` file, or use one from the folder `configs`. Using the config file with the argument `--base` will create new experiment directory using the defined base configuration, to store checkpoints and configs. 126 | 127 | You need to modify the `training_images_list_file` and `test_images_list_file` inside the `config.yaml` file (inside `data`), to point at the .txt files that contain paths to images: 128 | ```yaml 129 | data: 130 | target: main.DataModuleFromConfig 131 | params: 132 | ... 133 | train: 134 | target: taming.data.custom.CustomTrain 135 | params: 136 | training_images_list_file: /paper2fig_train.txt 137 | ... 138 | validation: 139 | target: taming.data.custom.CustomTest 140 | params: 141 | test_images_list_file: /paper2fig1_img_test.txt 142 | ... 143 | test: 144 | target: taming.data.custom.CustomTest 145 | params: 146 | test_images_list_file: /paper2fig1_img_test.txt 147 | ... 148 | ``` 149 | 150 | Then run the following command to start training. You may need to [configure wandb](https://docs.wandb.ai/quickstart): 151 | 152 | ```bash 153 | python main.py --base configs/.yaml --logdir path_to_logdir -t --gpus 0, -p 154 | ``` 155 | ---------------------- 156 | 157 | ### Fine-tuning pre-trained VQGANs with Paper2Fig100k 🚀 158 | 159 | You can also start with VQGAN pre-trained weights and fine-tune the model with figures from PaperFig100k. There are are several VQGAN pre-trained models in [this model zoo](https://github.com/CompVis/latent-diffusion#model-zoo). For instance, we will resume from `vqgan_imagenet_16384` model available [here](https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/). The steps are the following: 160 | 161 | 1. Create a directory for the new experiment. Create the `configs` and `checkpoints` directories, and add the `ocr-vqgan/configs/ocr-vqgan-imagenet-16384.yaml` and `last.ckpt` as, 162 | 163 | ``` 164 | 📂vqgan_imagenet_16384_paper2fig/ 165 | ├── 📂configs 166 | | ├── 📜ocr-vqgan-imagenet-16384.yaml 167 | ├── 📂checkponts 168 | | ├── 📜last.ckpt 169 | ``` 170 | 171 | 2. Running the following command will automatically load the `last.ckpt` weights: 172 | 173 | ```bash 174 | python main.py -r /vqgan_imagenet_16384_paper2fig -t --gpus 0, -p 175 | ``` 176 | or resume from a specific checkpont with: 177 | 178 | ```bash 179 | python main.py -r /model.ckpt -t --gpus 0, 180 | ``` 181 | 182 | >NOTE: The first time that the training is executed, it will crash because the OCR weights are not in the pre-trained model. However, during the crash, it will update the `last.ckpt` checkpoint inside `checkpoints`. The next run will use that checkpoint and will work fine. 183 | --------------------------- 184 | 185 | ## Evaluation of OCR-VQGAN 186 | 187 | The evaluation of OCR-VQGAN consists in computing quantitative metrics for **LPIPS** and **OCR Similarity** during inference (Check the proposed metric in the [paper](https://arxiv.org/abs/2210.11248)) in a test epoch. This process also stores reconstructions in a `evaluation` directory. 188 | 189 | ```bash 190 | python main.py -r dir_model --gpus 0 191 | ``` 192 | ### Computing FID, SSIM and Qualitative results 193 | We also compute FID and SSIM of the generated images with respect to the inputs. Both operations are done over the complete sets (after the test epoch performed in the past step) 194 | 195 | #### Prepare test images 196 | Before computing FID and SSIM metrics, we need to process test samples so that they are all inside a directory and center-cropped. 197 | 198 | ```bash 199 | python prepare_eval_samples.py --image_txt_path --store_path 200 | ``` 201 | 202 | where `--image_txt_path` indicates where the txt file is located and `--store_path` defines the folder to store results. 203 | 204 | #### Compute FID 205 | FID is a metric to measure the similarity of two sets of images in terms of their data distribution. It is computed using full batches of images, not one by one. FID extracts InceptionV3 features of all the images, and computes the similarity using the mean and stdv of the deep features. We use [torch-fidelity](https://github.com/toshas/torch-fidelity) library to compute FID between the two sets defined by `--input1` and `--input2`. 206 | 207 | ```bash 208 | pip install torch-fidelity #should be already installed 209 | fidelity --gpu 0 --fid --input1 test_samples_dir --input2 evaluation_samples_dir 210 | ``` 211 | #### Compute SSIM 212 | Similarly, we propose to compute SSIM scores by passing two sets of images (input and reconstruction sets), again defining the sets as `--input1` and `--input2`. 213 | 214 | ```bash 215 | python --input1 test_samples_dir --input2 evaluation_samples_dir 216 | ``` 217 | 218 | >This script does not use GPU, but we use multiprocessing to accelerate the computation. For 20k images, and 32 CPU cores it takes around 7 minutes. 219 | 220 | #### Extract qualitative results 221 | Extract random validation samples from different models (i.e. qualitatively evaluate the same sample generation from different methods). 222 | ```bash 223 | python generate_qualitative_results.py --test_dataset dir_original__samples\ 224 | --VQGAN_pretrained dir_VQVAE_samples\ 225 | --VQGAN_finetuned dir_VQVAE_samples\ 226 | --OCR_VQGAN dir_VQVAE_samples\ 227 | ``` 228 | ------------ 229 | 230 | 231 | ### Results and models 232 | We provide quantitative and qualitative results of our model, and [links to download](https://zenodo.org/record/7299220#.Y2kr0XbMKUk). Config files in yaml format are available at `configs`. The model is defined by `f`, the downsampling factor, `Z`, the discrete codebook size, and `d`, the model embedding size. 233 | 234 | | Model | LPIPS | OCR SIM |FID | SSIM | Link | Config | 235 | |-----------------|------------|-------|----------------|---------------|-------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------| 236 | | f=16, Z=16384, d=256 | 0.08 | 0.45 | 2.02 | 0.77 | [download](https://zenodo.org/record/7299220/files/ocr-vqgan-f16-c16384-d256.zip?download=1) | configs/ocr-vqgan-f16-c16384-d256.yaml | 237 | 238 | ------------------------ 239 | 240 |

241 | comparison 242 |

243 | 244 |

245 | comparison 246 |

247 | 248 | >More details in our [paper](https://arxiv.org/abs/2210.11248) 249 | 250 | ## Related work 251 | 252 | **[Taming Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2012.09841) by Esser et al, CVPR 2021.** 253 | 254 | **[High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) by Rombach et al, CVPR 2022 Oral.** 255 | 256 | **[Character Region Awareness for Text Detection](https://arxiv.org/abs/1904.01941) by Baek et al, CVPR 2019.** 257 | 258 | ----------------- 259 | 260 | ## Citation 261 | If you use this code please cite the following paper: 262 | ```bibtex 263 | @inproceedings{rodriguez2023ocr, 264 | title={OCR-VQGAN: Taming Text-within-Image Generation}, 265 | author={Rodriguez, Juan A and Vazquez, David and Laradji, Issam and Pedersoli, Marco and Rodriguez, Pau}, 266 | booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision}, 267 | pages={3689--3698}, 268 | year={2023} 269 | } 270 | ``` 271 | 272 | ## Contact 273 | Juan A. Rodríguez (joanrg.ai@gmail.com). **We welcome collaborators!** so don't hesitate to ask us about the project. 274 | -------------------------------------------------------------------------------- /taming/models/vqgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import pytorch_lightning as pl 4 | 5 | from main import instantiate_from_config 6 | 7 | from taming.modules.diffusionmodules.model import Encoder, Decoder 8 | from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer 9 | from taming.modules.vqvae.quantize import GumbelQuantize 10 | from taming.modules.vqvae.quantize import EMAVectorQuantizer 11 | 12 | import numpy as np 13 | import os 14 | from PIL import Image 15 | 16 | class VQModel(pl.LightningModule): 17 | def __init__(self, 18 | ddconfig, 19 | lossconfig, 20 | n_embed, 21 | embed_dim, 22 | ckpt_path=None, 23 | ignore_keys=[], 24 | image_key="image", 25 | colorize_nlabels=None, 26 | monitor=None, 27 | remap=None, 28 | sane_index_shape=False, # tell vector quantizer to return indices as bhw 29 | ): 30 | super().__init__() 31 | self.image_key = image_key 32 | self.encoder = Encoder(**ddconfig) 33 | self.decoder = Decoder(**ddconfig) 34 | self.loss = instantiate_from_config(lossconfig) 35 | self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, 36 | remap=remap, sane_index_shape=sane_index_shape) 37 | self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) 38 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 39 | self.ckpt_path = ckpt_path 40 | if ckpt_path is not None: 41 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 42 | self.image_key = image_key 43 | if colorize_nlabels is not None: 44 | assert type(colorize_nlabels)==int 45 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 46 | if monitor is not None: 47 | self.monitor = monitor 48 | 49 | def init_from_ckpt(self, path, ignore_keys=list()): 50 | sd = torch.load(path, map_location="cpu")["state_dict"] 51 | keys = list(sd.keys()) 52 | for k in keys: 53 | for ik in ignore_keys: 54 | if k.startswith(ik): 55 | print("Deleting key {} from state_dict.".format(k)) 56 | del sd[k] 57 | self.load_state_dict(sd, strict=False) 58 | print(f"Restored from {path}") 59 | 60 | def encode(self, x): 61 | h = self.encoder(x) 62 | h = self.quant_conv(h) 63 | quant, emb_loss, info = self.quantize(h) 64 | return quant, emb_loss, info 65 | 66 | def decode(self, quant): 67 | quant = self.post_quant_conv(quant) 68 | dec = self.decoder(quant) 69 | return dec 70 | 71 | def decode_code(self, code_b): 72 | quant_b = self.quantize.embed_code(code_b) 73 | dec = self.decode(quant_b) 74 | return dec 75 | 76 | def forward(self, input): 77 | quant, diff, _ = self.encode(input) 78 | dec = self.decode(quant) 79 | return dec, diff 80 | 81 | def get_input(self, batch, k): 82 | x = batch[k] 83 | if len(x.shape) == 3: 84 | x = x[..., None] 85 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) 86 | return x.float() 87 | 88 | def training_step(self, batch, batch_idx, optimizer_idx): 89 | x = self.get_input(batch, self.image_key) 90 | xrec, qloss = self(x) 91 | 92 | if optimizer_idx == 0: 93 | # autoencode 94 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, 95 | last_layer=self.get_last_layer(), split="train") 96 | 97 | self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 98 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) 99 | 100 | return aeloss 101 | 102 | if optimizer_idx == 1: 103 | # discriminator 104 | discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, 105 | last_layer=self.get_last_layer(), split="train") 106 | self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 107 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) 108 | return discloss 109 | 110 | def validation_step(self, batch, batch_idx): 111 | x = self.get_input(batch, self.image_key) 112 | xrec, qloss = self(x) 113 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step, 114 | last_layer=self.get_last_layer(), split="val") 115 | 116 | discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step, 117 | last_layer=self.get_last_layer(), split="val") 118 | rec_loss = log_dict_ae["val/rec_loss"] 119 | self.log("val/rec_loss", rec_loss, 120 | prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) 121 | self.log("val/aeloss", aeloss, 122 | prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) 123 | return self.log_dict 124 | 125 | def test_step(self, batch, batch_idx): 126 | x = self.get_input(batch, self.image_key) 127 | xrec, qloss = self(x) 128 | 129 | _, log_dict = self.loss(qloss, x, xrec, 0, self.global_step, 130 | last_layer=self.get_last_layer(), split="test") 131 | 132 | # Compute and Store reconstructions 133 | im_rec = xrec.detach().cpu() 134 | im_rec = torch.clamp(im_rec, -1., 1.) 135 | im_rec = (im_rec+1.0)/2.0 # -1,1 -> 0,1; c,h,w 136 | im_rec = im_rec.transpose(1, 2).transpose(2, 3) 137 | im_rec = im_rec.numpy() 138 | im_rec = (im_rec*255).astype(np.uint8) 139 | 140 | for k in range(im_rec.shape[0]): 141 | filename = f"reconstruction_batch_{batch_idx}_id_{k}.png" 142 | path = os.path.join(self.trainer.logdir, 'evaluation', filename) 143 | im = im_rec[k] 144 | Image.fromarray(im).save(path) 145 | 146 | # Compute LPIPS 147 | LPIPS = log_dict["test/p_loss"] 148 | try: 149 | OCR_loss = log_dict["test/p_ocr_loss"] 150 | except: 151 | OCR_loss= 0.0 152 | 153 | 154 | output = dict({ 155 | 'LPIPS': LPIPS, 156 | 'OCR_loss': OCR_loss 157 | }) 158 | self.log_dict(output) 159 | return output 160 | 161 | def configure_optimizers(self): 162 | lr = self.learning_rate 163 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ 164 | list(self.decoder.parameters())+ 165 | list(self.quantize.parameters())+ 166 | list(self.quant_conv.parameters())+ 167 | list(self.post_quant_conv.parameters()), 168 | lr=lr, betas=(0.5, 0.9)) 169 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 170 | lr=lr, betas=(0.5, 0.9)) 171 | return [opt_ae, opt_disc], [] 172 | 173 | def get_last_layer(self): 174 | return self.decoder.conv_out.weight 175 | 176 | def log_images(self, batch, **kwargs): 177 | log = dict() 178 | x = self.get_input(batch, self.image_key) 179 | x = x.to(self.device) 180 | xrec, _ = self(x) 181 | if x.shape[1] > 3: 182 | # colorize with random projection 183 | assert xrec.shape[1] > 3 184 | x = self.to_rgb(x) 185 | xrec = self.to_rgb(xrec) 186 | log["inputs"] = x 187 | log["reconstructions"] = xrec 188 | return log 189 | 190 | def to_rgb(self, x): 191 | assert self.image_key == "segmentation" 192 | if not hasattr(self, "colorize"): 193 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 194 | x = F.conv2d(x, weight=self.colorize) 195 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 196 | return x 197 | 198 | 199 | class VQSegmentationModel(VQModel): 200 | def __init__(self, n_labels, *args, **kwargs): 201 | super().__init__(*args, **kwargs) 202 | self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1)) 203 | 204 | def configure_optimizers(self): 205 | lr = self.learning_rate 206 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ 207 | list(self.decoder.parameters())+ 208 | list(self.quantize.parameters())+ 209 | list(self.quant_conv.parameters())+ 210 | list(self.post_quant_conv.parameters()), 211 | lr=lr, betas=(0.5, 0.9)) 212 | return opt_ae 213 | 214 | def training_step(self, batch, batch_idx): 215 | x = self.get_input(batch, self.image_key) 216 | xrec, qloss = self(x) 217 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train") 218 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) 219 | return aeloss 220 | 221 | def validation_step(self, batch, batch_idx): 222 | x = self.get_input(batch, self.image_key) 223 | xrec, qloss = self(x) 224 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val") 225 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) 226 | total_loss = log_dict_ae["val/total_loss"] 227 | self.log("val/total_loss", total_loss, 228 | prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) 229 | return aeloss 230 | 231 | @torch.no_grad() 232 | def log_images(self, batch, **kwargs): 233 | log = dict() 234 | x = self.get_input(batch, self.image_key) 235 | x = x.to(self.device) 236 | xrec, _ = self(x) 237 | if x.shape[1] > 3: 238 | # colorize with random projection 239 | assert xrec.shape[1] > 3 240 | # convert logits to indices 241 | xrec = torch.argmax(xrec, dim=1, keepdim=True) 242 | xrec = F.one_hot(xrec, num_classes=x.shape[1]) 243 | xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float() 244 | x = self.to_rgb(x) 245 | xrec = self.to_rgb(xrec) 246 | log["inputs"] = x 247 | log["reconstructions"] = xrec 248 | return log 249 | 250 | 251 | class VQNoDiscModel(VQModel): 252 | def __init__(self, 253 | ddconfig, 254 | lossconfig, 255 | n_embed, 256 | embed_dim, 257 | ckpt_path=None, 258 | ignore_keys=[], 259 | image_key="image", 260 | colorize_nlabels=None 261 | ): 262 | super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim, 263 | ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key, 264 | colorize_nlabels=colorize_nlabels) 265 | 266 | def training_step(self, batch, batch_idx): 267 | x = self.get_input(batch, self.image_key) 268 | xrec, qloss = self(x) 269 | # autoencode 270 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train") 271 | output = pl.TrainResult(minimize=aeloss) 272 | output.log("train/aeloss", aeloss, 273 | prog_bar=True, logger=True, on_step=True, on_epoch=True) 274 | output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) 275 | return output 276 | 277 | def validation_step(self, batch, batch_idx): 278 | x = self.get_input(batch, self.image_key) 279 | xrec, qloss = self(x) 280 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val") 281 | rec_loss = log_dict_ae["val/rec_loss"] 282 | output = pl.EvalResult(checkpoint_on=rec_loss) 283 | output.log("val/rec_loss", rec_loss, 284 | prog_bar=True, logger=True, on_step=True, on_epoch=True) 285 | output.log("val/aeloss", aeloss, 286 | prog_bar=True, logger=True, on_step=True, on_epoch=True) 287 | output.log_dict(log_dict_ae) 288 | 289 | return output 290 | 291 | def configure_optimizers(self): 292 | optimizer = torch.optim.Adam(list(self.encoder.parameters())+ 293 | list(self.decoder.parameters())+ 294 | list(self.quantize.parameters())+ 295 | list(self.quant_conv.parameters())+ 296 | list(self.post_quant_conv.parameters()), 297 | lr=self.learning_rate, betas=(0.5, 0.9)) 298 | return optimizer 299 | 300 | 301 | class GumbelVQ(VQModel): 302 | def __init__(self, 303 | ddconfig, 304 | lossconfig, 305 | n_embed, 306 | embed_dim, 307 | temperature_scheduler_config, 308 | ckpt_path=None, 309 | ignore_keys=[], 310 | image_key="image", 311 | colorize_nlabels=None, 312 | monitor=None, 313 | kl_weight=1e-8, 314 | remap=None, 315 | ): 316 | 317 | z_channels = ddconfig["z_channels"] 318 | super().__init__(ddconfig, 319 | lossconfig, 320 | n_embed, 321 | embed_dim, 322 | ckpt_path=None, 323 | ignore_keys=ignore_keys, 324 | image_key=image_key, 325 | colorize_nlabels=colorize_nlabels, 326 | monitor=monitor, 327 | ) 328 | 329 | self.loss.n_classes = n_embed 330 | self.vocab_size = n_embed 331 | 332 | self.quantize = GumbelQuantize(z_channels, embed_dim, 333 | n_embed=n_embed, 334 | kl_weight=kl_weight, temp_init=1.0, 335 | remap=remap) 336 | 337 | self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp 338 | 339 | if ckpt_path is not None: 340 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 341 | 342 | def temperature_scheduling(self): 343 | self.quantize.temperature = self.temperature_scheduler(self.global_step) 344 | 345 | def encode_to_prequant(self, x): 346 | h = self.encoder(x) 347 | h = self.quant_conv(h) 348 | return h 349 | 350 | def decode_code(self, code_b): 351 | raise NotImplementedError 352 | 353 | def training_step(self, batch, batch_idx, optimizer_idx): 354 | self.temperature_scheduling() 355 | x = self.get_input(batch, self.image_key) 356 | xrec, qloss = self(x) 357 | 358 | if optimizer_idx == 0: 359 | # autoencode 360 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, 361 | last_layer=self.get_last_layer(), split="train") 362 | 363 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) 364 | self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True) 365 | return aeloss 366 | 367 | if optimizer_idx == 1: 368 | # discriminator 369 | discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, 370 | last_layer=self.get_last_layer(), split="train") 371 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) 372 | return discloss 373 | 374 | def validation_step(self, batch, batch_idx): 375 | x = self.get_input(batch, self.image_key) 376 | xrec, qloss = self(x, return_pred_indices=True) 377 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step, 378 | last_layer=self.get_last_layer(), split="val") 379 | 380 | discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step, 381 | last_layer=self.get_last_layer(), split="val") 382 | rec_loss = log_dict_ae["val/rec_loss"] 383 | self.log("val/rec_loss", rec_loss, 384 | prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) 385 | self.log("val/aeloss", aeloss, 386 | prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) 387 | self.log_dict(log_dict_ae) 388 | self.log_dict(log_dict_disc) 389 | return self.log_dict 390 | 391 | def log_images(self, batch, **kwargs): 392 | log = dict() 393 | x = self.get_input(batch, self.image_key) 394 | x = x.to(self.device) 395 | # encode 396 | h = self.encoder(x) 397 | h = self.quant_conv(h) 398 | quant, _, _ = self.quantize(h) 399 | # decode 400 | x_rec = self.decode(quant) 401 | log["inputs"] = x 402 | log["reconstructions"] = x_rec 403 | return log 404 | 405 | 406 | class EMAVQ(VQModel): 407 | def __init__(self, 408 | ddconfig, 409 | lossconfig, 410 | n_embed, 411 | embed_dim, 412 | ckpt_path=None, 413 | ignore_keys=[], 414 | image_key="image", 415 | colorize_nlabels=None, 416 | monitor=None, 417 | remap=None, 418 | sane_index_shape=False, # tell vector quantizer to return indices as bhw 419 | ): 420 | super().__init__(ddconfig, 421 | lossconfig, 422 | n_embed, 423 | embed_dim, 424 | ckpt_path=None, 425 | ignore_keys=ignore_keys, 426 | image_key=image_key, 427 | colorize_nlabels=colorize_nlabels, 428 | monitor=monitor, 429 | ) 430 | self.quantize = EMAVectorQuantizer(n_embed=n_embed, 431 | embedding_dim=embed_dim, 432 | beta=0.25, 433 | remap=remap) 434 | def configure_optimizers(self): 435 | lr = self.learning_rate 436 | #Remove self.quantize from parameter list since it is updated via EMA 437 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ 438 | list(self.decoder.parameters())+ 439 | list(self.quant_conv.parameters())+ 440 | list(self.post_quant_conv.parameters()), 441 | lr=lr, betas=(0.5, 0.9)) 442 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 443 | lr=lr, betas=(0.5, 0.9)) 444 | return [opt_ae, opt_disc], [] -------------------------------------------------------------------------------- /taming/modules/vqvae/quantize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from torch import einsum 6 | from einops import rearrange 7 | 8 | 9 | class VectorQuantizer(nn.Module): 10 | """ 11 | see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py 12 | ____________________________________________ 13 | Discretization bottleneck part of the VQ-VAE. 14 | Inputs: 15 | - n_e : number of embeddings 16 | - e_dim : dimension of embedding 17 | - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 18 | _____________________________________________ 19 | """ 20 | 21 | # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for 22 | # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be 23 | # used wherever VectorQuantizer has been used before and is additionally 24 | # more efficient. 25 | def __init__(self, n_e, e_dim, beta): 26 | super(VectorQuantizer, self).__init__() 27 | self.n_e = n_e 28 | self.e_dim = e_dim 29 | self.beta = beta 30 | 31 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 32 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) 33 | 34 | def forward(self, z): 35 | """ 36 | Inputs the output of the encoder network z and maps it to a discrete 37 | one-hot vector that is the index of the closest embedding vector e_j 38 | z (continuous) -> z_q (discrete) 39 | z.shape = (batch, channel, height, width) 40 | quantization pipeline: 41 | 1. get encoder input (B,C,H,W) 42 | 2. flatten input to (B*H*W,C) 43 | """ 44 | # reshape z -> (batch, height, width, channel) and flatten 45 | z = z.permute(0, 2, 3, 1).contiguous() 46 | z_flattened = z.view(-1, self.e_dim) 47 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 48 | 49 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 50 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \ 51 | torch.matmul(z_flattened, self.embedding.weight.t()) 52 | 53 | ## could possible replace this here 54 | # #\start... 55 | # find closest encodings 56 | min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) 57 | 58 | min_encodings = torch.zeros( 59 | min_encoding_indices.shape[0], self.n_e).to(z) 60 | min_encodings.scatter_(1, min_encoding_indices, 1) 61 | 62 | # dtype min encodings: torch.float32 63 | # min_encodings shape: torch.Size([2048, 512]) 64 | # min_encoding_indices.shape: torch.Size([2048, 1]) 65 | 66 | # get quantized latent vectors 67 | z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) 68 | #.........\end 69 | 70 | # with: 71 | # .........\start 72 | #min_encoding_indices = torch.argmin(d, dim=1) 73 | #z_q = self.embedding(min_encoding_indices) 74 | # ......\end......... (TODO) 75 | 76 | # compute loss for embedding 77 | loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ 78 | torch.mean((z_q - z.detach()) ** 2) 79 | 80 | # preserve gradients 81 | z_q = z + (z_q - z).detach() 82 | 83 | # perplexity 84 | e_mean = torch.mean(min_encodings, dim=0) 85 | perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) 86 | 87 | # reshape back to match original input shape 88 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 89 | 90 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices) 91 | 92 | def get_codebook_entry(self, indices, shape): 93 | # shape specifying (batch, height, width, channel) 94 | # TODO: check for more easy handling with nn.Embedding 95 | min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices) 96 | min_encodings.scatter_(1, indices[:,None], 1) 97 | 98 | # get quantized latent vectors 99 | z_q = torch.matmul(min_encodings.float(), self.embedding.weight) 100 | 101 | if shape is not None: 102 | z_q = z_q.view(shape) 103 | 104 | # reshape back to match original input shape 105 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 106 | 107 | return z_q 108 | 109 | 110 | class GumbelQuantize(nn.Module): 111 | """ 112 | credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) 113 | Gumbel Softmax trick quantizer 114 | Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 115 | https://arxiv.org/abs/1611.01144 116 | """ 117 | def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True, 118 | kl_weight=5e-4, temp_init=1.0, use_vqinterface=True, 119 | remap=None, unknown_index="random"): 120 | super().__init__() 121 | 122 | self.embedding_dim = embedding_dim 123 | self.n_embed = n_embed 124 | 125 | self.straight_through = straight_through 126 | self.temperature = temp_init 127 | self.kl_weight = kl_weight 128 | 129 | self.proj = nn.Conv2d(num_hiddens, n_embed, 1) 130 | self.embed = nn.Embedding(n_embed, embedding_dim) 131 | 132 | self.use_vqinterface = use_vqinterface 133 | 134 | self.remap = remap 135 | if self.remap is not None: 136 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 137 | self.re_embed = self.used.shape[0] 138 | self.unknown_index = unknown_index # "random" or "extra" or integer 139 | if self.unknown_index == "extra": 140 | self.unknown_index = self.re_embed 141 | self.re_embed = self.re_embed+1 142 | print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " 143 | f"Using {self.unknown_index} for unknown indices.") 144 | else: 145 | self.re_embed = n_embed 146 | 147 | def remap_to_used(self, inds): 148 | ishape = inds.shape 149 | assert len(ishape)>1 150 | inds = inds.reshape(ishape[0],-1) 151 | used = self.used.to(inds) 152 | match = (inds[:,:,None]==used[None,None,...]).long() 153 | new = match.argmax(-1) 154 | unknown = match.sum(2)<1 155 | if self.unknown_index == "random": 156 | new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) 157 | else: 158 | new[unknown] = self.unknown_index 159 | return new.reshape(ishape) 160 | 161 | def unmap_to_all(self, inds): 162 | ishape = inds.shape 163 | assert len(ishape)>1 164 | inds = inds.reshape(ishape[0],-1) 165 | used = self.used.to(inds) 166 | if self.re_embed > self.used.shape[0]: # extra token 167 | inds[inds>=self.used.shape[0]] = 0 # simply set to zero 168 | back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) 169 | return back.reshape(ishape) 170 | 171 | def forward(self, z, temp=None, return_logits=False): 172 | # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work 173 | hard = self.straight_through if self.training else True 174 | temp = self.temperature if temp is None else temp 175 | 176 | logits = self.proj(z) 177 | if self.remap is not None: 178 | # continue only with used logits 179 | full_zeros = torch.zeros_like(logits) 180 | logits = logits[:,self.used,...] 181 | 182 | soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) 183 | if self.remap is not None: 184 | # go back to all entries but unused set to zero 185 | full_zeros[:,self.used,...] = soft_one_hot 186 | soft_one_hot = full_zeros 187 | z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight) 188 | 189 | # + kl divergence to the prior loss 190 | qy = F.softmax(logits, dim=1) 191 | diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() 192 | 193 | ind = soft_one_hot.argmax(dim=1) 194 | if self.remap is not None: 195 | ind = self.remap_to_used(ind) 196 | if self.use_vqinterface: 197 | if return_logits: 198 | return z_q, diff, (None, None, ind), logits 199 | return z_q, diff, (None, None, ind) 200 | return z_q, diff, ind 201 | 202 | def get_codebook_entry(self, indices, shape): 203 | b, h, w, c = shape 204 | assert b*h*w == indices.shape[0] 205 | indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w) 206 | if self.remap is not None: 207 | indices = self.unmap_to_all(indices) 208 | one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() 209 | z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight) 210 | return z_q 211 | 212 | 213 | class VectorQuantizer2(nn.Module): 214 | """ 215 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly 216 | avoids costly matrix multiplications and allows for post-hoc remapping of indices. 217 | """ 218 | # NOTE: due to a bug the beta term was applied to the wrong term. for 219 | # backwards compatibility we use the buggy version by default, but you can 220 | # specify legacy=False to fix it. 221 | def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", 222 | sane_index_shape=False, legacy=True): 223 | super().__init__() 224 | self.n_e = n_e 225 | self.e_dim = e_dim 226 | self.beta = beta 227 | self.legacy = legacy 228 | 229 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 230 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) 231 | 232 | self.remap = remap 233 | if self.remap is not None: 234 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 235 | self.re_embed = self.used.shape[0] 236 | self.unknown_index = unknown_index # "random" or "extra" or integer 237 | if self.unknown_index == "extra": 238 | self.unknown_index = self.re_embed 239 | self.re_embed = self.re_embed+1 240 | print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " 241 | f"Using {self.unknown_index} for unknown indices.") 242 | else: 243 | self.re_embed = n_e 244 | 245 | self.sane_index_shape = sane_index_shape 246 | 247 | def remap_to_used(self, inds): 248 | ishape = inds.shape 249 | assert len(ishape)>1 250 | inds = inds.reshape(ishape[0],-1) 251 | used = self.used.to(inds) 252 | match = (inds[:,:,None]==used[None,None,...]).long() 253 | new = match.argmax(-1) 254 | unknown = match.sum(2)<1 255 | if self.unknown_index == "random": 256 | new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) 257 | else: 258 | new[unknown] = self.unknown_index 259 | return new.reshape(ishape) 260 | 261 | def unmap_to_all(self, inds): 262 | ishape = inds.shape 263 | assert len(ishape)>1 264 | inds = inds.reshape(ishape[0],-1) 265 | used = self.used.to(inds) 266 | if self.re_embed > self.used.shape[0]: # extra token 267 | inds[inds>=self.used.shape[0]] = 0 # simply set to zero 268 | back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) 269 | return back.reshape(ishape) 270 | 271 | def forward(self, z, temp=None, rescale_logits=False, return_logits=False): 272 | assert temp is None or temp==1.0, "Only for interface compatible with Gumbel" 273 | assert rescale_logits==False, "Only for interface compatible with Gumbel" 274 | assert return_logits==False, "Only for interface compatible with Gumbel" 275 | # reshape z -> (batch, height, width, channel) and flatten 276 | z = rearrange(z, 'b c h w -> b h w c').contiguous() 277 | z_flattened = z.view(-1, self.e_dim) 278 | # distances from z to embeddings e_j: (z - e)^2 = z^2 + e^2 - 2 e * z 279 | 280 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 281 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \ 282 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) 283 | 284 | min_encoding_indices = torch.argmin(d, dim=1) 285 | z_q = self.embedding(min_encoding_indices).view(z.shape) 286 | perplexity = None 287 | min_encodings = None 288 | 289 | # compute loss for embedding 290 | if not self.legacy: 291 | loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ 292 | torch.mean((z_q - z.detach()) ** 2) 293 | else: 294 | loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ 295 | torch.mean((z_q - z.detach()) ** 2) 296 | 297 | # preserve gradients 298 | z_q = z + (z_q - z).detach() 299 | 300 | # reshape back to match original input shape 301 | z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() 302 | 303 | if self.remap is not None: 304 | min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis 305 | min_encoding_indices = self.remap_to_used(min_encoding_indices) 306 | min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten 307 | 308 | if self.sane_index_shape: 309 | min_encoding_indices = min_encoding_indices.reshape( 310 | z_q.shape[0], z_q.shape[2], z_q.shape[3]) 311 | 312 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices) 313 | 314 | def get_codebook_entry(self, indices, shape): 315 | # shape specifying (batch, height, width, channel) 316 | if self.remap is not None: 317 | indices = indices.reshape(shape[0],-1) # add batch axis 318 | indices = self.unmap_to_all(indices) 319 | indices = indices.reshape(-1) # flatten again 320 | 321 | # get quantized latent vectors 322 | z_q = self.embedding(indices) 323 | 324 | if shape is not None: 325 | z_q = z_q.view(shape) 326 | # reshape back to match original input shape 327 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 328 | 329 | return z_q 330 | 331 | class EmbeddingEMA(nn.Module): 332 | def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5): 333 | super().__init__() 334 | self.decay = decay 335 | self.eps = eps 336 | weight = torch.randn(num_tokens, codebook_dim) 337 | self.weight = nn.Parameter(weight, requires_grad = False) 338 | self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False) 339 | self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False) 340 | self.update = True 341 | 342 | def forward(self, embed_id): 343 | return F.embedding(embed_id, self.weight) 344 | 345 | def cluster_size_ema_update(self, new_cluster_size): 346 | self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) 347 | 348 | def embed_avg_ema_update(self, new_embed_avg): 349 | self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) 350 | 351 | def weight_update(self, num_tokens): 352 | n = self.cluster_size.sum() 353 | smoothed_cluster_size = ( 354 | (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n 355 | ) 356 | #normalize embedding average with smoothed cluster size 357 | embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) 358 | self.weight.data.copy_(embed_normalized) 359 | 360 | 361 | class EMAVectorQuantizer(nn.Module): 362 | def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, 363 | remap=None, unknown_index="random"): 364 | super().__init__() 365 | self.codebook_dim = codebook_dim 366 | self.num_tokens = num_tokens 367 | self.beta = beta 368 | self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps) 369 | 370 | self.remap = remap 371 | if self.remap is not None: 372 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 373 | self.re_embed = self.used.shape[0] 374 | self.unknown_index = unknown_index # "random" or "extra" or integer 375 | if self.unknown_index == "extra": 376 | self.unknown_index = self.re_embed 377 | self.re_embed = self.re_embed+1 378 | print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " 379 | f"Using {self.unknown_index} for unknown indices.") 380 | else: 381 | self.re_embed = n_embed 382 | 383 | def remap_to_used(self, inds): 384 | ishape = inds.shape 385 | assert len(ishape)>1 386 | inds = inds.reshape(ishape[0],-1) 387 | used = self.used.to(inds) 388 | match = (inds[:,:,None]==used[None,None,...]).long() 389 | new = match.argmax(-1) 390 | unknown = match.sum(2)<1 391 | if self.unknown_index == "random": 392 | new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) 393 | else: 394 | new[unknown] = self.unknown_index 395 | return new.reshape(ishape) 396 | 397 | def unmap_to_all(self, inds): 398 | ishape = inds.shape 399 | assert len(ishape)>1 400 | inds = inds.reshape(ishape[0],-1) 401 | used = self.used.to(inds) 402 | if self.re_embed > self.used.shape[0]: # extra token 403 | inds[inds>=self.used.shape[0]] = 0 # simply set to zero 404 | back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) 405 | return back.reshape(ishape) 406 | 407 | def forward(self, z): 408 | # reshape z -> (batch, height, width, channel) and flatten 409 | #z, 'b c h w -> b h w c' 410 | z = rearrange(z, 'b c h w -> b h w c') 411 | z_flattened = z.reshape(-1, self.codebook_dim) 412 | 413 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 414 | d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ 415 | self.embedding.weight.pow(2).sum(dim=1) - 2 * \ 416 | torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n' 417 | 418 | 419 | encoding_indices = torch.argmin(d, dim=1) 420 | 421 | z_q = self.embedding(encoding_indices).view(z.shape) 422 | encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) 423 | avg_probs = torch.mean(encodings, dim=0) 424 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) 425 | 426 | if self.training and self.embedding.update: 427 | #EMA cluster size 428 | encodings_sum = encodings.sum(0) 429 | self.embedding.cluster_size_ema_update(encodings_sum) 430 | #EMA embedding average 431 | embed_sum = encodings.transpose(0,1) @ z_flattened 432 | self.embedding.embed_avg_ema_update(embed_sum) 433 | #normalize embed_avg and update weight 434 | self.embedding.weight_update(self.num_tokens) 435 | 436 | # compute loss for embedding 437 | loss = self.beta * F.mse_loss(z_q.detach(), z) 438 | 439 | # preserve gradients 440 | z_q = z + (z_q - z).detach() 441 | 442 | # reshape back to match original input shape 443 | #z_q, 'b h w c -> b c h w' 444 | z_q = rearrange(z_q, 'b h w c -> b c h w') 445 | return z_q, loss, (perplexity, encodings, encoding_indices) 446 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # main file for training vqgan models, addapted from taming-transformers and latent-diffusion repos. 2 | # Thanks to those communities the awesome work so far 3 | 4 | import argparse, os, sys, datetime, glob, importlib 5 | from omegaconf import OmegaConf 6 | import numpy as np 7 | from PIL import Image 8 | import torch 9 | import torchvision 10 | from torch.utils.data import random_split, DataLoader, Dataset 11 | import pytorch_lightning as pl 12 | from pytorch_lightning import seed_everything 13 | from pytorch_lightning.trainer import Trainer 14 | from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor 15 | from pytorch_lightning.utilities.distributed import rank_zero_only 16 | from pytorch_lightning.utilities import rank_zero_info 17 | from taming.data.utils import custom_collate 18 | from packaging import version 19 | import time 20 | import wandb 21 | 22 | def get_obj_from_str(string, reload=False): 23 | 24 | module, cls = string.rsplit(".", 1) 25 | if reload: 26 | module_imp = importlib.import_module(module) 27 | importlib.reload(module_imp) 28 | return getattr(importlib.import_module(module, package=None), cls) 29 | 30 | 31 | def get_parser(**parser_kwargs): 32 | def str2bool(v): 33 | if isinstance(v, bool): 34 | return v 35 | if v.lower() in ("yes", "true", "t", "y", "1"): 36 | return True 37 | elif v.lower() in ("no", "false", "f", "n", "0"): 38 | return False 39 | else: 40 | raise argparse.ArgumentTypeError("Boolean value expected.") 41 | 42 | parser = argparse.ArgumentParser(**parser_kwargs) 43 | parser.add_argument( 44 | "-n", 45 | "--name", 46 | type=str, 47 | const=True, 48 | default="", 49 | nargs="?", 50 | help="postfix for logdir", 51 | ) 52 | parser.add_argument( 53 | "-r", 54 | "--resume", 55 | type=str, 56 | const=True, 57 | default="", 58 | nargs="?", 59 | help="resume from logdir or checkpoint in logdir", 60 | ) 61 | parser.add_argument( 62 | "-b", 63 | "--base", 64 | nargs="*", 65 | metavar="base_config.yaml", 66 | help="paths to base configs. Loaded from left-to-right. " 67 | "Parameters can be overwritten or added with command-line options of the form `--key value`.", 68 | default=list(), 69 | ) 70 | parser.add_argument( 71 | "-t", 72 | "--train", 73 | type=str2bool, 74 | const=True, 75 | default=False, 76 | nargs="?", 77 | help="train", 78 | ) 79 | parser.add_argument( 80 | "--no-test", 81 | type=str2bool, 82 | const=True, 83 | default=False, 84 | nargs="?", 85 | help="disable test", 86 | ) 87 | parser.add_argument( 88 | "-p", 89 | "--project", 90 | help="name of new or path to existing project" 91 | ) 92 | parser.add_argument( 93 | "-d", 94 | "--debug", 95 | type=str2bool, 96 | nargs="?", 97 | const=True, 98 | default=False, 99 | help="enable post-mortem debugging", 100 | ) 101 | parser.add_argument( 102 | "-s", 103 | "--seed", 104 | type=int, 105 | default=7, 106 | help="seed for seed_everything", 107 | ) 108 | parser.add_argument( 109 | "-f", 110 | "--postfix", 111 | type=str, 112 | default="", 113 | help="post-postfix for default name", 114 | ) 115 | 116 | parser.add_argument( 117 | "-l", 118 | "--logdir", 119 | type=str, 120 | default="logs", 121 | help="directory for logging", 122 | ) 123 | 124 | return parser 125 | 126 | 127 | def nondefault_trainer_args(opt): 128 | parser = argparse.ArgumentParser() 129 | parser = Trainer.add_argparse_args(parser) 130 | args = parser.parse_args([]) 131 | return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) 132 | 133 | 134 | def instantiate_from_config(config): 135 | 136 | if not "target" in config: 137 | raise KeyError("Expected key `target` to instantiate.") 138 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 139 | 140 | 141 | class WrappedDataset(Dataset): 142 | """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" 143 | def __init__(self, dataset): 144 | self.data = dataset 145 | 146 | def __len__(self): 147 | return len(self.data) 148 | 149 | def __getitem__(self, idx): 150 | return self.data[idx] 151 | 152 | 153 | class DataModuleFromConfig(pl.LightningDataModule): 154 | def __init__(self, batch_size, train=None, validation=None, test=None, 155 | wrap=False, num_workers=None): 156 | super().__init__() 157 | self.batch_size = batch_size 158 | self.dataset_configs = dict() 159 | self.num_workers = num_workers if num_workers is not None else batch_size*2 160 | if train is not None: 161 | self.dataset_configs["train"] = train 162 | self.train_dataloader = self._train_dataloader 163 | if validation is not None: 164 | self.dataset_configs["validation"] = validation 165 | self.val_dataloader = self._val_dataloader 166 | if test is not None: 167 | self.dataset_configs["test"] = test 168 | self.test_dataloader = self._test_dataloader 169 | self.wrap = wrap 170 | 171 | def prepare_data(self): 172 | for data_cfg in self.dataset_configs.values(): 173 | instantiate_from_config(data_cfg) 174 | 175 | def setup(self, stage=None): 176 | self.datasets = dict( 177 | (k, instantiate_from_config(self.dataset_configs[k])) 178 | for k in self.dataset_configs) 179 | if self.wrap: 180 | for k in self.datasets: 181 | self.datasets[k] = WrappedDataset(self.datasets[k]) 182 | 183 | def _train_dataloader(self): 184 | return DataLoader(self.datasets["train"], batch_size=self.batch_size, 185 | num_workers=self.num_workers, shuffle=True, worker_init_fn=None) 186 | 187 | def _val_dataloader(self): 188 | return DataLoader(self.datasets["validation"], 189 | batch_size=self.batch_size, 190 | num_workers=self.num_workers, shuffle=False, worker_init_fn=None) 191 | 192 | def _test_dataloader(self): 193 | return DataLoader(self.datasets["test"], batch_size=self.batch_size, 194 | num_workers=self.num_workers, shuffle=False, worker_init_fn=None) 195 | 196 | 197 | class SetupCallback(Callback): 198 | def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config): 199 | super().__init__() 200 | self.resume = resume 201 | self.now = now 202 | self.logdir = logdir 203 | self.ckptdir = ckptdir 204 | self.cfgdir = cfgdir 205 | self.config = config 206 | self.lightning_config = lightning_config 207 | 208 | def on_keyboard_interrupt(self, trainer, pl_module): 209 | if trainer.global_rank == 0: 210 | print("Summoning checkpoint.") 211 | ckpt_path = os.path.join(self.ckptdir, "last.ckpt") 212 | trainer.save_checkpoint(ckpt_path) 213 | 214 | def on_pretrain_routine_start(self, trainer, pl_module): 215 | if trainer.global_rank == 0: 216 | if not self.resume: # Avoid storing configs if we are resuming, already there 217 | # Create logdirs and save configs 218 | os.makedirs(self.logdir, exist_ok=True) 219 | os.makedirs(self.ckptdir, exist_ok=True) 220 | os.makedirs(self.cfgdir, exist_ok=True) 221 | 222 | print("Project config") 223 | print(OmegaConf.to_yaml(self.config)) 224 | OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) 225 | print("Lightning config") 226 | print(OmegaConf.to_yaml(self.lightning_config)) 227 | OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}), os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now))) 228 | else: 229 | # ModelCheckpoint callback created log directory --- remove it 230 | if not self.resume and os.path.exists(self.logdir): 231 | dst, name = os.path.split(self.logdir) 232 | dst = os.path.join(dst, "child_runs", name) 233 | os.makedirs(os.path.split(dst)[0], exist_ok=True) 234 | try: 235 | os.rename(self.logdir, dst) 236 | except FileNotFoundError: 237 | pass 238 | 239 | class ImageLogger(Callback): 240 | def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True, 241 | rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, 242 | log_images_kwargs=None, log_local=False): 243 | super().__init__() 244 | self.rescale = rescale 245 | self.batch_freq = batch_frequency 246 | self.max_images = max_images 247 | self.logger_log_images = { 248 | pl.loggers.WandbLogger: self._wandb, 249 | pl.loggers.TestTubeLogger: self._testtube, 250 | } 251 | self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)] 252 | if not increase_log_steps: 253 | self.log_steps = [self.batch_freq] 254 | self.clamp = clamp 255 | self.disabled = disabled 256 | self.log_on_batch_idx = log_on_batch_idx 257 | self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} 258 | self.log_first_step = log_first_step 259 | self.log_local = log_local 260 | 261 | @rank_zero_only 262 | def _testtube(self, pl_module, images, batch_idx, split): 263 | for k in images: 264 | grid = torchvision.utils.make_grid(images[k]) 265 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w 266 | 267 | tag = f"{split}/{k}" 268 | pl_module.logger.experiment.add_image( 269 | tag, grid, 270 | global_step=pl_module.global_step) 271 | 272 | @rank_zero_only 273 | def _wandb(self, pl_module, images, batch_idx, split): 274 | grids = dict() 275 | for k in images: 276 | grid = torchvision.utils.make_grid(images[k]) 277 | grids[f"{split}/{k}"] = wandb.Image(grid) 278 | pl_module.logger.experiment.log(grids) 279 | 280 | 281 | @rank_zero_only 282 | def log_local(self, save_dir, split, images, 283 | global_step, current_epoch, batch_idx): 284 | root = os.path.join(save_dir, "images", split) 285 | for k in images: 286 | grid = torchvision.utils.make_grid(images[k], nrow=4) 287 | if self.rescale: 288 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w 289 | grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) 290 | grid = grid.numpy() 291 | grid = (grid * 255).astype(np.uint8) 292 | filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format( 293 | k, 294 | global_step, 295 | current_epoch, 296 | batch_idx) 297 | path = os.path.join(root, filename) 298 | os.makedirs(os.path.split(path)[0], exist_ok=True) 299 | Image.fromarray(grid).save(path) 300 | 301 | def log_img(self, pl_module, batch, batch_idx, split="train"): 302 | check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step 303 | if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 304 | hasattr(pl_module, "log_images") and 305 | callable(pl_module.log_images) and 306 | self.max_images > 0): 307 | logger = type(pl_module.logger) 308 | 309 | is_train = pl_module.training 310 | if is_train: 311 | pl_module.eval() 312 | 313 | with torch.no_grad(): 314 | images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) 315 | 316 | for k in images: 317 | N = min(images[k].shape[0], self.max_images) 318 | images[k] = images[k][:N] 319 | if isinstance(images[k], torch.Tensor): 320 | images[k] = images[k].detach().cpu() 321 | if self.clamp: 322 | images[k] = torch.clamp(images[k], -1., 1.) 323 | if self.log_local: 324 | self.log_local(pl_module.logger.save_dir, split, images, 325 | pl_module.global_step, pl_module.current_epoch, batch_idx) 326 | 327 | logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None) 328 | logger_log_images(pl_module, images, pl_module.global_step, split) 329 | 330 | if is_train: 331 | pl_module.train() 332 | 333 | def check_frequency(self, check_idx): 334 | if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and ( 335 | check_idx > 0 or self.log_first_step): 336 | try: 337 | self.log_steps.pop(0) 338 | except IndexError as e: 339 | print(e) 340 | pass 341 | return True 342 | return False 343 | 344 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 345 | if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): 346 | self.log_img(pl_module, batch, batch_idx, split="train") 347 | 348 | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 349 | if not self.disabled and pl_module.global_step > 0: 350 | self.log_img(pl_module, batch, batch_idx, split="val") 351 | if hasattr(pl_module, 'calibrate_grad_norm'): 352 | if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0: 353 | self.log_gradients(trainer, pl_module, batch_idx=batch_idx) 354 | 355 | class CUDACallback(Callback): 356 | def on_train_epoch_start(self, trainer, pl_module): 357 | # Reset the memory use counter 358 | torch.cuda.reset_peak_memory_stats(trainer.root_gpu) 359 | torch.cuda.synchronize(trainer.root_gpu) 360 | self.start_time = time.time() 361 | 362 | def on_train_epoch_end(self, trainer, pl_module, outputs): 363 | torch.cuda.synchronize(trainer.root_gpu) 364 | max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20 365 | epoch_time = time.time() - self.start_time 366 | 367 | try: 368 | max_memory = trainer.training_type_plugin.reduce(max_memory) 369 | epoch_time = trainer.training_type_plugin.reduce(epoch_time) 370 | 371 | rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") 372 | rank_zero_info(f"Average Peak memory: {max_memory:.2f} MiB") 373 | except AttributeError: 374 | pass 375 | 376 | if __name__ == "__main__": 377 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 378 | sys.path.append(os.getcwd()) # add cwd to make main.py classes available 379 | parser = get_parser() 380 | parser = Trainer.add_argparse_args(parser) 381 | 382 | opt, unknown = parser.parse_known_args() 383 | if opt.name and opt.resume: 384 | raise ValueError( 385 | "-n/--name and -r/--resume cannot be specified both." 386 | "If you want to resume training in a new log folder, " 387 | "use -n/--name in combination with --resume_from_checkpoint" 388 | ) 389 | if opt.resume: 390 | # Manage path resume not exist 391 | if not os.path.exists(opt.resume): 392 | raise ValueError("Cannot find {}".format(opt.resume)) 393 | 394 | if os.path.isfile(opt.resume): 395 | paths = opt.resume.split("/") 396 | logdir = "/".join(paths[:-2]) 397 | ckpt = opt.resume 398 | else: 399 | assert os.path.isdir(opt.resume), opt.resume 400 | logdir = opt.resume.rstrip("/") 401 | ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") 402 | 403 | opt.resume_from_checkpoint = ckpt 404 | base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) 405 | opt.base = base_configs+opt.base 406 | _tmp = logdir.split("/") 407 | nowname = _tmp[-1] 408 | else: 409 | if opt.name: 410 | name = "_" + opt.name 411 | elif opt.base: 412 | cfg_fname = os.path.split(opt.base[0])[-1] 413 | cfg_name = os.path.splitext(cfg_fname)[0] 414 | name = "_" + cfg_name 415 | else: 416 | name = "" 417 | nowname = now + name + opt.postfix 418 | logdir = os.path.join(opt.logdir, nowname) 419 | 420 | ckptdir = os.path.join(logdir, "checkpoints") 421 | cfgdir = os.path.join(logdir, "configs") 422 | seed_everything(opt.seed) 423 | 424 | try: 425 | # init and save configs 426 | configs = [OmegaConf.load(cfg) for cfg in opt.base] 427 | cli = OmegaConf.from_dotlist(unknown) 428 | config = OmegaConf.merge(*configs, cli) 429 | lightning_config = config.pop("lightning", OmegaConf.create()) 430 | 431 | # merge trainer cli with config 432 | trainer_config = lightning_config.get("trainer", OmegaConf.create()) 433 | # default to ddp 434 | trainer_config["accelerator"] = "ddp" 435 | for k in nondefault_trainer_args(opt): 436 | trainer_config[k] = getattr(opt, k) 437 | if not "gpus" in trainer_config: 438 | del trainer_config["distributed_backend"] 439 | cpu = True 440 | else: 441 | gpuinfo = trainer_config["gpus"] 442 | print(f"Running on GPUs {gpuinfo}") 443 | cpu = False 444 | trainer_opt = argparse.Namespace(**trainer_config) 445 | lightning_config.trainer = trainer_config 446 | 447 | # model 448 | model = instantiate_from_config(config.model) 449 | 450 | # trainer and callbacks 451 | trainer_kwargs = dict() 452 | 453 | if opt.train: 454 | # default logger configs 455 | default_logger_cfgs = { 456 | "wandb": { 457 | "target": "pytorch_lightning.loggers.WandbLogger", 458 | "params": { 459 | "project":opt.project, 460 | "save_dir": opt.logdir, # Set this to the "--logdir" dir, only trick to make wandb work 461 | "offline": opt.debug, 462 | "config": dict(config) 463 | }, 464 | }, 465 | "testtube": { 466 | "target": "pytorch_lightning.loggers.TestTubeLogger", 467 | "params": { 468 | "name": "testtube", 469 | "save_dir": logdir, 470 | } 471 | }, 472 | } 473 | default_logger_cfg = default_logger_cfgs["wandb"] 474 | if "logger" in lightning_config: 475 | logger_cfg = lightning_config.logger 476 | else: 477 | logger_cfg = OmegaConf.create() 478 | logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) 479 | trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) 480 | 481 | # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to 482 | # specify which metric is used to determine best models 483 | default_modelckpt_cfg = { 484 | "target": "pytorch_lightning.callbacks.ModelCheckpoint", 485 | "params": { 486 | "dirpath": ckptdir, 487 | "filename": "{epoch:06}", 488 | "verbose": True, 489 | "save_last": True, 490 | } 491 | } 492 | if hasattr(model, "monitor"): 493 | print(f"Monitoring {model.monitor} as checkpoint metric.") 494 | default_modelckpt_cfg["params"]["monitor"] = model.monitor 495 | default_modelckpt_cfg["params"]["save_top_k"] = 1 496 | 497 | if "modelcheckpoint" in lightning_config: 498 | modelckpt_cfg = lightning_config.modelcheckpoint 499 | else: 500 | modelckpt_cfg = OmegaConf.create() 501 | modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) 502 | print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}") 503 | if version.parse(pl.__version__) < version.parse('1.4.0'): 504 | trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg) 505 | 506 | # add callback which sets up log directory 507 | default_callbacks_cfg = { 508 | "setup_callback": { 509 | "target": "main.SetupCallback", 510 | "params": { 511 | "resume": opt.resume, 512 | "now": now, 513 | "logdir": logdir, 514 | "ckptdir": ckptdir, 515 | "cfgdir": cfgdir, 516 | "config": config, 517 | "lightning_config": lightning_config, 518 | } 519 | }, 520 | "image_logger": { 521 | "target": "main.ImageLogger", 522 | "params": { 523 | "batch_frequency": 750, 524 | "max_images": 4, 525 | "clamp": True 526 | } 527 | }, 528 | "learning_rate_logger": { 529 | "target": "main.LearningRateMonitor", 530 | "params": { 531 | "logging_interval": "step", 532 | #"log_momentum": True 533 | } 534 | }, 535 | "cuda_callback": { 536 | "target": "main.CUDACallback" 537 | }, 538 | } 539 | if version.parse(pl.__version__) >= version.parse('1.4.0'): 540 | default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg}) 541 | 542 | if "callbacks" in lightning_config: 543 | callbacks_cfg = lightning_config.callbacks 544 | else: 545 | callbacks_cfg = OmegaConf.create() 546 | 547 | if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg: 548 | print( 549 | 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.') 550 | default_metrics_over_trainsteps_ckpt_dict = { 551 | 'metrics_over_trainsteps_checkpoint': 552 | {"target": 'pytorch_lightning.callbacks.ModelCheckpoint', 553 | 'params': { 554 | "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'), 555 | "filename": "{epoch:06}-{step:09}", 556 | "verbose": True, 557 | 'save_top_k': -1, 558 | 'every_n_train_steps': 10000, 559 | 'save_weights_only': True 560 | } 561 | } 562 | } 563 | default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) 564 | 565 | callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) 566 | if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'): 567 | callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint 568 | elif 'ignore_keys_callback' in callbacks_cfg: 569 | del callbacks_cfg['ignore_keys_callback'] 570 | 571 | trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] 572 | 573 | trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) 574 | trainer.logdir = logdir 575 | 576 | # Datasets 577 | data = instantiate_from_config(config.data) 578 | data.prepare_data() 579 | data.setup() 580 | print("#### Datasets #####") 581 | for k in data.datasets: 582 | print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}") 583 | 584 | # Configure learning rate 585 | bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate 586 | 587 | # Devices 588 | if not cpu: 589 | ngpu = len(lightning_config.trainer.gpus.strip(",").split(',')) 590 | else: 591 | ngpu = 1 592 | 593 | if 'accumulate_grad_batches' in lightning_config.trainer: 594 | accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches 595 | else: 596 | accumulate_grad_batches = 1 597 | print(f"accumulate_grad_batches = {accumulate_grad_batches}") 598 | lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches 599 | 600 | model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr 601 | print("Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format( 602 | model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr)) 603 | 604 | # allow checkpointing via USR1 605 | def melk(*args, **kwargs): 606 | # run all checkpoint hooks 607 | if trainer.global_rank == 0: 608 | print("Summoning checkpoint.") 609 | ckpt_path = os.path.join(ckptdir, "last.ckpt") 610 | trainer.save_checkpoint(ckpt_path) 611 | 612 | def divein(*args, **kwargs): 613 | if trainer.global_rank == 0: 614 | import pudb; pudb.set_trace() 615 | 616 | import signal 617 | signal.signal(signal.SIGUSR1, melk) 618 | signal.signal(signal.SIGUSR2, divein) 619 | 620 | # run 621 | if opt.train: 622 | try: 623 | trainer.fit(model, data) 624 | except Exception: 625 | melk() 626 | raise 627 | if not opt.no_test and not trainer.interrupted: 628 | # Create directory for storing test results (image reconstructions) 629 | os.makedirs(os.path.join(logdir, 'evaluation'), exist_ok=True) 630 | # We need to load the model this way when we dont do fit (otherwise weights are not loaded) 631 | model.load_state_dict(torch.load(opt.resume_from_checkpoint)['state_dict']) 632 | trainer.test(model, data) 633 | except Exception: 634 | if opt.debug and trainer.global_rank==0: 635 | try: 636 | import pudb as debugger 637 | except ImportError: 638 | import pdb as debugger 639 | debugger.post_mortem() 640 | raise 641 | finally: 642 | # move newly created debug project to debug_runs 643 | if opt.debug and not opt.resume and trainer.global_rank==0: 644 | dst, name = os.path.split(logdir) 645 | dst = os.path.join(dst, "debug_runs", name) 646 | os.makedirs(os.path.split(dst)[0], exist_ok=True) 647 | os.rename(logdir, dst) -------------------------------------------------------------------------------- /taming/modules/diffusionmodules/model.py: -------------------------------------------------------------------------------- 1 | # pytorch_diffusion + derived encoder decoder 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | 8 | def get_timestep_embedding(timesteps, embedding_dim): 9 | """ 10 | This matches the implementation in Denoising Diffusion Probabilistic Models: 11 | From Fairseq. 12 | Build sinusoidal embeddings. 13 | This matches the implementation in tensor2tensor, but differs slightly 14 | from the description in Section 3.5 of "Attention Is All You Need". 15 | """ 16 | assert len(timesteps.shape) == 1 17 | 18 | half_dim = embedding_dim // 2 19 | emb = math.log(10000) / (half_dim - 1) 20 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) 21 | emb = emb.to(device=timesteps.device) 22 | emb = timesteps.float()[:, None] * emb[None, :] 23 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 24 | if embedding_dim % 2 == 1: # zero pad 25 | emb = torch.nn.functional.pad(emb, (0,1,0,0)) 26 | return emb 27 | 28 | 29 | def nonlinearity(x): 30 | # swish 31 | return x*torch.sigmoid(x) 32 | 33 | 34 | def Normalize(in_channels): 35 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 36 | 37 | 38 | class Upsample(nn.Module): 39 | def __init__(self, in_channels, with_conv): 40 | super().__init__() 41 | self.with_conv = with_conv 42 | if self.with_conv: 43 | self.conv = torch.nn.Conv2d(in_channels, 44 | in_channels, 45 | kernel_size=3, 46 | stride=1, 47 | padding=1) 48 | 49 | def forward(self, x): 50 | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 51 | if self.with_conv: 52 | x = self.conv(x) #Here we get -infs in level 1 block 2 53 | # for param in self.conv.parameters(): 54 | # print(param) 55 | return x 56 | 57 | 58 | class Downsample(nn.Module): 59 | def __init__(self, in_channels, with_conv): 60 | super().__init__() 61 | self.with_conv = with_conv 62 | if self.with_conv: 63 | # no asymmetric padding in torch conv, must do it ourselves 64 | self.conv = torch.nn.Conv2d(in_channels, 65 | in_channels, 66 | kernel_size=3, 67 | stride=2, 68 | padding=0) 69 | 70 | def forward(self, x): 71 | if self.with_conv: 72 | pad = (0,1,0,1) 73 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 74 | x = self.conv(x) 75 | else: 76 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 77 | return x 78 | 79 | 80 | class ResnetBlock(nn.Module): 81 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, 82 | dropout, temb_channels=512): 83 | super().__init__() 84 | self.in_channels = in_channels 85 | out_channels = in_channels if out_channels is None else out_channels 86 | self.out_channels = out_channels 87 | self.use_conv_shortcut = conv_shortcut 88 | 89 | self.norm1 = Normalize(in_channels) 90 | self.conv1 = torch.nn.Conv2d(in_channels, 91 | out_channels, 92 | kernel_size=3, 93 | stride=1, 94 | padding=1) 95 | if temb_channels > 0: 96 | self.temb_proj = torch.nn.Linear(temb_channels, 97 | out_channels) 98 | self.norm2 = Normalize(out_channels) 99 | self.dropout = torch.nn.Dropout(dropout) 100 | self.conv2 = torch.nn.Conv2d(out_channels, 101 | out_channels, 102 | kernel_size=3, 103 | stride=1, 104 | padding=1) 105 | if self.in_channels != self.out_channels: 106 | if self.use_conv_shortcut: 107 | self.conv_shortcut = torch.nn.Conv2d(in_channels, 108 | out_channels, 109 | kernel_size=3, 110 | stride=1, 111 | padding=1) 112 | else: 113 | self.nin_shortcut = torch.nn.Conv2d(in_channels, 114 | out_channels, 115 | kernel_size=1, 116 | stride=1, 117 | padding=0) 118 | 119 | def forward(self, x, temb): 120 | h = x 121 | h = self.norm1(h) 122 | h = nonlinearity(h) 123 | h = self.conv1(h) 124 | 125 | if temb is not None: 126 | h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] 127 | 128 | h = self.norm2(h) 129 | h = nonlinearity(h) 130 | h = self.dropout(h) 131 | h = self.conv2(h) 132 | 133 | if self.in_channels != self.out_channels: 134 | if self.use_conv_shortcut: 135 | x = self.conv_shortcut(x) 136 | else: 137 | x = self.nin_shortcut(x) 138 | 139 | return x+h 140 | 141 | 142 | class AttnBlock(nn.Module): 143 | def __init__(self, in_channels): 144 | super().__init__() 145 | self.in_channels = in_channels 146 | 147 | self.norm = Normalize(in_channels) 148 | self.q = torch.nn.Conv2d(in_channels, 149 | in_channels, 150 | kernel_size=1, 151 | stride=1, 152 | padding=0) 153 | self.k = torch.nn.Conv2d(in_channels, 154 | in_channels, 155 | kernel_size=1, 156 | stride=1, 157 | padding=0) 158 | self.v = torch.nn.Conv2d(in_channels, 159 | in_channels, 160 | kernel_size=1, 161 | stride=1, 162 | padding=0) 163 | self.proj_out = torch.nn.Conv2d(in_channels, 164 | in_channels, 165 | kernel_size=1, 166 | stride=1, 167 | padding=0) 168 | 169 | 170 | def forward(self, x): 171 | h_ = x 172 | h_ = self.norm(h_) 173 | q = self.q(h_) 174 | k = self.k(h_) 175 | v = self.v(h_) 176 | 177 | # compute attention 178 | b,c,h,w = q.shape 179 | q = q.reshape(b,c,h*w) 180 | q = q.permute(0,2,1) # b,hw,c 181 | k = k.reshape(b,c,h*w) # b,c,hw 182 | w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 183 | w_ = w_ * (int(c)**(-0.5)) 184 | w_ = torch.nn.functional.softmax(w_, dim=2) 185 | 186 | # attend to values 187 | v = v.reshape(b,c,h*w) 188 | w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) 189 | h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 190 | h_ = h_.reshape(b,c,h,w) 191 | 192 | h_ = self.proj_out(h_) 193 | 194 | return x+h_ 195 | 196 | 197 | class Model(nn.Module): 198 | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, 199 | attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, 200 | resolution, use_timestep=True): 201 | super().__init__() 202 | self.ch = ch 203 | self.temb_ch = self.ch*4 204 | self.num_resolutions = len(ch_mult) 205 | self.num_res_blocks = num_res_blocks 206 | self.resolution = resolution 207 | self.in_channels = in_channels 208 | 209 | self.use_timestep = use_timestep 210 | if self.use_timestep: 211 | # timestep embedding 212 | self.temb = nn.Module() 213 | self.temb.dense = nn.ModuleList([ 214 | torch.nn.Linear(self.ch, 215 | self.temb_ch), 216 | torch.nn.Linear(self.temb_ch, 217 | self.temb_ch), 218 | ]) 219 | 220 | # downsampling 221 | self.conv_in = torch.nn.Conv2d(in_channels, 222 | self.ch, 223 | kernel_size=3, 224 | stride=1, 225 | padding=1) 226 | 227 | curr_res = resolution 228 | in_ch_mult = (1,)+tuple(ch_mult) 229 | self.down = nn.ModuleList() 230 | for i_level in range(self.num_resolutions): 231 | block = nn.ModuleList() 232 | attn = nn.ModuleList() 233 | block_in = ch*in_ch_mult[i_level] 234 | block_out = ch*ch_mult[i_level] 235 | for i_block in range(self.num_res_blocks): 236 | block.append(ResnetBlock(in_channels=block_in, 237 | out_channels=block_out, 238 | temb_channels=self.temb_ch, 239 | dropout=dropout)) 240 | block_in = block_out 241 | if curr_res in attn_resolutions: 242 | attn.append(AttnBlock(block_in)) 243 | down = nn.Module() 244 | down.block = block 245 | down.attn = attn 246 | if i_level != self.num_resolutions-1: 247 | down.downsample = Downsample(block_in, resamp_with_conv) 248 | curr_res = curr_res // 2 249 | self.down.append(down) 250 | 251 | # middle 252 | self.mid = nn.Module() 253 | self.mid.block_1 = ResnetBlock(in_channels=block_in, 254 | out_channels=block_in, 255 | temb_channels=self.temb_ch, 256 | dropout=dropout) 257 | self.mid.attn_1 = AttnBlock(block_in) 258 | self.mid.block_2 = ResnetBlock(in_channels=block_in, 259 | out_channels=block_in, 260 | temb_channels=self.temb_ch, 261 | dropout=dropout) 262 | 263 | # upsampling 264 | self.up = nn.ModuleList() 265 | for i_level in reversed(range(self.num_resolutions)): 266 | block = nn.ModuleList() 267 | attn = nn.ModuleList() 268 | block_out = ch*ch_mult[i_level] 269 | skip_in = ch*ch_mult[i_level] 270 | for i_block in range(self.num_res_blocks+1): 271 | if i_block == self.num_res_blocks: 272 | skip_in = ch*in_ch_mult[i_level] 273 | block.append(ResnetBlock(in_channels=block_in+skip_in, 274 | out_channels=block_out, 275 | temb_channels=self.temb_ch, 276 | dropout=dropout)) 277 | block_in = block_out 278 | if curr_res in attn_resolutions: 279 | attn.append(AttnBlock(block_in)) 280 | up = nn.Module() 281 | up.block = block 282 | up.attn = attn 283 | if i_level != 0: 284 | up.upsample = Upsample(block_in, resamp_with_conv) 285 | curr_res = curr_res * 2 286 | self.up.insert(0, up) # prepend to get consistent order 287 | 288 | # end 289 | self.norm_out = Normalize(block_in) 290 | self.conv_out = torch.nn.Conv2d(block_in, 291 | out_ch, 292 | kernel_size=3, 293 | stride=1, 294 | padding=1) 295 | 296 | 297 | def forward(self, x, t=None): 298 | #assert x.shape[2] == x.shape[3] == self.resolution 299 | 300 | if self.use_timestep: 301 | # timestep embedding 302 | assert t is not None 303 | temb = get_timestep_embedding(t, self.ch) 304 | temb = self.temb.dense[0](temb) 305 | temb = nonlinearity(temb) 306 | temb = self.temb.dense[1](temb) 307 | else: 308 | temb = None 309 | 310 | # downsampling 311 | hs = [self.conv_in(x)] 312 | for i_level in range(self.num_resolutions): 313 | for i_block in range(self.num_res_blocks): 314 | h = self.down[i_level].block[i_block](hs[-1], temb) 315 | if len(self.down[i_level].attn) > 0: 316 | h = self.down[i_level].attn[i_block](h) 317 | hs.append(h) 318 | if i_level != self.num_resolutions-1: 319 | hs.append(self.down[i_level].downsample(hs[-1])) 320 | 321 | # middle 322 | h = hs[-1] 323 | h = self.mid.block_1(h, temb) 324 | h = self.mid.attn_1(h) 325 | h = self.mid.block_2(h, temb) 326 | 327 | # upsampling 328 | for i_level in reversed(range(self.num_resolutions)): 329 | for i_block in range(self.num_res_blocks+1): 330 | h = self.up[i_level].block[i_block]( 331 | torch.cat([h, hs.pop()], dim=1), temb) 332 | if len(self.up[i_level].attn) > 0: 333 | h = self.up[i_level].attn[i_block](h) 334 | if i_level != 0: 335 | h = self.up[i_level].upsample(h) 336 | 337 | # end 338 | h = self.norm_out(h) 339 | h = nonlinearity(h) 340 | h = self.conv_out(h) 341 | return h 342 | 343 | 344 | class Encoder(nn.Module): 345 | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, 346 | attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, 347 | resolution, z_channels, double_z=True, **ignore_kwargs): 348 | super().__init__() 349 | self.ch = ch 350 | self.temb_ch = 0 351 | self.num_resolutions = len(ch_mult) 352 | self.num_res_blocks = num_res_blocks 353 | self.resolution = resolution 354 | self.in_channels = in_channels 355 | 356 | # downsampling 357 | self.conv_in = torch.nn.Conv2d(in_channels, 358 | self.ch, 359 | kernel_size=3, 360 | stride=1, 361 | padding=1) 362 | 363 | curr_res = resolution 364 | in_ch_mult = (1,)+tuple(ch_mult) 365 | self.down = nn.ModuleList() 366 | for i_level in range(self.num_resolutions): 367 | block = nn.ModuleList() 368 | attn = nn.ModuleList() 369 | block_in = ch*in_ch_mult[i_level] 370 | block_out = ch*ch_mult[i_level] 371 | for i_block in range(self.num_res_blocks): 372 | block.append(ResnetBlock(in_channels=block_in, 373 | out_channels=block_out, 374 | temb_channels=self.temb_ch, 375 | dropout=dropout)) 376 | block_in = block_out 377 | if curr_res in attn_resolutions: 378 | attn.append(AttnBlock(block_in)) 379 | down = nn.Module() 380 | down.block = block 381 | down.attn = attn 382 | if i_level != self.num_resolutions-1: 383 | down.downsample = Downsample(block_in, resamp_with_conv) 384 | curr_res = curr_res // 2 385 | self.down.append(down) 386 | 387 | # middle 388 | self.mid = nn.Module() 389 | self.mid.block_1 = ResnetBlock(in_channels=block_in, 390 | out_channels=block_in, 391 | temb_channels=self.temb_ch, 392 | dropout=dropout) 393 | self.mid.attn_1 = AttnBlock(block_in) 394 | self.mid.block_2 = ResnetBlock(in_channels=block_in, 395 | out_channels=block_in, 396 | temb_channels=self.temb_ch, 397 | dropout=dropout) 398 | 399 | # end 400 | self.norm_out = Normalize(block_in) 401 | self.conv_out = torch.nn.Conv2d(block_in, 402 | 2*z_channels if double_z else z_channels, 403 | kernel_size=3, 404 | stride=1, 405 | padding=1) 406 | 407 | 408 | def forward(self, x): 409 | #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) 410 | 411 | # timestep embedding 412 | temb = None 413 | 414 | # downsampling 415 | hs = [self.conv_in(x)] 416 | for i_level in range(self.num_resolutions): 417 | for i_block in range(self.num_res_blocks): 418 | h = self.down[i_level].block[i_block](hs[-1], temb) 419 | if len(self.down[i_level].attn) > 0: 420 | h = self.down[i_level].attn[i_block](h) 421 | hs.append(h) 422 | if i_level != self.num_resolutions-1: 423 | hs.append(self.down[i_level].downsample(hs[-1])) 424 | 425 | # middle 426 | h = hs[-1] 427 | h = self.mid.block_1(h, temb) 428 | h = self.mid.attn_1(h) 429 | h = self.mid.block_2(h, temb) 430 | 431 | # end 432 | h = self.norm_out(h) 433 | h = nonlinearity(h) 434 | h = self.conv_out(h) 435 | return h 436 | 437 | 438 | class Decoder(nn.Module): 439 | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, 440 | attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, 441 | resolution, z_channels, give_pre_end=False, **ignorekwargs): 442 | super().__init__() 443 | self.ch = ch 444 | self.temb_ch = 0 445 | self.num_resolutions = len(ch_mult) 446 | self.num_res_blocks = num_res_blocks 447 | self.resolution = resolution 448 | self.in_channels = in_channels 449 | self.give_pre_end = give_pre_end 450 | 451 | # compute in_ch_mult, block_in and curr_res at lowest res 452 | in_ch_mult = (1,)+tuple(ch_mult) 453 | block_in = ch*ch_mult[self.num_resolutions-1] 454 | curr_res = resolution // 2**(self.num_resolutions-1) 455 | self.z_shape = (1,z_channels,curr_res,curr_res) 456 | print("Working with z of shape {} = {} dimensions.".format( 457 | self.z_shape, np.prod(self.z_shape))) 458 | 459 | # z to block_in 460 | self.conv_in = torch.nn.Conv2d(z_channels, 461 | block_in, 462 | kernel_size=3, 463 | stride=1, 464 | padding=1) 465 | 466 | # middle 467 | self.mid = nn.Module() 468 | self.mid.block_1 = ResnetBlock(in_channels=block_in, 469 | out_channels=block_in, 470 | temb_channels=self.temb_ch, 471 | dropout=dropout) 472 | self.mid.attn_1 = AttnBlock(block_in) 473 | self.mid.block_2 = ResnetBlock(in_channels=block_in, 474 | out_channels=block_in, 475 | temb_channels=self.temb_ch, 476 | dropout=dropout) 477 | 478 | # upsampling 479 | self.up = nn.ModuleList() 480 | for i_level in reversed(range(self.num_resolutions)): 481 | block = nn.ModuleList() 482 | attn = nn.ModuleList() 483 | block_out = ch*ch_mult[i_level] 484 | for i_block in range(self.num_res_blocks+1): 485 | block.append(ResnetBlock(in_channels=block_in, 486 | out_channels=block_out, 487 | temb_channels=self.temb_ch, 488 | dropout=dropout)) 489 | block_in = block_out 490 | if curr_res in attn_resolutions: 491 | attn.append(AttnBlock(block_in)) 492 | up = nn.Module() 493 | up.block = block 494 | up.attn = attn 495 | if i_level != 0: 496 | up.upsample = Upsample(block_in, resamp_with_conv) 497 | curr_res = curr_res * 2 498 | self.up.insert(0, up) # prepend to get consistent order 499 | 500 | # end 501 | self.norm_out = Normalize(block_in) 502 | self.conv_out = torch.nn.Conv2d(block_in, 503 | out_ch, 504 | kernel_size=3, 505 | stride=1, 506 | padding=1) 507 | 508 | def forward(self, z): 509 | #assert z.shape[1:] == self.z_shape[1:] 510 | self.last_z_shape = z.shape 511 | 512 | # timestep embedding 513 | temb = None 514 | 515 | # z to block_in 516 | h = self.conv_in(z) 517 | 518 | # middle 519 | h = self.mid.block_1(h, temb) 520 | h = self.mid.attn_1(h) 521 | h = self.mid.block_2(h, temb) 522 | 523 | # upsampling #error here 524 | for i_level in reversed(range(self.num_resolutions)): 525 | for i_block in range(self.num_res_blocks+1): 526 | # if i_level == 1: 527 | # print(i_level) 528 | # print() 529 | # print(f"level: {i_level}, resnet_block: {i_block}, mean of h: {torch.mean(h)}") 530 | # print(f" - shape h: {h.shape}") 531 | h = self.up[i_level].block[i_block](h, temb) 532 | # print(torch.mean(h)) 533 | if len(self.up[i_level].attn) > 0: 534 | h = self.up[i_level].attn[i_block](h) 535 | if i_level != 0: 536 | # if i_level == 1 and i_block == 2: 537 | # print("hi") 538 | # print(f"Upsampling at level {i_level}") 539 | h = self.up[i_level].upsample(h) 540 | # print(f" - shape h: {h.shape}, with mean: {torch.mean(h)}") 541 | # ------- 542 | # end 543 | if self.give_pre_end: 544 | return h 545 | 546 | h = self.norm_out(h) 547 | h = nonlinearity(h) 548 | h = self.conv_out(h) 549 | return h 550 | 551 | 552 | class VUNet(nn.Module): 553 | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, 554 | attn_resolutions, dropout=0.0, resamp_with_conv=True, 555 | in_channels, c_channels, 556 | resolution, z_channels, use_timestep=False, **ignore_kwargs): 557 | super().__init__() 558 | self.ch = ch 559 | self.temb_ch = self.ch*4 560 | self.num_resolutions = len(ch_mult) 561 | self.num_res_blocks = num_res_blocks 562 | self.resolution = resolution 563 | 564 | self.use_timestep = use_timestep 565 | if self.use_timestep: 566 | # timestep embedding 567 | self.temb = nn.Module() 568 | self.temb.dense = nn.ModuleList([ 569 | torch.nn.Linear(self.ch, 570 | self.temb_ch), 571 | torch.nn.Linear(self.temb_ch, 572 | self.temb_ch), 573 | ]) 574 | 575 | # downsampling 576 | self.conv_in = torch.nn.Conv2d(c_channels, 577 | self.ch, 578 | kernel_size=3, 579 | stride=1, 580 | padding=1) 581 | 582 | curr_res = resolution 583 | in_ch_mult = (1,)+tuple(ch_mult) 584 | self.down = nn.ModuleList() 585 | for i_level in range(self.num_resolutions): 586 | block = nn.ModuleList() 587 | attn = nn.ModuleList() 588 | block_in = ch*in_ch_mult[i_level] 589 | block_out = ch*ch_mult[i_level] 590 | for i_block in range(self.num_res_blocks): 591 | block.append(ResnetBlock(in_channels=block_in, 592 | out_channels=block_out, 593 | temb_channels=self.temb_ch, 594 | dropout=dropout)) 595 | block_in = block_out 596 | if curr_res in attn_resolutions: 597 | attn.append(AttnBlock(block_in)) 598 | down = nn.Module() 599 | down.block = block 600 | down.attn = attn 601 | if i_level != self.num_resolutions-1: 602 | down.downsample = Downsample(block_in, resamp_with_conv) 603 | curr_res = curr_res // 2 604 | self.down.append(down) 605 | 606 | self.z_in = torch.nn.Conv2d(z_channels, 607 | block_in, 608 | kernel_size=1, 609 | stride=1, 610 | padding=0) 611 | # middle 612 | self.mid = nn.Module() 613 | self.mid.block_1 = ResnetBlock(in_channels=2*block_in, 614 | out_channels=block_in, 615 | temb_channels=self.temb_ch, 616 | dropout=dropout) 617 | self.mid.attn_1 = AttnBlock(block_in) 618 | self.mid.block_2 = ResnetBlock(in_channels=block_in, 619 | out_channels=block_in, 620 | temb_channels=self.temb_ch, 621 | dropout=dropout) 622 | 623 | # upsampling 624 | self.up = nn.ModuleList() 625 | for i_level in reversed(range(self.num_resolutions)): 626 | block = nn.ModuleList() 627 | attn = nn.ModuleList() 628 | block_out = ch*ch_mult[i_level] 629 | skip_in = ch*ch_mult[i_level] 630 | for i_block in range(self.num_res_blocks+1): 631 | if i_block == self.num_res_blocks: 632 | skip_in = ch*in_ch_mult[i_level] 633 | block.append(ResnetBlock(in_channels=block_in+skip_in, 634 | out_channels=block_out, 635 | temb_channels=self.temb_ch, 636 | dropout=dropout)) 637 | block_in = block_out 638 | if curr_res in attn_resolutions: 639 | attn.append(AttnBlock(block_in)) 640 | up = nn.Module() 641 | up.block = block 642 | up.attn = attn 643 | if i_level != 0: 644 | up.upsample = Upsample(block_in, resamp_with_conv) 645 | curr_res = curr_res * 2 646 | self.up.insert(0, up) # prepend to get consistent order 647 | 648 | # end 649 | self.norm_out = Normalize(block_in) 650 | self.conv_out = torch.nn.Conv2d(block_in, 651 | out_ch, 652 | kernel_size=3, 653 | stride=1, 654 | padding=1) 655 | 656 | 657 | def forward(self, x, z): 658 | #assert x.shape[2] == x.shape[3] == self.resolution 659 | 660 | if self.use_timestep: 661 | # timestep embedding 662 | assert t is not None 663 | temb = get_timestep_embedding(t, self.ch) 664 | temb = self.temb.dense[0](temb) 665 | temb = nonlinearity(temb) 666 | temb = self.temb.dense[1](temb) 667 | else: 668 | temb = None 669 | 670 | # downsampling 671 | hs = [self.conv_in(x)] 672 | for i_level in range(self.num_resolutions): 673 | for i_block in range(self.num_res_blocks): 674 | h = self.down[i_level].block[i_block](hs[-1], temb) 675 | if len(self.down[i_level].attn) > 0: 676 | h = self.down[i_level].attn[i_block](h) 677 | hs.append(h) 678 | if i_level != self.num_resolutions-1: 679 | hs.append(self.down[i_level].downsample(hs[-1])) 680 | 681 | # middle 682 | h = hs[-1] 683 | z = self.z_in(z) 684 | h = torch.cat((h,z),dim=1) 685 | h = self.mid.block_1(h, temb) 686 | h = self.mid.attn_1(h) 687 | h = self.mid.block_2(h, temb) 688 | 689 | # upsampling 690 | for i_level in reversed(range(self.num_resolutions)): 691 | for i_block in range(self.num_res_blocks+1): 692 | h = self.up[i_level].block[i_block]( 693 | torch.cat([h, hs.pop()], dim=1), temb) 694 | if len(self.up[i_level].attn) > 0: 695 | h = self.up[i_level].attn[i_block](h) 696 | if i_level != 0: 697 | h = self.up[i_level].upsample(h) 698 | 699 | # end 700 | h = self.norm_out(h) 701 | h = nonlinearity(h) 702 | h = self.conv_out(h) 703 | return h 704 | 705 | 706 | class SimpleDecoder(nn.Module): 707 | def __init__(self, in_channels, out_channels, *args, **kwargs): 708 | super().__init__() 709 | self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), 710 | ResnetBlock(in_channels=in_channels, 711 | out_channels=2 * in_channels, 712 | temb_channels=0, dropout=0.0), 713 | ResnetBlock(in_channels=2 * in_channels, 714 | out_channels=4 * in_channels, 715 | temb_channels=0, dropout=0.0), 716 | ResnetBlock(in_channels=4 * in_channels, 717 | out_channels=2 * in_channels, 718 | temb_channels=0, dropout=0.0), 719 | nn.Conv2d(2*in_channels, in_channels, 1), 720 | Upsample(in_channels, with_conv=True)]) 721 | # end 722 | self.norm_out = Normalize(in_channels) 723 | self.conv_out = torch.nn.Conv2d(in_channels, 724 | out_channels, 725 | kernel_size=3, 726 | stride=1, 727 | padding=1) 728 | 729 | def forward(self, x): 730 | for i, layer in enumerate(self.model): 731 | if i in [1,2,3]: 732 | x = layer(x, None) 733 | else: 734 | x = layer(x) 735 | 736 | h = self.norm_out(x) 737 | h = nonlinearity(h) 738 | x = self.conv_out(h) 739 | return x 740 | 741 | 742 | class UpsampleDecoder(nn.Module): 743 | def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, 744 | ch_mult=(2,2), dropout=0.0): 745 | super().__init__() 746 | # upsampling 747 | self.temb_ch = 0 748 | self.num_resolutions = len(ch_mult) 749 | self.num_res_blocks = num_res_blocks 750 | block_in = in_channels 751 | curr_res = resolution // 2 ** (self.num_resolutions - 1) 752 | self.res_blocks = nn.ModuleList() 753 | self.upsample_blocks = nn.ModuleList() 754 | for i_level in range(self.num_resolutions): 755 | res_block = [] 756 | block_out = ch * ch_mult[i_level] 757 | for i_block in range(self.num_res_blocks + 1): 758 | res_block.append(ResnetBlock(in_channels=block_in, 759 | out_channels=block_out, 760 | temb_channels=self.temb_ch, 761 | dropout=dropout)) 762 | block_in = block_out 763 | self.res_blocks.append(nn.ModuleList(res_block)) 764 | if i_level != self.num_resolutions - 1: 765 | self.upsample_blocks.append(Upsample(block_in, True)) 766 | curr_res = curr_res * 2 767 | 768 | # end 769 | self.norm_out = Normalize(block_in) 770 | self.conv_out = torch.nn.Conv2d(block_in, 771 | out_channels, 772 | kernel_size=3, 773 | stride=1, 774 | padding=1) 775 | 776 | def forward(self, x): 777 | # upsampling 778 | h = x 779 | for k, i_level in enumerate(range(self.num_resolutions)): 780 | for i_block in range(self.num_res_blocks + 1): 781 | h = self.res_blocks[i_level][i_block](h, None) 782 | if i_level != self.num_resolutions - 1: 783 | h = self.upsample_blocks[k](h) 784 | h = self.norm_out(h) 785 | h = nonlinearity(h) 786 | h = self.conv_out(h) 787 | return h 788 | 789 | --------------------------------------------------------------------------------