├── .github ├── CODE_OF_CONDUCT.md └── CONTRIBUTING.md ├── LICENSE ├── README.md ├── decoding.ipynb ├── finetune_ldm_decoder.py ├── hidden ├── README.md ├── attenuations.py ├── ckpts │ └── hidden_replicate.pth ├── data_augmentation.py ├── imgs │ └── 00.png ├── main.py ├── models.py ├── notebooks │ └── demo.ipynb ├── requirements.txt ├── utils.py └── utils_img.py ├── requirements.txt ├── run_evals.py ├── src ├── README.md ├── ldm │ ├── data │ │ ├── __init__.py │ │ ├── base.py │ │ ├── imagenet.py │ │ └── lsun.py │ ├── lr_scheduler.py │ ├── models │ │ ├── autoencoder.py │ │ └── diffusion │ │ │ ├── __init__.py │ │ │ ├── ddim.py │ │ │ ├── ddpm.py │ │ │ ├── dpm_solver │ │ │ ├── __init__.py │ │ │ ├── dpm_solver.py │ │ │ └── sampler.py │ │ │ ├── plms.py │ │ │ └── sampling_util.py │ ├── modules │ │ ├── attention.py │ │ ├── diffusionmodules │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ ├── openaimodel.py │ │ │ ├── upscaling.py │ │ │ └── util.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ └── distributions.py │ │ ├── ema.py │ │ ├── encoders │ │ │ ├── __init__.py │ │ │ └── modules.py │ │ ├── image_degradation │ │ │ ├── __init__.py │ │ │ ├── bsrgan.py │ │ │ ├── bsrgan_light.py │ │ │ ├── utils │ │ │ │ └── test.png │ │ │ └── utils_image.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── contperceptual.py │ │ │ └── vqperceptual.py │ │ ├── midas │ │ │ ├── __init__.py │ │ │ ├── api.py │ │ │ ├── midas │ │ │ │ ├── __init__.py │ │ │ │ ├── base_model.py │ │ │ │ ├── blocks.py │ │ │ │ ├── dpt_depth.py │ │ │ │ ├── midas_net.py │ │ │ │ ├── midas_net_custom.py │ │ │ │ ├── transforms.py │ │ │ │ └── vit.py │ │ │ └── utils.py │ │ └── x_transformer.py │ └── util.py ├── loss │ ├── __init__.py │ ├── color_wrapper.py │ ├── dct2d.py │ ├── deep_loss.py │ ├── loss_provider.py │ ├── rfft2d.py │ ├── shift_wrapper.py │ ├── ssim.py │ ├── watson.py │ ├── watson_fft.py │ └── watson_vgg.py └── taming │ ├── data │ ├── ade20k.py │ ├── annotated_objects_coco.py │ ├── annotated_objects_dataset.py │ ├── annotated_objects_open_images.py │ ├── base.py │ ├── coco.py │ ├── conditional_builder │ │ ├── objects_bbox.py │ │ ├── objects_center_points.py │ │ └── utils.py │ ├── custom.py │ ├── faceshq.py │ ├── helper_types.py │ ├── image_transforms.py │ ├── imagenet.py │ ├── open_images_helper.py │ ├── sflckr.py │ └── utils.py │ ├── lr_scheduler.py │ ├── models │ ├── cond_transformer.py │ ├── dummy_cond_stage.py │ └── vqgan.py │ ├── modules │ ├── diffusionmodules │ │ └── model.py │ ├── discriminator │ │ └── model.py │ ├── losses │ │ ├── __init__.py │ │ ├── lpips.py │ │ ├── segmentation.py │ │ └── vqperceptual.py │ ├── misc │ │ └── coord.py │ ├── transformer │ │ ├── mingpt.py │ │ └── permuter.py │ ├── util.py │ └── vqvae │ │ └── quantize.py │ └── util.py ├── utils.py ├── utils_img.py └── utils_model.py /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | Meta has adopted a Code of Conduct that we expect project participants to adhere to. Please read the full text so that you can understand what actions will and will not be tolerated. -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to active_indexing 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to active_indexing, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /hidden/attenuations.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from torchvision.transforms import functional 12 | 13 | class JND(nn.Module): 14 | """ Same as in https://github.com/facebookresearch/active_indexing """ 15 | 16 | def __init__(self, preprocess = lambda x: x): 17 | super(JND, self).__init__() 18 | kernel_x = [[-1., 0., 1.], [-2., 0., 2.], [-1., 0., 1.]] 19 | kernel_y = [[-1., -2., -1.], [0., 0., 0.], [1., 2., 1.]] 20 | kernel_lum = [[1, 1, 1, 1, 1], [1, 2, 2, 2, 1], [1, 2, 0, 2, 1], [1, 2, 2, 2, 1], [1, 1, 1, 1, 1]] 21 | 22 | kernel_x = torch.FloatTensor(kernel_x).unsqueeze(0).unsqueeze(0) 23 | kernel_y = torch.FloatTensor(kernel_y).unsqueeze(0).unsqueeze(0) 24 | kernel_lum = torch.FloatTensor(kernel_lum).unsqueeze(0).unsqueeze(0) 25 | 26 | self.weight_x = nn.Parameter(data=kernel_x, requires_grad=False) 27 | self.weight_y = nn.Parameter(data=kernel_y, requires_grad=False) 28 | self.weight_lum = nn.Parameter(data=kernel_lum, requires_grad=False) 29 | 30 | self.preprocess = preprocess 31 | 32 | def jnd_la(self, x, alpha=1.0): 33 | """ Luminance masking: x must be in [0,255] """ 34 | la = F.conv2d(x, self.weight_lum, padding=2) / 32 35 | mask_lum = la <= 127 36 | la[mask_lum] = 17 * (1 - torch.sqrt(la[mask_lum]/127)) + 3 37 | la[~mask_lum] = 3/128 * (la[~mask_lum] - 127) + 3 38 | return alpha * la 39 | 40 | def jnd_cm(self, x, beta=0.117): 41 | """ Contrast masking: x must be in [0,255] """ 42 | grad_x = F.conv2d(x, self.weight_x, padding=1) 43 | grad_y = F.conv2d(x, self.weight_y, padding=1) 44 | cm = torch.sqrt(grad_x**2 + grad_y**2) 45 | cm = 16 * cm**2.4 / (cm**2 + 26**2) 46 | return beta * cm 47 | 48 | def heatmaps(self, x, clc=0.3): 49 | """ x must be in [0,1] """ 50 | x = 255 * self.preprocess(x) 51 | x = 0.299 * x[...,0:1,:,:] + 0.587 * x[...,1:2,:,:] + 0.114 * x[...,2:3,:,:] 52 | la = self.jnd_la(x) 53 | cm = self.jnd_cm(x) 54 | return (la + cm - clc * torch.minimum(la, cm))/255 # b 1 h w 55 | -------------------------------------------------------------------------------- /hidden/ckpts/hidden_replicate.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/stable_signature/48261686883ba86f533836ad47c35040afcfe37b/hidden/ckpts/hidden_replicate.pth -------------------------------------------------------------------------------- /hidden/data_augmentation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torchvision import transforms 13 | from torchvision.transforms import functional 14 | 15 | import kornia.augmentation as K 16 | from kornia.augmentation import AugmentationBase2D 17 | 18 | import utils_img 19 | 20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | 22 | class DiffJPEG(nn.Module): 23 | def __init__(self, quality=50): 24 | super().__init__() 25 | self.quality = quality 26 | 27 | def forward(self, x): 28 | with torch.no_grad(): 29 | img_clip = utils_img.clamp_pixel(x) 30 | img_jpeg = utils_img.jpeg_compress(img_clip, self.quality) 31 | img_gap = img_jpeg - x 32 | img_gap = img_gap.detach() 33 | img_aug = x+img_gap 34 | return img_aug 35 | 36 | class RandomDiffJPEG(AugmentationBase2D): 37 | def __init__(self, p, low=10, high=100) -> None: 38 | super().__init__(p=p) 39 | self.diff_jpegs = [DiffJPEG(quality=qf).to(device) for qf in range(low,high,10)] 40 | 41 | def generate_parameters(self, input_shape: torch.Size): 42 | qf = torch.randint(high=len(self.diff_jpegs), size=input_shape[0:1]) 43 | return dict(qf=qf) 44 | 45 | def compute_transformation(self, input, params, flags): 46 | return self.identity_matrix(input) 47 | 48 | def apply_transform(self, input, params, *args, **kwargs): 49 | B, C, H, W = input.shape 50 | qf = params['qf'] 51 | output = torch.zeros_like(input) 52 | for ii in range(B): 53 | output[ii] = self.diff_jpegs[qf[ii]](input[ii:ii+1]) 54 | return output 55 | 56 | class RandomBlur(AugmentationBase2D): 57 | def __init__(self, blur_size, p=1) -> None: 58 | super().__init__(p=p) 59 | self.gaussian_blurs = [K.RandomGaussianBlur(kernel_size=(kk,kk), sigma= (kk*0.15 + 0.35, kk*0.15 + 0.35)) for kk in range(1,int(blur_size),2)] 60 | 61 | def generate_parameters(self, input_shape: torch.Size): 62 | blur_strength = torch.randint(high=len(self.gaussian_blurs), size=input_shape[0:1]) 63 | return dict(blur_strength=blur_strength) 64 | 65 | def compute_transformation(self, input, params, flags): 66 | return self.identity_matrix(input) 67 | 68 | def apply_transform(self, input, params, *args, **kwargs): 69 | B, C, H, W = input.shape 70 | blur_strength = params['blur_strength'] 71 | output = torch.zeros_like(input) 72 | for ii in range(B): 73 | output[ii] = self.gaussian_blurs[blur_strength[ii]](input[ii:ii+1]) 74 | return output 75 | 76 | 77 | class HiddenAug(nn.Module): 78 | """Dropout p = 0.3,Dropout p = 0.7, Cropout p = 0.3, Cropout p = 0.7, Crop p = 0.3, Crop p = 0.7, Gaussian blur σ = 2, Gaussian blur σ = 4, JPEG-drop, JPEG-mask and the Identity layer""" 79 | def __init__(self, img_size, p_crop=0.3, p_blur=0.3, p_jpeg=0.3, p_rot=0.3, p_color_jitter=0.3, p_res=0.3): 80 | super().__init__() 81 | augmentations = [] 82 | hflip = K.RandomHorizontalFlip(p=1) 83 | augmentations += [nn.Identity(), hflip] 84 | if p_crop > 0: 85 | crop1 = int(img_size * np.sqrt(0.3)) 86 | crop2 = int(img_size * np.sqrt(0.7)) 87 | crop1 = K.RandomCrop(size=(crop1, crop1), p=1) # Crop 0.3 88 | crop2 = K.RandomCrop(size=(crop2, crop2), p=1) # Crop 0.7 89 | augmentations += [crop1, crop2] 90 | if p_res > 0: 91 | res1 = int(img_size * np.sqrt(0.3)) 92 | res2 = int(img_size * np.sqrt(0.7)) 93 | res1 = K.RandomResizedCrop(size=(res1, res1), scale=(1.0,1.0), p=1) # Resize 0.3 94 | res2 = K.RandomResizedCrop(size=(res2, res2), scale=(1.0,1.0), p=1) # Resize 0.7 95 | augmentations += [res1, res2] 96 | if p_blur > 0: 97 | blur1 = K.RandomGaussianBlur(kernel_size=(11,11), sigma= (2.0, 2.0), p=1) # Gaussian blur σ = 2 98 | # blur2 = K.RandomGaussianBlur(kernel_size=(25,25), sigma= (4.0, 4.0), p=1) # Gaussian blur σ = 4 99 | augmentations += [blur1] 100 | # augmentations += [blur1, blur2] 101 | if p_jpeg > 0: 102 | diff_jpeg1 = DiffJPEG(quality=50) # JPEG50 103 | diff_jpeg2 = DiffJPEG(quality=80) # JPEG80 104 | augmentations += [diff_jpeg1, diff_jpeg2] 105 | if p_rot > 0: 106 | aff1 = K.RandomAffine(degrees=(-10,10), p=1) 107 | aff2 = K.RandomAffine(degrees=(90,90), p=1) 108 | aff3 = K.RandomAffine(degrees=(-90,-90), p=1) 109 | augmentations += [aff1] 110 | augmentations += [aff2, aff3] 111 | if p_color_jitter > 0: 112 | jitter1 = K.ColorJiggle(brightness=(1.5, 1.5), contrast=0, saturation=0, hue=0, p=1) 113 | jitter2 = K.ColorJiggle(brightness=0, contrast=(1.5, 1.5), saturation=0, hue=0, p=1) 114 | jitter3 = K.ColorJiggle(brightness=0, contrast=0, saturation=(1.5,1.5), hue=0, p=1) 115 | jitter4 = K.ColorJiggle(brightness=0, contrast=0, saturation=0, hue=(0.25, 0.25), p=1) 116 | augmentations += [jitter1, jitter2, jitter3, jitter4] 117 | self.hidden_aug = K.AugmentationSequential(*augmentations, random_apply=1).to(device) 118 | 119 | def forward(self, x): 120 | return self.hidden_aug(x) 121 | 122 | class KorniaAug(nn.Module): 123 | def __init__(self, degrees=30, crop_scale=(0.2, 1.0), crop_ratio=(3/4, 4/3), blur_size=17, color_jitter=(1.0, 1.0, 1.0, 0.3), diff_jpeg=10, 124 | p_crop=0.5, p_aff=0.5, p_blur=0.5, p_color_jitter=0.5, p_diff_jpeg=0.5, 125 | cropping_mode='slice', img_size=224 126 | ): 127 | super(KorniaAug, self).__init__() 128 | self.jitter = K.ColorJitter(*color_jitter, p=p_color_jitter).to(device) 129 | # self.jitter = K.RandomPlanckianJitter(p=p_color_jitter).to(device) 130 | self.aff = K.RandomAffine(degrees=degrees, p=p_aff).to(device) 131 | self.crop = K.RandomResizedCrop(size=(img_size,img_size),scale=crop_scale,ratio=crop_ratio, p=p_crop, cropping_mode=cropping_mode).to(device) 132 | self.hflip = K.RandomHorizontalFlip().to(device) 133 | self.blur = RandomBlur(blur_size, p_blur).to(device) 134 | self.diff_jpeg = RandomDiffJPEG(p=p_diff_jpeg, low=diff_jpeg).to(device) 135 | 136 | def forward(self, input): 137 | input = self.diff_jpeg(input) 138 | input = self.aff(input) 139 | input = self.crop(input) 140 | input = self.blur(input) 141 | input = self.jitter(input) 142 | input = self.hflip(input) 143 | return input 144 | -------------------------------------------------------------------------------- /hidden/imgs/00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/stable_signature/48261686883ba86f533836ad47c35040afcfe37b/hidden/imgs/00.png -------------------------------------------------------------------------------- /hidden/requirements.txt: -------------------------------------------------------------------------------- 1 | timm==0.9.5 2 | kornia==0.7.0 3 | scipy==1.10.1 4 | augly==1.0.0 5 | scikit-image==0.20.0 6 | pandas==1.5.3 7 | matplotlib -------------------------------------------------------------------------------- /hidden/utils_img.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import os 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.autograd.variable import Variable 15 | from torchvision import transforms 16 | from torchvision.transforms import functional 17 | from augly.image import functional as aug_functional 18 | 19 | import kornia.augmentation as K 20 | 21 | from PIL import Image 22 | 23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | 25 | NORMALIZE_IMAGENET = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 26 | default_transform = transforms.Compose([ 27 | transforms.ToTensor(), 28 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 29 | ]) 30 | image_mean = torch.Tensor([0.485, 0.456, 0.406]).view(-1, 1, 1) 31 | image_std = torch.Tensor([0.229, 0.224, 0.225]).view(-1, 1, 1) 32 | 33 | normalize_rgb = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 34 | unnormalize_rgb = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], std=[1/0.229, 1/0.224, 1/0.225]) 35 | normalize_yuv = transforms.Normalize(mean=[0.5, 0, 0], std=[0.5, 1, 1]) 36 | unnormalize_yuv = transforms.Normalize(mean=[-0.5/0.5, 0, 0], std=[1/0.5, 1/1, 1/1]) 37 | 38 | 39 | def normalize_img(x): 40 | """ Normalize image to approx. [-1,1] """ 41 | return (x - image_mean.to(x.device)) / image_std.to(x.device) 42 | 43 | def unnormalize_img(x): 44 | """ Unnormalize image to [0,1] """ 45 | return (x * image_std.to(x.device)) + image_mean.to(x.device) 46 | 47 | def round_pixel(x): 48 | """ 49 | Round pixel values to nearest integer. 50 | Args: 51 | x: Image tensor with values approx. between [-1,1] 52 | Returns: 53 | y: Rounded image tensor with values approx. between [-1,1] 54 | """ 55 | x_pixel = 255 * unnormalize_img(x) 56 | y = torch.round(x_pixel).clamp(0, 255) 57 | y = normalize_img(y/255.0) 58 | return y 59 | 60 | def clamp_pixel(x): 61 | """ 62 | Clamp pixel values to 0 255. 63 | Args: 64 | x: Image tensor with values approx. between [-1,1] 65 | Returns: 66 | y: Rounded image tensor with values approx. between [-1,1] 67 | """ 68 | x_pixel = 255 * unnormalize_img(x) 69 | y = x_pixel.clamp(0, 255) 70 | y = normalize_img(y/255.0) 71 | return y 72 | 73 | def project_linf(x, y, radius): 74 | """ 75 | Clamp x so that Linf(x,y)<=radius 76 | Args: 77 | x: Image tensor with values approx. between [-1,1] 78 | y: Image tensor with values approx. between [-1,1], ex: original image 79 | radius: Radius of Linf ball for the images in pixel space [0, 255] 80 | """ 81 | delta = x - y 82 | delta = 255 * (delta * image_std.to(x.device)) 83 | delta = torch.clamp(delta, -radius, radius) 84 | delta = (delta / 255.0) / image_std.to(x.device) 85 | return y + delta 86 | 87 | def psnr(x, y): 88 | """ 89 | Return PSNR 90 | Args: 91 | x: Image tensor with values approx. between [-1,1] 92 | y: Image tensor with values approx. between [-1,1], ex: original image 93 | """ 94 | delta = x - y 95 | delta = 255 * (delta * image_std.to(x.device)) 96 | delta = delta.reshape(-1, x.shape[-3], x.shape[-2], x.shape[-1]) # BxCxHxW 97 | psnr = 20*np.log10(255) - 10*torch.log10(torch.mean(delta**2, dim=(1,2,3))) # B 98 | return psnr 99 | 100 | def center_crop(x, scale): 101 | """ Perform center crop such that the target area of the crop is at a given scale 102 | Args: 103 | x: PIL image 104 | scale: target area scale 105 | """ 106 | scale = np.sqrt(scale) 107 | new_edges_size = [int(s*scale) for s in x.shape[-2:]][::-1] 108 | 109 | # left = int(x.size[0]/2-new_edges_size[0]/2) 110 | # upper = int(x.size[1]/2-new_edges_size[1]/2) 111 | # right = left + new_edges_size[0] 112 | # lower = upper + new_edges_size[1] 113 | 114 | # return x.crop((left, upper, right, lower)) 115 | x = functional.center_crop(x, new_edges_size) 116 | return x 117 | 118 | def resize(x, scale): 119 | """ Perform center crop such that the target area of the crop is at a given scale 120 | Args: 121 | x: PIL image 122 | scale: target area scale 123 | """ 124 | scale = np.sqrt(scale) 125 | new_edges_size = [int(s*scale) for s in x.shape[-2:]][::-1] 126 | return functional.resize(x, new_edges_size) 127 | 128 | def rotate(x, angle): 129 | """ Rotate image by angle 130 | Args: 131 | x: image (PIl or tensor) 132 | angle: angle in degrees 133 | """ 134 | return functional.rotate(x, angle) 135 | 136 | def adjust_brightness(x, brightness_factor): 137 | """ Adjust brightness of an image 138 | Args: 139 | x: PIL image 140 | brightness_factor: brightness factor 141 | """ 142 | return normalize_img(functional.adjust_brightness(unnormalize_img(x), brightness_factor)) 143 | 144 | def adjust_contrast(x, contrast_factor): 145 | """ Adjust constrast of an image 146 | Args: 147 | x: PIL image 148 | contrast_factor: contrast factor 149 | """ 150 | return normalize_img(functional.adjust_contrast(unnormalize_img(x), contrast_factor)) 151 | 152 | def jpeg_compress(x, quality_factor): 153 | """ Apply jpeg compression to image 154 | Args: 155 | x: Tensor image 156 | quality_factor: quality factor 157 | """ 158 | to_pil = transforms.ToPILImage() 159 | to_tensor = transforms.ToTensor() 160 | img_aug = torch.zeros_like(x, device=x.device) 161 | x = unnormalize_img(x) 162 | for ii,img in enumerate(x): 163 | pil_img = to_pil(img) 164 | img_aug[ii] = to_tensor(aug_functional.encoding_quality(pil_img, quality=quality_factor)) 165 | return normalize_img(img_aug) 166 | 167 | def gaussian_blur(x, sigma=1): 168 | """ Add gaussian blur to image 169 | Args: 170 | x: Tensor image 171 | sigma: sigma of gaussian kernel 172 | """ 173 | x = unnormalize_img(x) 174 | x = functional.gaussian_blur(x, sigma=sigma, kernel_size=21) 175 | x = normalize_img(x) 176 | return x 177 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | omegaconf==2.1.1 2 | einops==0.3.0 3 | transformers==4.19.2 4 | open_clip_torch==2.0.2 5 | torchmetrics==0.6.0 6 | scipy==1.10.1 7 | augly==1.0.0 8 | scikit-image==0.20.0 9 | pytorch-fid==0.3.0 10 | pandas==1.5.3 11 | matplotlib 12 | -------------------------------------------------------------------------------- /src/README.md: -------------------------------------------------------------------------------- 1 | The repositories present in this folder are based on the following repositories: 2 | 3 | - Taming Transformers for High-Resolution Image Synthesis: https://github.com/CompVis/taming-transformers 4 | - Latent Diffusion Models: https://github.com/CompVis/latent-diffusion 5 | - Perceptual Similarity: https://github.com/SteffenCzolbe/PerceptualSimilarity -------------------------------------------------------------------------------- /src/ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/stable_signature/48261686883ba86f533836ad47c35040afcfe37b/src/ldm/data/__init__.py -------------------------------------------------------------------------------- /src/ldm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /src/ldm/data/lsun.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | 9 | class LSUNBase(Dataset): 10 | def __init__(self, 11 | txt_file, 12 | data_root, 13 | size=None, 14 | interpolation="bicubic", 15 | flip_p=0.5 16 | ): 17 | self.data_paths = txt_file 18 | self.data_root = data_root 19 | with open(self.data_paths, "r") as f: 20 | self.image_paths = f.read().splitlines() 21 | self._length = len(self.image_paths) 22 | self.labels = { 23 | "relative_file_path_": [l for l in self.image_paths], 24 | "file_path_": [os.path.join(self.data_root, l) 25 | for l in self.image_paths], 26 | } 27 | 28 | self.size = size 29 | self.interpolation = {"linear": PIL.Image.LINEAR, 30 | "bilinear": PIL.Image.BILINEAR, 31 | "bicubic": PIL.Image.BICUBIC, 32 | "lanczos": PIL.Image.LANCZOS, 33 | }[interpolation] 34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 35 | 36 | def __len__(self): 37 | return self._length 38 | 39 | def __getitem__(self, i): 40 | example = dict((k, self.labels[k][i]) for k in self.labels) 41 | image = Image.open(example["file_path_"]) 42 | if not image.mode == "RGB": 43 | image = image.convert("RGB") 44 | 45 | # default to score-sde preprocessing 46 | img = np.array(image).astype(np.uint8) 47 | crop = min(img.shape[0], img.shape[1]) 48 | h, w, = img.shape[0], img.shape[1] 49 | img = img[(h - crop) // 2:(h + crop) // 2, 50 | (w - crop) // 2:(w + crop) // 2] 51 | 52 | image = Image.fromarray(img) 53 | if self.size is not None: 54 | image = image.resize((self.size, self.size), resample=self.interpolation) 55 | 56 | image = self.flip(image) 57 | image = np.array(image).astype(np.uint8) 58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 59 | return example 60 | 61 | 62 | class LSUNChurchesTrain(LSUNBase): 63 | def __init__(self, **kwargs): 64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) 65 | 66 | 67 | class LSUNChurchesValidation(LSUNBase): 68 | def __init__(self, flip_p=0., **kwargs): 69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", 70 | flip_p=flip_p, **kwargs) 71 | 72 | 73 | class LSUNBedroomsTrain(LSUNBase): 74 | def __init__(self, **kwargs): 75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) 76 | 77 | 78 | class LSUNBedroomsValidation(LSUNBase): 79 | def __init__(self, flip_p=0.0, **kwargs): 80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", 81 | flip_p=flip_p, **kwargs) 82 | 83 | 84 | class LSUNCatsTrain(LSUNBase): 85 | def __init__(self, **kwargs): 86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) 87 | 88 | 89 | class LSUNCatsValidation(LSUNBase): 90 | def __init__(self, flip_p=0., **kwargs): 91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", 92 | flip_p=flip_p, **kwargs) 93 | -------------------------------------------------------------------------------- /src/ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /src/ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/stable_signature/48261686883ba86f533836ad47c35040afcfe37b/src/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /src/ldm/models/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler -------------------------------------------------------------------------------- /src/ldm/models/diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | import torch 3 | 4 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 5 | 6 | 7 | MODEL_TYPES = { 8 | "eps": "noise", 9 | "v": "v" 10 | } 11 | 12 | 13 | class DPMSolverSampler(object): 14 | def __init__(self, model, **kwargs): 15 | super().__init__() 16 | self.model = model 17 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 18 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 19 | 20 | def register_buffer(self, name, attr): 21 | if type(attr) == torch.Tensor: 22 | if attr.device != torch.device("cuda"): 23 | attr = attr.to(torch.device("cuda")) 24 | setattr(self, name, attr) 25 | 26 | @torch.no_grad() 27 | def sample(self, 28 | S, 29 | batch_size, 30 | shape, 31 | conditioning=None, 32 | callback=None, 33 | normals_sequence=None, 34 | img_callback=None, 35 | quantize_x0=False, 36 | eta=0., 37 | mask=None, 38 | x0=None, 39 | temperature=1., 40 | noise_dropout=0., 41 | score_corrector=None, 42 | corrector_kwargs=None, 43 | verbose=True, 44 | x_T=None, 45 | log_every_t=100, 46 | unconditional_guidance_scale=1., 47 | unconditional_conditioning=None, 48 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 49 | **kwargs 50 | ): 51 | if conditioning is not None: 52 | if isinstance(conditioning, dict): 53 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 54 | if cbs != batch_size: 55 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 56 | else: 57 | if conditioning.shape[0] != batch_size: 58 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 59 | 60 | # sampling 61 | C, H, W = shape 62 | size = (batch_size, C, H, W) 63 | 64 | print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') 65 | 66 | device = self.model.betas.device 67 | if x_T is None: 68 | img = torch.randn(size, device=device) 69 | else: 70 | img = x_T 71 | 72 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 73 | 74 | model_fn = model_wrapper( 75 | lambda x, t, c: self.model.apply_model(x, t, c), 76 | ns, 77 | model_type=MODEL_TYPES[self.model.parameterization], 78 | guidance_type="classifier-free", 79 | condition=conditioning, 80 | unconditional_condition=unconditional_conditioning, 81 | guidance_scale=unconditional_guidance_scale, 82 | ) 83 | 84 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) 85 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) 86 | 87 | return x.to(device), None -------------------------------------------------------------------------------- /src/ldm/models/diffusion/sampling_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def append_dims(x, target_dims): 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions. 7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 11 | return x[(...,) + (None,) * dims_to_append] 12 | 13 | 14 | def norm_thresholding(x0, value): 15 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) 16 | return x0 * (value / s) 17 | 18 | 19 | def spatial_norm_thresholding(x0, value): 20 | # b c h w 21 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) 22 | return x0 * (value / s) -------------------------------------------------------------------------------- /src/ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/stable_signature/48261686883ba86f533836ad47c35040afcfe37b/src/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /src/ldm/modules/diffusionmodules/upscaling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from functools import partial 5 | 6 | from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule 7 | from ldm.util import default 8 | 9 | 10 | class AbstractLowScaleModel(nn.Module): 11 | # for concatenating a downsampled image to the latent representation 12 | def __init__(self, noise_schedule_config=None): 13 | super(AbstractLowScaleModel, self).__init__() 14 | if noise_schedule_config is not None: 15 | self.register_schedule(**noise_schedule_config) 16 | 17 | def register_schedule(self, beta_schedule="linear", timesteps=1000, 18 | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 19 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, 20 | cosine_s=cosine_s) 21 | alphas = 1. - betas 22 | alphas_cumprod = np.cumprod(alphas, axis=0) 23 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 24 | 25 | timesteps, = betas.shape 26 | self.num_timesteps = int(timesteps) 27 | self.linear_start = linear_start 28 | self.linear_end = linear_end 29 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' 30 | 31 | to_torch = partial(torch.tensor, dtype=torch.float32) 32 | 33 | self.register_buffer('betas', to_torch(betas)) 34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 35 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 36 | 37 | # calculations for diffusion q(x_t | x_{t-1}) and others 38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) 39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) 40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) 41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) 42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) 43 | 44 | def q_sample(self, x_start, t, noise=None): 45 | noise = default(noise, lambda: torch.randn_like(x_start)) 46 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 47 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) 48 | 49 | def forward(self, x): 50 | return x, None 51 | 52 | def decode(self, x): 53 | return x 54 | 55 | 56 | class SimpleImageConcat(AbstractLowScaleModel): 57 | # no noise level conditioning 58 | def __init__(self): 59 | super(SimpleImageConcat, self).__init__(noise_schedule_config=None) 60 | self.max_noise_level = 0 61 | 62 | def forward(self, x): 63 | # fix to constant noise level 64 | return x, torch.zeros(x.shape[0], device=x.device).long() 65 | 66 | 67 | class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): 68 | def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): 69 | super().__init__(noise_schedule_config=noise_schedule_config) 70 | self.max_noise_level = max_noise_level 71 | 72 | def forward(self, x, noise_level=None): 73 | if noise_level is None: 74 | noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() 75 | else: 76 | assert isinstance(noise_level, torch.Tensor) 77 | z = self.q_sample(x, noise_level) 78 | return z, noise_level 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /src/ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/stable_signature/48261686883ba86f533836ad47c35040afcfe37b/src/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /src/ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /src/ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /src/ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/stable_signature/48261686883ba86f533836ad47c35040afcfe37b/src/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /src/ldm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.checkpoint import checkpoint 4 | 5 | from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel 6 | 7 | import open_clip 8 | from ldm.util import default, count_params 9 | 10 | 11 | class AbstractEncoder(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def encode(self, *args, **kwargs): 16 | raise NotImplementedError 17 | 18 | 19 | class IdentityEncoder(AbstractEncoder): 20 | 21 | def encode(self, x): 22 | return x 23 | 24 | 25 | class ClassEmbedder(nn.Module): 26 | def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): 27 | super().__init__() 28 | self.key = key 29 | self.embedding = nn.Embedding(n_classes, embed_dim) 30 | self.n_classes = n_classes 31 | self.ucg_rate = ucg_rate 32 | 33 | def forward(self, batch, key=None, disable_dropout=False): 34 | if key is None: 35 | key = self.key 36 | # this is for use in crossattn 37 | c = batch[key][:, None] 38 | if self.ucg_rate > 0. and not disable_dropout: 39 | mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) 40 | c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) 41 | c = c.long() 42 | c = self.embedding(c) 43 | return c 44 | 45 | def get_unconditional_conditioning(self, bs, device="cuda"): 46 | uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) 47 | uc = torch.ones((bs,), device=device) * uc_class 48 | uc = {self.key: uc} 49 | return uc 50 | 51 | 52 | def disabled_train(self, mode=True): 53 | """Overwrite model.train with this function to make sure train/eval mode 54 | does not change anymore.""" 55 | return self 56 | 57 | 58 | class FrozenT5Embedder(AbstractEncoder): 59 | """Uses the T5 transformer encoder for text""" 60 | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl 61 | super().__init__() 62 | self.tokenizer = T5Tokenizer.from_pretrained(version) 63 | self.transformer = T5EncoderModel.from_pretrained(version) 64 | self.device = device 65 | self.max_length = max_length # TODO: typical value? 66 | if freeze: 67 | self.freeze() 68 | 69 | def freeze(self): 70 | self.transformer = self.transformer.eval() 71 | #self.train = disabled_train 72 | for param in self.parameters(): 73 | param.requires_grad = False 74 | 75 | def forward(self, text): 76 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 77 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 78 | tokens = batch_encoding["input_ids"].to(self.device) 79 | outputs = self.transformer(input_ids=tokens) 80 | 81 | z = outputs.last_hidden_state 82 | return z 83 | 84 | def encode(self, text): 85 | return self(text) 86 | 87 | 88 | class FrozenCLIPEmbedder(AbstractEncoder): 89 | """Uses the CLIP transformer encoder for text (from huggingface)""" 90 | LAYERS = [ 91 | "last", 92 | "pooled", 93 | "hidden" 94 | ] 95 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, 96 | freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 97 | super().__init__() 98 | assert layer in self.LAYERS 99 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 100 | self.transformer = CLIPTextModel.from_pretrained(version) 101 | self.device = device 102 | self.max_length = max_length 103 | if freeze: 104 | self.freeze() 105 | self.layer = layer 106 | self.layer_idx = layer_idx 107 | if layer == "hidden": 108 | assert layer_idx is not None 109 | assert 0 <= abs(layer_idx) <= 12 110 | 111 | def freeze(self): 112 | self.transformer = self.transformer.eval() 113 | #self.train = disabled_train 114 | for param in self.parameters(): 115 | param.requires_grad = False 116 | 117 | def forward(self, text): 118 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 119 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 120 | tokens = batch_encoding["input_ids"].to(self.device) 121 | outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") 122 | if self.layer == "last": 123 | z = outputs.last_hidden_state 124 | elif self.layer == "pooled": 125 | z = outputs.pooler_output[:, None, :] 126 | else: 127 | z = outputs.hidden_states[self.layer_idx] 128 | return z 129 | 130 | def encode(self, text): 131 | return self(text) 132 | 133 | 134 | class FrozenOpenCLIPEmbedder(AbstractEncoder): 135 | """ 136 | Uses the OpenCLIP transformer encoder for text 137 | """ 138 | LAYERS = [ 139 | #"pooled", 140 | "last", 141 | "penultimate" 142 | ] 143 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, 144 | freeze=True, layer="last"): 145 | super().__init__() 146 | assert layer in self.LAYERS 147 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) 148 | del model.visual 149 | self.model = model 150 | 151 | self.device = device 152 | self.max_length = max_length 153 | if freeze: 154 | self.freeze() 155 | self.layer = layer 156 | if self.layer == "last": 157 | self.layer_idx = 0 158 | elif self.layer == "penultimate": 159 | self.layer_idx = 1 160 | else: 161 | raise NotImplementedError() 162 | 163 | def freeze(self): 164 | self.model = self.model.eval() 165 | for param in self.parameters(): 166 | param.requires_grad = False 167 | 168 | def forward(self, text): 169 | tokens = open_clip.tokenize(text) 170 | z = self.encode_with_transformer(tokens.to(self.device)) 171 | return z 172 | 173 | def encode_with_transformer(self, text): 174 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] 175 | x = x + self.model.positional_embedding 176 | x = x.permute(1, 0, 2) # NLD -> LND 177 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) 178 | x = x.permute(1, 0, 2) # LND -> NLD 179 | x = self.model.ln_final(x) 180 | return x 181 | 182 | def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): 183 | for i, r in enumerate(self.model.transformer.resblocks): 184 | if i == len(self.model.transformer.resblocks) - self.layer_idx: 185 | break 186 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): 187 | x = checkpoint(r, x, attn_mask) 188 | else: 189 | x = r(x, attn_mask=attn_mask) 190 | return x 191 | 192 | def encode(self, text): 193 | return self(text) 194 | 195 | 196 | class FrozenCLIPT5Encoder(AbstractEncoder): 197 | def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", 198 | clip_max_length=77, t5_max_length=77): 199 | super().__init__() 200 | self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) 201 | self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) 202 | print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " 203 | f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") 204 | 205 | def encode(self, text): 206 | return self(text) 207 | 208 | def forward(self, text): 209 | clip_z = self.clip_encoder.encode(text) 210 | t5_z = self.t5_encoder.encode(text) 211 | return [clip_z, t5_z] 212 | 213 | 214 | -------------------------------------------------------------------------------- /src/ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /src/ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/stable_signature/48261686883ba86f533836ad47c35040afcfe37b/src/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /src/ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /src/ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /src/ldm/modules/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/stable_signature/48261686883ba86f533836ad47c35040afcfe37b/src/ldm/modules/midas/__init__.py -------------------------------------------------------------------------------- /src/ldm/modules/midas/api.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/isl-org/MiDaS 2 | 3 | import cv2 4 | import torch 5 | import torch.nn as nn 6 | from torchvision.transforms import Compose 7 | 8 | from ldm.modules.midas.midas.dpt_depth import DPTDepthModel 9 | from ldm.modules.midas.midas.midas_net import MidasNet 10 | from ldm.modules.midas.midas.midas_net_custom import MidasNet_small 11 | from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet 12 | 13 | 14 | ISL_PATHS = { 15 | "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", 16 | "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt", 17 | "midas_v21": "", 18 | "midas_v21_small": "", 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | def load_midas_transform(model_type): 29 | # https://github.com/isl-org/MiDaS/blob/master/run.py 30 | # load transform only 31 | if model_type == "dpt_large": # DPT-Large 32 | net_w, net_h = 384, 384 33 | resize_mode = "minimal" 34 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 35 | 36 | elif model_type == "dpt_hybrid": # DPT-Hybrid 37 | net_w, net_h = 384, 384 38 | resize_mode = "minimal" 39 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 40 | 41 | elif model_type == "midas_v21": 42 | net_w, net_h = 384, 384 43 | resize_mode = "upper_bound" 44 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 45 | 46 | elif model_type == "midas_v21_small": 47 | net_w, net_h = 256, 256 48 | resize_mode = "upper_bound" 49 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 50 | 51 | else: 52 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large" 53 | 54 | transform = Compose( 55 | [ 56 | Resize( 57 | net_w, 58 | net_h, 59 | resize_target=None, 60 | keep_aspect_ratio=True, 61 | ensure_multiple_of=32, 62 | resize_method=resize_mode, 63 | image_interpolation_method=cv2.INTER_CUBIC, 64 | ), 65 | normalization, 66 | PrepareForNet(), 67 | ] 68 | ) 69 | 70 | return transform 71 | 72 | 73 | def load_model(model_type): 74 | # https://github.com/isl-org/MiDaS/blob/master/run.py 75 | # load network 76 | model_path = ISL_PATHS[model_type] 77 | if model_type == "dpt_large": # DPT-Large 78 | model = DPTDepthModel( 79 | path=model_path, 80 | backbone="vitl16_384", 81 | non_negative=True, 82 | ) 83 | net_w, net_h = 384, 384 84 | resize_mode = "minimal" 85 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 86 | 87 | elif model_type == "dpt_hybrid": # DPT-Hybrid 88 | model = DPTDepthModel( 89 | path=model_path, 90 | backbone="vitb_rn50_384", 91 | non_negative=True, 92 | ) 93 | net_w, net_h = 384, 384 94 | resize_mode = "minimal" 95 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 96 | 97 | elif model_type == "midas_v21": 98 | model = MidasNet(model_path, non_negative=True) 99 | net_w, net_h = 384, 384 100 | resize_mode = "upper_bound" 101 | normalization = NormalizeImage( 102 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 103 | ) 104 | 105 | elif model_type == "midas_v21_small": 106 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, 107 | non_negative=True, blocks={'expand': True}) 108 | net_w, net_h = 256, 256 109 | resize_mode = "upper_bound" 110 | normalization = NormalizeImage( 111 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 112 | ) 113 | 114 | else: 115 | print(f"model_type '{model_type}' not implemented, use: --model_type large") 116 | assert False 117 | 118 | transform = Compose( 119 | [ 120 | Resize( 121 | net_w, 122 | net_h, 123 | resize_target=None, 124 | keep_aspect_ratio=True, 125 | ensure_multiple_of=32, 126 | resize_method=resize_mode, 127 | image_interpolation_method=cv2.INTER_CUBIC, 128 | ), 129 | normalization, 130 | PrepareForNet(), 131 | ] 132 | ) 133 | 134 | return model.eval(), transform 135 | 136 | 137 | class MiDaSInference(nn.Module): 138 | MODEL_TYPES_TORCH_HUB = [ 139 | "DPT_Large", 140 | "DPT_Hybrid", 141 | "MiDaS_small" 142 | ] 143 | MODEL_TYPES_ISL = [ 144 | "dpt_large", 145 | "dpt_hybrid", 146 | "midas_v21", 147 | "midas_v21_small", 148 | ] 149 | 150 | def __init__(self, model_type): 151 | super().__init__() 152 | assert (model_type in self.MODEL_TYPES_ISL) 153 | model, _ = load_model(model_type) 154 | self.model = model 155 | self.model.train = disabled_train 156 | 157 | def forward(self, x): 158 | # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array 159 | # NOTE: we expect that the correct transform has been called during dataloading. 160 | with torch.no_grad(): 161 | prediction = self.model(x) 162 | prediction = torch.nn.functional.interpolate( 163 | prediction.unsqueeze(1), 164 | size=x.shape[2:], 165 | mode="bicubic", 166 | align_corners=False, 167 | ) 168 | assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) 169 | return prediction 170 | 171 | -------------------------------------------------------------------------------- /src/ldm/modules/midas/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/stable_signature/48261686883ba86f533836ad47c35040afcfe37b/src/ldm/modules/midas/midas/__init__.py -------------------------------------------------------------------------------- /src/ldm/modules/midas/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /src/ldm/modules/midas/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | 88 | class DPTDepthModel(DPT): 89 | def __init__(self, path=None, non_negative=True, **kwargs): 90 | features = kwargs["features"] if "features" in kwargs else 256 91 | 92 | head = nn.Sequential( 93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(True), 97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 98 | nn.ReLU(True) if non_negative else nn.Identity(), 99 | nn.Identity(), 100 | ) 101 | 102 | super().__init__(head, **kwargs) 103 | 104 | if path is not None: 105 | self.load(path) 106 | 107 | def forward(self, x): 108 | return super().forward(x).squeeze(dim=1) 109 | 110 | -------------------------------------------------------------------------------- /src/ldm/modules/midas/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /src/ldm/modules/midas/midas/midas_net_custom.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_small(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, 17 | blocks={'expand': True}): 18 | """Init. 19 | 20 | Args: 21 | path (str, optional): Path to saved model. Defaults to None. 22 | features (int, optional): Number of features. Defaults to 256. 23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 24 | """ 25 | print("Loading weights: ", path) 26 | 27 | super(MidasNet_small, self).__init__() 28 | 29 | use_pretrained = False if path else True 30 | 31 | self.channels_last = channels_last 32 | self.blocks = blocks 33 | self.backbone = backbone 34 | 35 | self.groups = 1 36 | 37 | features1=features 38 | features2=features 39 | features3=features 40 | features4=features 41 | self.expand = False 42 | if "expand" in self.blocks and self.blocks['expand'] == True: 43 | self.expand = True 44 | features1=features 45 | features2=features*2 46 | features3=features*4 47 | features4=features*8 48 | 49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) 50 | 51 | self.scratch.activation = nn.ReLU(False) 52 | 53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) 57 | 58 | 59 | self.scratch.output_conv = nn.Sequential( 60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), 61 | Interpolate(scale_factor=2, mode="bilinear"), 62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), 63 | self.scratch.activation, 64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 65 | nn.ReLU(True) if non_negative else nn.Identity(), 66 | nn.Identity(), 67 | ) 68 | 69 | if path: 70 | self.load(path) 71 | 72 | 73 | def forward(self, x): 74 | """Forward pass. 75 | 76 | Args: 77 | x (tensor): input data (image) 78 | 79 | Returns: 80 | tensor: depth 81 | """ 82 | if self.channels_last==True: 83 | print("self.channels_last = ", self.channels_last) 84 | x.contiguous(memory_format=torch.channels_last) 85 | 86 | 87 | layer_1 = self.pretrained.layer1(x) 88 | layer_2 = self.pretrained.layer2(layer_1) 89 | layer_3 = self.pretrained.layer3(layer_2) 90 | layer_4 = self.pretrained.layer4(layer_3) 91 | 92 | layer_1_rn = self.scratch.layer1_rn(layer_1) 93 | layer_2_rn = self.scratch.layer2_rn(layer_2) 94 | layer_3_rn = self.scratch.layer3_rn(layer_3) 95 | layer_4_rn = self.scratch.layer4_rn(layer_4) 96 | 97 | 98 | path_4 = self.scratch.refinenet4(layer_4_rn) 99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 102 | 103 | out = self.scratch.output_conv(path_1) 104 | 105 | return torch.squeeze(out, dim=1) 106 | 107 | 108 | 109 | def fuse_model(m): 110 | prev_previous_type = nn.Identity() 111 | prev_previous_name = '' 112 | previous_type = nn.Identity() 113 | previous_name = '' 114 | for name, module in m.named_modules(): 115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: 116 | # print("FUSED ", prev_previous_name, previous_name, name) 117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) 118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: 119 | # print("FUSED ", prev_previous_name, previous_name) 120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) 121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: 122 | # print("FUSED ", previous_name, name) 123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) 124 | 125 | prev_previous_type = previous_type 126 | prev_previous_name = previous_name 127 | previous_type = type(module) 128 | previous_name = name -------------------------------------------------------------------------------- /src/ldm/modules/midas/utils.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth.""" 2 | import sys 3 | import re 4 | import numpy as np 5 | import cv2 6 | import torch 7 | 8 | 9 | def read_pfm(path): 10 | """Read pfm file. 11 | 12 | Args: 13 | path (str): path to file 14 | 15 | Returns: 16 | tuple: (data, scale) 17 | """ 18 | with open(path, "rb") as file: 19 | 20 | color = None 21 | width = None 22 | height = None 23 | scale = None 24 | endian = None 25 | 26 | header = file.readline().rstrip() 27 | if header.decode("ascii") == "PF": 28 | color = True 29 | elif header.decode("ascii") == "Pf": 30 | color = False 31 | else: 32 | raise Exception("Not a PFM file: " + path) 33 | 34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 35 | if dim_match: 36 | width, height = list(map(int, dim_match.groups())) 37 | else: 38 | raise Exception("Malformed PFM header.") 39 | 40 | scale = float(file.readline().decode("ascii").rstrip()) 41 | if scale < 0: 42 | # little-endian 43 | endian = "<" 44 | scale = -scale 45 | else: 46 | # big-endian 47 | endian = ">" 48 | 49 | data = np.fromfile(file, endian + "f") 50 | shape = (height, width, 3) if color else (height, width) 51 | 52 | data = np.reshape(data, shape) 53 | data = np.flipud(data) 54 | 55 | return data, scale 56 | 57 | 58 | def write_pfm(path, image, scale=1): 59 | """Write pfm file. 60 | 61 | Args: 62 | path (str): pathto file 63 | image (array): data 64 | scale (int, optional): Scale. Defaults to 1. 65 | """ 66 | 67 | with open(path, "wb") as file: 68 | color = None 69 | 70 | if image.dtype.name != "float32": 71 | raise Exception("Image dtype must be float32.") 72 | 73 | image = np.flipud(image) 74 | 75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 76 | color = True 77 | elif ( 78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 79 | ): # greyscale 80 | color = False 81 | else: 82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 83 | 84 | file.write("PF\n" if color else "Pf\n".encode()) 85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 86 | 87 | endian = image.dtype.byteorder 88 | 89 | if endian == "<" or endian == "=" and sys.byteorder == "little": 90 | scale = -scale 91 | 92 | file.write("%f\n".encode() % scale) 93 | 94 | image.tofile(file) 95 | 96 | 97 | def read_image(path): 98 | """Read image and output RGB image (0-1). 99 | 100 | Args: 101 | path (str): path to file 102 | 103 | Returns: 104 | array: RGB image (0-1) 105 | """ 106 | img = cv2.imread(path) 107 | 108 | if img.ndim == 2: 109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 110 | 111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 112 | 113 | return img 114 | 115 | 116 | def resize_image(img): 117 | """Resize image and make it fit for network. 118 | 119 | Args: 120 | img (array): image 121 | 122 | Returns: 123 | tensor: data ready for network 124 | """ 125 | height_orig = img.shape[0] 126 | width_orig = img.shape[1] 127 | 128 | if width_orig > height_orig: 129 | scale = width_orig / 384 130 | else: 131 | scale = height_orig / 384 132 | 133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 135 | 136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 137 | 138 | img_resized = ( 139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() 140 | ) 141 | img_resized = img_resized.unsqueeze(0) 142 | 143 | return img_resized 144 | 145 | 146 | def resize_depth(depth, width, height): 147 | """Resize depth map and bring to CPU (numpy). 148 | 149 | Args: 150 | depth (tensor): depth 151 | width (int): image width 152 | height (int): image height 153 | 154 | Returns: 155 | array: processed depth 156 | """ 157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu") 158 | 159 | depth_resized = cv2.resize( 160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC 161 | ) 162 | 163 | return depth_resized 164 | 165 | def write_depth(path, depth, bits=1): 166 | """Write depth map to pfm and png file. 167 | 168 | Args: 169 | path (str): filepath without extension 170 | depth (array): depth 171 | """ 172 | write_pfm(path + ".pfm", depth.astype(np.float32)) 173 | 174 | depth_min = depth.min() 175 | depth_max = depth.max() 176 | 177 | max_val = (2**(8*bits))-1 178 | 179 | if depth_max - depth_min > np.finfo("float").eps: 180 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 181 | else: 182 | out = np.zeros(depth.shape, dtype=depth.type) 183 | 184 | if bits == 1: 185 | cv2.imwrite(path + ".png", out.astype("uint8")) 186 | elif bits == 2: 187 | cv2.imwrite(path + ".png", out.astype("uint16")) 188 | 189 | return 190 | -------------------------------------------------------------------------------- /src/ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from inspect import isfunction 7 | from PIL import Image, ImageDraw, ImageFont 8 | 9 | def log_txt_as_img(wh, xc, size=10): 10 | # wh a tuple of (width, height) 11 | # xc a list of captions to plot 12 | b = len(xc) 13 | txts = list() 14 | for bi in range(b): 15 | txt = Image.new("RGB", wh, color="white") 16 | draw = ImageDraw.Draw(txt) 17 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 18 | nc = int(40 * (wh[0] / 256)) 19 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 20 | 21 | try: 22 | draw.text((0, 0), lines, fill="black", font=font) 23 | except UnicodeEncodeError: 24 | print("Cant encode string for logging. Skipping.") 25 | 26 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 27 | txts.append(txt) 28 | txts = np.stack(txts) 29 | txts = torch.tensor(txts) 30 | return txts 31 | 32 | 33 | def ismap(x): 34 | if not isinstance(x, torch.Tensor): 35 | return False 36 | return (len(x.shape) == 4) and (x.shape[1] > 3) 37 | 38 | 39 | def isimage(x): 40 | if not isinstance(x,torch.Tensor): 41 | return False 42 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 43 | 44 | 45 | def exists(x): 46 | return x is not None 47 | 48 | 49 | def default(val, d): 50 | if exists(val): 51 | return val 52 | return d() if isfunction(d) else d 53 | 54 | 55 | def mean_flat(tensor): 56 | """ 57 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 58 | Take the mean over all non-batch dimensions. 59 | """ 60 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 61 | 62 | 63 | def count_params(model, verbose=False): 64 | total_params = sum(p.numel() for p in model.parameters()) 65 | if verbose: 66 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 67 | return total_params 68 | 69 | 70 | def instantiate_from_config(config): 71 | if not "target" in config: 72 | if config == '__is_first_stage__': 73 | return None 74 | elif config == "__is_unconditional__": 75 | return None 76 | raise KeyError("Expected key `target` to instantiate.") 77 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 78 | 79 | 80 | def get_obj_from_str(string, reload=False): 81 | module, cls = string.rsplit(".", 1) 82 | if reload: 83 | module_imp = importlib.import_module(module) 84 | importlib.reload(module_imp) 85 | return getattr(importlib.import_module(module, package=None), cls) -------------------------------------------------------------------------------- /src/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/stable_signature/48261686883ba86f533836ad47c35040afcfe37b/src/loss/__init__.py -------------------------------------------------------------------------------- /src/loss/color_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class RGB2YCbCr(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | transf = torch.tensor([[.299, .587, .114], [-.1687, -.3313, .5], [.5, -.4187, -.0813]]).transpose(0, 1) 9 | self.transform = nn.Parameter(transf, requires_grad=False) 10 | bias = torch.tensor([0, 0.5, 0.5]) 11 | self.bias = nn.Parameter(bias, requires_grad=False) 12 | 13 | def forward(self, rgb): 14 | N, C, H, W = rgb.shape 15 | assert C == 3 16 | rgb = rgb.transpose(1,3) 17 | cbcr = torch.matmul(rgb, self.transform) 18 | cbcr += self.bias 19 | return cbcr.transpose(1,3) 20 | 21 | class ColorWrapper(nn.Module): 22 | """ 23 | Extension for single-channel loss to work on color images 24 | """ 25 | def __init__(self, lossclass, args, kwargs, trainable=False): 26 | """ 27 | Parameters: 28 | lossclass: class of the individual loss functions 29 | trainable: bool, if True parameters of the loss are trained. 30 | args: tuple, arguments for instantiation of loss fun 31 | kwargs: dict, key word arguments for instantiation of loss fun 32 | """ 33 | super().__init__() 34 | 35 | # submodules 36 | self.add_module('to_YCbCr', RGB2YCbCr()) 37 | self.add_module('ly', lossclass(*args, **kwargs)) 38 | self.add_module('lcb', lossclass(*args, **kwargs)) 39 | self.add_module('lcr', lossclass(*args, **kwargs)) 40 | 41 | # weights 42 | self.w_tild = nn.Parameter(torch.zeros(3), requires_grad=trainable) 43 | 44 | @property 45 | def w(self): 46 | return F.softmax(self.w_tild, dim=0) 47 | 48 | def forward(self, input, target): 49 | # convert color space 50 | input = self.to_YCbCr(input) 51 | target = self.to_YCbCr(target) 52 | 53 | ly = self.ly(input[:,[0],:,:], target[:,[0],:,:]) 54 | lcb = self.lcb(input[:,[1],:,:], target[:,[1],:,:]) 55 | lcr = self.lcr(input[:,[2],:,:], target[:,[2],:,:]) 56 | 57 | w = self.w 58 | 59 | return ly * w[0] + lcb * w[1] + lcr * w[2] 60 | 61 | class GreyscaleWrapper(nn.Module): 62 | """ 63 | Maps 3 channel RGB or 1 channel greyscale input to 3 greyscale channels 64 | """ 65 | def __init__(self, lossclass, args, kwargs): 66 | """ 67 | Parameters: 68 | lossclass: class of the individual loss function 69 | args: tuple, arguments for instantiation of loss fun 70 | kwargs: dict, key word arguments for instantiation of loss fun 71 | """ 72 | super().__init__() 73 | 74 | # submodules 75 | self.add_module('loss', lossclass(*args, **kwargs)) 76 | 77 | def to_greyscale(self, tensor): 78 | return tensor[:,[0],:,:] * 0.3 + tensor[:,[1],:,:] * 0.59 + tensor[:,[2],:,:] * 0.11 79 | 80 | def forward(self, input, target): 81 | (N,C,X,Y) = input.size() 82 | 83 | if N == 3: 84 | # convert input to greyscale 85 | input = self.to_greyscale(input) 86 | target = self.to_greyscale(target) 87 | 88 | # input in now greyscale, expand to 3 channels 89 | input = input.expand(N, 3, X, Y) 90 | target = target.expand(N, 3, X, Y) 91 | 92 | return self.loss.forward(input, target) 93 | -------------------------------------------------------------------------------- /src/loss/dct2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class Dct2d(nn.Module): 7 | """ 8 | Blockwhise 2D DCT 9 | """ 10 | def __init__(self, blocksize=8, interleaving=False): 11 | """ 12 | Parameters: 13 | blocksize: int, size of the Blocks for discrete cosine transform 14 | interleaving: bool, should the blocks interleave? 15 | """ 16 | super().__init__() # call super constructor 17 | 18 | self.blocksize = blocksize 19 | self.interleaving = interleaving 20 | 21 | if interleaving: 22 | self.stride = self.blocksize // 2 23 | else: 24 | self.stride = self.blocksize 25 | 26 | # precompute DCT weight matrix 27 | A = np.zeros((blocksize,blocksize)) 28 | for i in range(blocksize): 29 | c_i = 1/np.sqrt(2) if i == 0 else 1. 30 | for n in range(blocksize): 31 | A[i,n] = np.sqrt(2/blocksize) * c_i * np.cos((2*n+ 1)/(blocksize*2) * i * np.pi) 32 | 33 | # set up conv layer 34 | self.A = nn.Parameter(torch.tensor(A, dtype=torch.float32), requires_grad=False) 35 | self.unfold = torch.nn.Unfold(kernel_size=blocksize, padding=0, stride=self.stride) 36 | return 37 | 38 | def forward(self, x): 39 | """ 40 | performs 2D blockwhise DCT 41 | 42 | Parameters: 43 | x: tensor of dimension (N, 1, h, w) 44 | 45 | Return: 46 | tensor of dimension (N, k, blocksize, blocksize) 47 | where the 2nd dimension indexes the block. Dimensions 3 and 4 are the block DCT coefficients 48 | """ 49 | 50 | (N, C, H, W) = x.shape 51 | assert (C == 1), "DCT is only implemented for a single channel" 52 | assert (H >= self.blocksize), "Input too small for blocksize" 53 | assert (W >= self.blocksize), "Input too small for blocksize" 54 | assert (H % self.stride == 0) and (W % self.stride == 0), "FFT is only for dimensions divisible by the blocksize" 55 | 56 | # unfold to blocks 57 | x = self.unfold(x) 58 | # now shape (N, blocksize**2, k) 59 | (N, _, k) = x.shape 60 | x = x.view(-1,self.blocksize,self.blocksize,k).permute(0,3,1,2) 61 | # now shape (N, #k, blocksize, blocksize) 62 | # perform DCT 63 | coeff = self.A.matmul(x).matmul(self.A.transpose(0,1)) 64 | 65 | return coeff 66 | 67 | def inverse(self, coeff, output_shape): 68 | """ 69 | performs 2D blockwhise iDCT 70 | 71 | Parameters: 72 | coeff: tensor of dimension (N, k, blocksize, blocksize) 73 | where the 2nd dimension indexes the block. Dimensions 3 and 4 are the block DCT coefficients 74 | output_shape: (h, w) dimensions of the reconstructed image 75 | 76 | Return: 77 | tensor of dimension (N, 1, h, w) 78 | """ 79 | if self.interleaving: 80 | raise Exception('Inverse block DCT is not implemented for interleaving blocks!') 81 | 82 | # perform iDCT 83 | x = self.A.transpose(0,1).matmul(coeff).matmul(self.A) 84 | (N, k, _, _) = x.shape 85 | x = x.permute(0,2,3,1).view(-1, self.blocksize**2, k) 86 | x = F.fold(x, output_size=(output_shape[-2], output_shape[-1]), kernel_size=self.blocksize, padding=0, stride=self.blocksize) 87 | return x 88 | -------------------------------------------------------------------------------- /src/loss/loss_provider.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | from collections import OrderedDict 5 | 6 | from loss.color_wrapper import ColorWrapper, GreyscaleWrapper 7 | from loss.shift_wrapper import ShiftWrapper 8 | from loss.watson import WatsonDistance 9 | from loss.watson_fft import WatsonDistanceFft 10 | from loss.watson_vgg import WatsonDistanceVgg 11 | from loss.deep_loss import PNetLin 12 | from loss.ssim import SSIM 13 | 14 | 15 | class LossProvider(): 16 | def __init__(self): 17 | self.loss_functions = ['L1', 'L2', 'SSIM', 'Watson-dct', 'Watson-fft', 'Watson-vgg', 'Deeploss-vgg', 'Deeploss-squeeze', 'Adaptive'] 18 | self.color_models = ['LA', 'RGB'] 19 | 20 | def load_state_dict(self, filename): 21 | current_dir = os.path.dirname(__file__) 22 | path = os.path.join(current_dir, 'losses', filename) 23 | return torch.load(path, map_location='cpu') 24 | 25 | def get_loss_function(self, model, colorspace='RGB', reduction='sum', deterministic=False, pretrained=True, image_size=None): 26 | """ 27 | returns a trained loss class. 28 | model: one of the values returned by self.loss_functions 29 | colorspace: 'LA' or 'RGB' 30 | deterministic: bool, if false (default) uses shifting of image blocks for watson-fft 31 | image_size: tuple, size of input images. Only required for adaptive loss. Eg: [3, 64, 64] 32 | """ 33 | is_greyscale = colorspace in ['grey', 'Grey', 'LA', 'greyscale', 'grey-scale'] 34 | 35 | 36 | if model.lower() in ['l2']: 37 | loss = nn.MSELoss(reduction=reduction) 38 | elif model.lower() in ['l1']: 39 | loss = nn.L1Loss(reduction=reduction) 40 | elif model.lower() in ['ssim']: 41 | loss = SSIM(size_average=(reduction in ['sum', 'mean'])) 42 | elif model.lower() in ['watson', 'watson-dct']: 43 | if is_greyscale: 44 | if deterministic: 45 | loss = WatsonDistance(reduction=reduction) 46 | if pretrained: 47 | loss.load_state_dict(self.load_state_dict('gray_watson_dct_trial0.pth')) 48 | else: 49 | loss = ShiftWrapper(WatsonDistance, (), {'reduction': reduction}) 50 | if pretrained: 51 | loss.loss.load_state_dict(self.load_state_dict('gray_watson_dct_trial0.pth')) 52 | else: 53 | if deterministic: 54 | loss = ColorWrapper(WatsonDistance, (), {'reduction': reduction}) 55 | if pretrained: 56 | loss.load_state_dict(self.load_state_dict('rgb_watson_dct_trial0.pth')) 57 | else: 58 | loss = ShiftWrapper(ColorWrapper, (WatsonDistance, (), {'reduction': reduction}), {}) 59 | if pretrained: 60 | loss.loss.load_state_dict(self.load_state_dict('rgb_watson_dct_trial0.pth')) 61 | elif model.lower() in ['watson-fft', 'watson-dft']: 62 | if is_greyscale: 63 | if deterministic: 64 | loss = WatsonDistanceFft(reduction=reduction) 65 | if pretrained: 66 | loss.load_state_dict(self.load_state_dict('gray_watson_fft_trial0.pth')) 67 | else: 68 | loss = ShiftWrapper(WatsonDistanceFft, (), {'reduction': reduction}) 69 | if pretrained: 70 | loss.loss.load_state_dict(self.load_state_dict('gray_watson_fft_trial0.pth')) 71 | else: 72 | if deterministic: 73 | loss = ColorWrapper(WatsonDistanceFft, (), {'reduction': reduction}) 74 | if pretrained: 75 | loss.load_state_dict(self.load_state_dict('rgb_watson_fft_trial0.pth')) 76 | else: 77 | loss = ShiftWrapper(ColorWrapper, (WatsonDistanceFft, (), {'reduction': reduction}), {}) 78 | if pretrained: 79 | loss.loss.load_state_dict(self.load_state_dict('rgb_watson_fft_trial0.pth')) 80 | elif model.lower() in ['watson-vgg', 'watson-deep']: 81 | if is_greyscale: 82 | loss = GreyscaleWrapper(WatsonDistanceVgg, (), {'reduction': reduction}) 83 | if pretrained: 84 | loss.loss.load_state_dict(self.load_state_dict('gray_watson_vgg_trial0.pth')) 85 | else: 86 | loss = WatsonDistanceVgg(reduction=reduction) 87 | if pretrained: 88 | loss.load_state_dict(self.load_state_dict('rgb_watson_vgg_trial0.pth')) 89 | elif model.lower() in ['deeploss-vgg']: 90 | if is_greyscale: 91 | loss = GreyscaleWrapper(PNetLin, (), {'pnet_type': 'vgg', 'reduction': reduction, 'use_dropout': False}) 92 | if pretrained: 93 | loss.loss.load_state_dict(self.load_state_dict('gray_pnet_lin_vgg_trial0.pth')) 94 | else: 95 | loss = PNetLin(pnet_type='vgg', reduction=reduction, use_dropout=False) 96 | if pretrained: 97 | loss.load_state_dict(self.load_state_dict('rgb_pnet_lin_vgg_trial0.pth')) 98 | elif model.lower() in ['deeploss-squeeze']: 99 | if is_greyscale: 100 | loss = GreyscaleWrapper(PNetLin, (), {'pnet_type': 'squeeze', 'reduction': reduction, 'use_dropout': False}) 101 | if pretrained: 102 | loss.loss.load_state_dict(self.load_state_dict('gray_pnet_lin_squeeze_trial0.pth')) 103 | else: 104 | loss = PNetLin(pnet_type='squeeze', reduction=reduction, use_dropout=False) 105 | if pretrained: 106 | loss.load_state_dict(self.load_state_dict('rgb_pnet_lin_squeeze_trial0.pth')) 107 | else: 108 | raise Exception('Metric "{}" not implemented'.format(model)) 109 | 110 | # freeze all training of the loss functions 111 | if pretrained: 112 | for param in loss.parameters(): 113 | param.requires_grad = False 114 | 115 | return loss 116 | -------------------------------------------------------------------------------- /src/loss/rfft2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.fft as fft 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | 8 | class Rfft2d(nn.Module): 9 | """ 10 | Blockwhise 2D FFT 11 | for fixed blocksize of 8x8 12 | """ 13 | def __init__(self, blocksize=8, interleaving=False): 14 | """ 15 | Parameters: 16 | """ 17 | super().__init__() # call super constructor 18 | 19 | self.blocksize = blocksize 20 | self.interleaving = interleaving 21 | if interleaving: 22 | self.stride = self.blocksize // 2 23 | else: 24 | self.stride = self.blocksize 25 | 26 | self.unfold = torch.nn.Unfold(kernel_size=self.blocksize, padding=0, stride=self.stride) 27 | return 28 | 29 | def forward(self, x): 30 | """ 31 | performs 2D blockwhise DCT 32 | 33 | Parameters: 34 | x: tensor of dimension (N, 1, h, w) 35 | 36 | Return: 37 | tensor of dimension (N, k, b, b/2, 2) 38 | where the 2nd dimension indexes the block. Dimensions 3 and 4 are the block real FFT coefficients. 39 | The last dimension is pytorches representation of complex values 40 | """ 41 | 42 | (N, C, H, W) = x.shape 43 | assert (C == 1), "FFT is only implemented for a single channel" 44 | assert (H >= self.blocksize), "Input too small for blocksize" 45 | assert (W >= self.blocksize), "Input too small for blocksize" 46 | assert (H % self.stride == 0) and (W % self.stride == 0), "FFT is only for dimensions divisible by the blocksize" 47 | 48 | # unfold to blocks 49 | x = self.unfold(x) 50 | # now shape (N, 64, k) 51 | (N, _, k) = x.shape 52 | x = x.view(-1,self.blocksize,self.blocksize,k).permute(0,3,1,2) 53 | # now shape (N, #k, b, b) 54 | # perform DCT 55 | coeff = fft.rfft(x) 56 | coeff = torch.view_as_real(coeff) 57 | 58 | return coeff / self.blocksize**2 59 | 60 | def inverse(self, coeff, output_shape): 61 | """ 62 | performs 2D blockwhise inverse rFFT 63 | 64 | Parameters: 65 | output_shape: Tuple, dimensions of the outpus sample 66 | """ 67 | if self.interleaving: 68 | raise Exception('Inverse block FFT is not implemented for interleaving blocks!') 69 | 70 | # perform iRFFT 71 | x = fft.irfft(coeff, dim=2, signal_sizes=(self.blocksize, self.blocksize)) 72 | (N, k, _, _) = x.shape 73 | x = x.permute(0,2,3,1).view(-1, self.blocksize**2, k) 74 | x = F.fold(x, output_size=(output_shape[-2], output_shape[-1]), kernel_size=self.blocksize, padding=0, stride=self.blocksize) 75 | return x * (self.blocksize**2) 76 | -------------------------------------------------------------------------------- /src/loss/shift_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | class ShiftWrapper(nn.Module): 6 | """ 7 | Extension for 2-dimensional inout loss functions. 8 | Shifts the inputs by up to 4 pixels. Uses replication padding. 9 | """ 10 | def __init__(self, lossclass, args, kwargs): 11 | """ 12 | Parameters: 13 | lossclass: class of the individual loss functions 14 | trainable: bool, if True parameters of the loss are trained. 15 | args: tuple, arguments for instantiation of loss fun 16 | kwargs: dict, key word arguments for instantiation of loss fun 17 | """ 18 | super().__init__() 19 | 20 | # submodules 21 | self.add_module('loss', lossclass(*args, **kwargs)) 22 | 23 | # shift amount 24 | self.max_shift = 8 25 | 26 | # padding 27 | self.pad = nn.ReplicationPad2d(self.max_shift // 2) 28 | 29 | def forward(self, input, target): 30 | # convert color space 31 | input = self.pad(input) 32 | target = self.pad(target) 33 | 34 | shift_x = np.random.randint(self.max_shift) 35 | shift_y = np.random.randint(self.max_shift) 36 | 37 | input = input[:,:,shift_x:-(self.max_shift - shift_x),shift_y:-(self.max_shift - shift_y)] 38 | target = target[:,:,shift_x:-(self.max_shift - shift_x),shift_y:-(self.max_shift - shift_y)] 39 | 40 | return self.loss(input, target) 41 | -------------------------------------------------------------------------------- /src/loss/ssim.py: -------------------------------------------------------------------------------- 1 | # SSIM implementation from https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import numpy as np 7 | from math import exp 8 | 9 | def gaussian(window_size, sigma): 10 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 11 | return gauss/gauss.sum() 12 | 13 | def create_window(window_size, channel): 14 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 15 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 16 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 17 | return window 18 | 19 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 20 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 21 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 22 | 23 | mu1_sq = mu1.pow(2) 24 | mu2_sq = mu2.pow(2) 25 | mu1_mu2 = mu1*mu2 26 | 27 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 28 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 29 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 30 | 31 | C1 = 0.01**2 32 | C2 = 0.03**2 33 | 34 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 35 | 36 | if size_average: 37 | return ssim_map.mean() 38 | else: 39 | return ssim_map.mean(1).mean(1).mean(1) 40 | 41 | class SSIM(torch.nn.Module): 42 | def __init__(self, window_size = 11, size_average = True): 43 | super(SSIM, self).__init__() 44 | self.window_size = window_size 45 | self.size_average = size_average 46 | self.channel = 1 47 | self.window = create_window(window_size, self.channel) 48 | 49 | def forward(self, img1, img2): 50 | (_, channel, _, _) = img1.size() 51 | 52 | if channel == self.channel and self.window.data.type() == img1.data.type(): 53 | window = self.window 54 | else: 55 | window = create_window(self.window_size, channel) 56 | 57 | if img1.is_cuda: 58 | window = window.cuda(img1.get_device()) 59 | window = window.type_as(img1) 60 | 61 | self.window = window 62 | self.channel = channel 63 | 64 | 65 | return 1 - _ssim(img1, img2, window, self.window_size, channel, self.size_average) 66 | 67 | def ssim(img1, img2, window_size = 11, size_average = True): 68 | (_, channel, _, _) = img1.size() 69 | window = create_window(window_size, channel) 70 | 71 | if img1.is_cuda: 72 | window = window.cuda(img1.get_device()) 73 | window = window.type_as(img1) 74 | 75 | return _ssim(img1, img2, window, window_size, channel, size_average) 76 | -------------------------------------------------------------------------------- /src/loss/watson.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from loss.dct2d import Dct2d 5 | 6 | EPS = 1e-10 7 | 8 | def softmax(a, b, factor=1): 9 | concat = torch.cat([a.unsqueeze(-1), b.unsqueeze(-1)], dim=-1) 10 | softmax_factors = F.softmax(concat * factor, dim=-1) 11 | return a * softmax_factors[:,:,:,:,0] + b * softmax_factors[:,:,:,:,1] 12 | 13 | class WatsonDistance(nn.Module): 14 | """ 15 | Loss function based on Watsons perceptual distance. 16 | Based on DCT quantization 17 | """ 18 | def __init__(self, blocksize=8, trainable=False, reduction='sum'): 19 | """ 20 | Parameters: 21 | blocksize: int, size of the Blocks for discrete cosine transform 22 | trainable: bool, if True parameters of the loss are trained and dropout is enabled. 23 | reduction: 'sum' or 'none', determines return format 24 | """ 25 | super().__init__() 26 | 27 | # input mapping 28 | blocksize = torch.as_tensor(blocksize) 29 | 30 | # module to perform 2D blockwise DCT 31 | self.add_module('dct', Dct2d(blocksize=blocksize.item(), interleaving=False)) 32 | 33 | # parameters, initialized with values from watson paper 34 | self.blocksize = nn.Parameter(blocksize, requires_grad=False) 35 | if self.blocksize == 8: 36 | # init with Jpeg QM 37 | self.t_tild = nn.Parameter(torch.log(torch.tensor( # log-scaled weights 38 | [[1.40, 1.01, 1.16, 1.66, 2.40, 3.43, 4.79, 6.56], 39 | [1.01, 1.45, 1.32, 1.52, 2.00, 2.71, 3.67, 4.93], 40 | [1.16, 1.32, 2.24, 2.59, 2.98, 3.64, 4.60, 5.88], 41 | [1.66, 1.52, 2.59, 3.77, 4.55, 5.30, 6.28, 7.60], 42 | [2.40, 2.00, 2.98, 4.55, 6.15, 7.46, 8.71, 10.17], 43 | [3.43, 2.71, 3.64, 5.30, 7.46, 9.62, 11.58, 13.51], 44 | [4.79, 3.67, 4.60, 6.28, 8.71, 11.58, 14.50, 17.29], 45 | [6.56, 4.93, 5.88, 7.60, 10.17, 13.51, 17.29, 21.15]] 46 | )), requires_grad=trainable) 47 | else: 48 | # init with uniform QM 49 | self.t_tild = nn.Parameter(torch.zeros((self.blocksize, self.blocksize)), requires_grad=trainable) 50 | 51 | # other default parameters 52 | self.alpha = nn.Parameter(torch.tensor(0.649), requires_grad=trainable) # luminance masking 53 | w = torch.tensor(0.7) # contrast masking 54 | self.w_tild = nn.Parameter(torch.log(w / (1- w)), requires_grad=trainable) # inverse of sigmoid 55 | self.beta = nn.Parameter(torch.tensor(4.), requires_grad=trainable) # pooling 56 | 57 | # dropout for training 58 | self.dropout = nn.Dropout(0.5 if trainable else 0) 59 | 60 | # reduction 61 | self.reduction = reduction 62 | if reduction not in ['sum', 'none']: 63 | raise Exception('Reduction "{}" not supported. Valid values are: "sum", "none".'.format(reduction)) 64 | 65 | @property 66 | def t(self): 67 | # returns QM 68 | qm = torch.exp(self.t_tild) 69 | return qm 70 | 71 | @property 72 | def w(self): 73 | # return luminance masking parameter 74 | return torch.sigmoid(self.w_tild) 75 | 76 | def forward(self, input, target): 77 | # dct 78 | c0 = self.dct(target) 79 | c1 = self.dct(input) 80 | 81 | N, K, B, B = c0.shape 82 | 83 | # luminance masking 84 | avg_lum = torch.mean(c0[:,:,0,0]) 85 | t_l = self.t.view(1, 1, B, B).expand(N, K, B, B) 86 | t_l = t_l * (((c0[:,:,0,0] + EPS) / (avg_lum + EPS)) ** self.alpha).view(N, K, 1, 1) 87 | 88 | # contrast masking 89 | s = softmax(t_l, (c0.abs() + EPS)**self.w * t_l**(1 - self.w)) 90 | 91 | # pooling 92 | watson_dist = (((c0 - c1) / s).abs() + EPS) ** self.beta 93 | watson_dist = self.dropout(watson_dist) + EPS 94 | watson_dist = torch.sum(watson_dist, dim=(1,2,3)) 95 | watson_dist = watson_dist ** (1 / self.beta) 96 | 97 | # reduction 98 | if self.reduction == 'sum': 99 | watson_dist = torch.sum(watson_dist) 100 | 101 | return watson_dist 102 | 103 | -------------------------------------------------------------------------------- /src/loss/watson_fft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from loss.rfft2d import Rfft2d 5 | 6 | EPS = 1e-10 7 | 8 | def softmax(a, b, factor=1): 9 | concat = torch.cat([a.unsqueeze(-1), b.unsqueeze(-1)], dim=-1) 10 | softmax_factors = F.softmax(concat * factor, dim=-1) 11 | return a * softmax_factors[:,:,:,:,0] + b * softmax_factors[:,:,:,:,1] 12 | 13 | class WatsonDistanceFft(nn.Module): 14 | """ 15 | Loss function based on Watsons perceptual distance. 16 | Based on FFT quantization 17 | """ 18 | def __init__(self, blocksize=8, trainable=False, reduction='sum'): 19 | """ 20 | Parameters: 21 | blocksize: int, size of the Blocks for discrete cosine transform 22 | trainable: bool, if True parameters of the loss are trained and dropout is enabled. 23 | reduction: 'sum' or 'none', determines return format 24 | """ 25 | super().__init__() 26 | self.trainable = trainable 27 | 28 | # input mapping 29 | blocksize = torch.as_tensor(blocksize) 30 | 31 | # module to perform 2D blockwise rFFT 32 | self.add_module('fft', Rfft2d(blocksize=blocksize.item(), interleaving=False)) 33 | 34 | # parameters 35 | self.weight_size = (blocksize, blocksize // 2 + 1) 36 | self.blocksize = nn.Parameter(blocksize, requires_grad=False) 37 | # init with uniform QM 38 | self.t_tild = nn.Parameter(torch.zeros(self.weight_size), requires_grad=trainable) 39 | self.alpha = nn.Parameter(torch.tensor(0.1), requires_grad=trainable) # luminance masking 40 | w = torch.tensor(0.2) # contrast masking 41 | self.w_tild = nn.Parameter(torch.log(w / (1- w)), requires_grad=trainable) # inverse of sigmoid 42 | self.beta = nn.Parameter(torch.tensor(1.), requires_grad=trainable) # pooling 43 | 44 | # phase weights 45 | self.w_phase_tild = nn.Parameter(torch.zeros(self.weight_size) -2., requires_grad=trainable) 46 | 47 | # dropout for training 48 | self.dropout = nn.Dropout(0.5 if trainable else 0) 49 | 50 | # reduction 51 | self.reduction = reduction 52 | if reduction not in ['sum', 'none']: 53 | raise Exception('Reduction "{}" not supported. Valid values are: "sum", "none".'.format(reduction)) 54 | 55 | @property 56 | def t(self): 57 | # returns QM 58 | qm = torch.exp(self.t_tild) 59 | return qm 60 | 61 | @property 62 | def w(self): 63 | # return luminance masking parameter 64 | return torch.sigmoid(self.w_tild) 65 | 66 | @property 67 | def w_phase(self): 68 | # return weights for phase 69 | w_phase = torch.exp(self.w_phase_tild) 70 | # set weights of non-phases to 0 71 | if not self.trainable: 72 | w_phase[0,0] = 0. 73 | w_phase[0,self.weight_size[1] - 1] = 0. 74 | w_phase[self.weight_size[1] - 1,self.weight_size[1] - 1] = 0. 75 | w_phase[self.weight_size[1] - 1, 0] = 0. 76 | return w_phase 77 | 78 | def forward(self, input, target): 79 | # fft 80 | c0 = self.fft(target) 81 | c1 = self.fft(input) 82 | 83 | N, K, H, W, _ = c0.shape 84 | 85 | # get amplitudes 86 | c0_amp = torch.norm(c0 + EPS, p='fro', dim=4) 87 | c1_amp = torch.norm(c1 + EPS, p='fro', dim=4) 88 | 89 | # luminance masking 90 | avg_lum = torch.mean(c0_amp[:,:,0,0]) 91 | t_l = self.t.view(1, 1, H, W).expand(N, K, H, W) 92 | t_l = t_l * (((c0_amp[:,:,0,0] + EPS) / (avg_lum + EPS)) ** self.alpha).view(N, K, 1, 1) 93 | 94 | # contrast masking 95 | s = softmax(t_l, (c0_amp.abs() + EPS)**self.w * t_l**(1 - self.w)) 96 | 97 | # pooling 98 | watson_dist = (((c0_amp - c1_amp) / s).abs() + EPS) ** self.beta 99 | watson_dist = self.dropout(watson_dist) + EPS 100 | watson_dist = torch.sum(watson_dist, dim=(1,2,3)) 101 | watson_dist = watson_dist ** (1 / self.beta) 102 | 103 | # get phases 104 | c0_phase = torch.atan2( c0[:,:,:,:,1], c0[:,:,:,:,0] + EPS) 105 | c1_phase = torch.atan2( c1[:,:,:,:,1], c1[:,:,:,:,0] + EPS) 106 | 107 | # angular distance 108 | phase_dist = torch.acos(torch.cos(c0_phase - c1_phase)*(1 - EPS*10**3)) * self.w_phase # we multiply with a factor ->1 to prevent taking the gradient of acos(-1) or acos(1). The gradient in this case would be -/+ inf 109 | phase_dist = self.dropout(phase_dist) 110 | phase_dist = torch.sum(phase_dist, dim=(1,2,3)) 111 | 112 | # perceptual distance 113 | distance = watson_dist + phase_dist 114 | 115 | # reduce 116 | if self.reduction == 'sum': 117 | distance = torch.sum(distance) 118 | 119 | return distance 120 | 121 | -------------------------------------------------------------------------------- /src/loss/watson_vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | 6 | EPS = 1e-10 7 | 8 | class VggFeatureExtractor(nn.Module): 9 | def __init__(self): 10 | super(VggFeatureExtractor, self).__init__() 11 | 12 | # download vgg 13 | vgg16 = torchvision.models.vgg16(pretrained=True).features 14 | 15 | # set non trainable 16 | for param in vgg16.parameters(): 17 | param.requires_grad = False 18 | 19 | # slice model 20 | self.slice1 = torch.nn.Sequential() 21 | self.slice2 = torch.nn.Sequential() 22 | self.slice3 = torch.nn.Sequential() 23 | self.slice4 = torch.nn.Sequential() 24 | self.slice5 = torch.nn.Sequential() 25 | 26 | for x in range(4): # conv relu conv relu 27 | self.slice1.add_module(str(x), vgg16[x]) 28 | for x in range(4, 9): # max conv relu conv relu 29 | self.slice2.add_module(str(x), vgg16[x]) 30 | for x in range(9, 16): # max cov relu conv relu conv relu 31 | self.slice3.add_module(str(x), vgg16[x]) 32 | for x in range(16, 23): # conv relu max conv relu conv relu 33 | self.slice4.add_module(str(x), vgg16[x]) 34 | for x in range(23, 30): # conv relu conv relu max conv relu 35 | self.slice5.add_module(str(x), vgg16[x]) 36 | 37 | def forward(self, X): 38 | h = self.slice1(X) 39 | h_relu1_2 = h 40 | h = self.slice2(h) 41 | h_relu2_2 = h 42 | h = self.slice3(h) 43 | h_relu3_3 = h 44 | h = self.slice4(h) 45 | h_relu4_3 = h 46 | h = self.slice5(h) 47 | h_relu5_3 = h 48 | 49 | return [h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3] 50 | 51 | 52 | def normalize_tensor(t): 53 | # norms a tensor over the channel dimension to an euclidean length of 1. 54 | N, C, H, W = t.shape 55 | norm_factor = torch.sqrt(torch.sum(t**2,dim=1)).view(N,1,H,W) 56 | return t/(norm_factor.expand_as(t)+EPS) 57 | 58 | def softmax(a, b, factor=1): 59 | concat = torch.cat([a.unsqueeze(-1), b.unsqueeze(-1)], dim=-1) 60 | softmax_factors = F.softmax(concat * factor, dim=-1) 61 | return a * softmax_factors[:,:,:,:,0] + b * softmax_factors[:,:,:,:,1] 62 | 63 | class WatsonDistanceVgg(nn.Module): 64 | """ 65 | Loss function based on Watsons perceptual distance. 66 | Based on deep feature extraction 67 | """ 68 | def __init__(self, trainable=False, reduction='sum'): 69 | """ 70 | Parameters: 71 | trainable: bool, if True parameters of the loss are trained and dropout is enabled. 72 | reduction: 'sum' or 'none', determines return format 73 | """ 74 | super().__init__() 75 | 76 | # module to perform feature extraction 77 | self.add_module('vgg', VggFeatureExtractor()) 78 | 79 | # imagenet-normalization 80 | self.shift = nn.Parameter(torch.Tensor([-.030, -.088, -.188]).view(1,3,1,1), requires_grad=False) 81 | self.scale = nn.Parameter(torch.Tensor([.458, .448, .450]).view(1,3,1,1), requires_grad=False) 82 | 83 | # channel dimensions 84 | self.L = 5 85 | self.channels = [64,128,256,512,512] 86 | 87 | # sensitivity parameters 88 | self.t0_tild = nn.Parameter(torch.zeros((self.channels[0])), requires_grad=trainable) 89 | self.t1_tild = nn.Parameter(torch.zeros((self.channels[1])), requires_grad=trainable) 90 | self.t2_tild = nn.Parameter(torch.zeros((self.channels[2])), requires_grad=trainable) 91 | self.t3_tild = nn.Parameter(torch.zeros((self.channels[3])), requires_grad=trainable) 92 | self.t4_tild = nn.Parameter(torch.zeros((self.channels[4])), requires_grad=trainable) 93 | 94 | # other default parameters 95 | w = torch.tensor(0.2) # contrast masking 96 | self.w0_tild = nn.Parameter(torch.log(w / (1- w)), requires_grad=trainable) # inverse of sigmoid 97 | self.w1_tild = nn.Parameter(torch.log(w / (1- w)), requires_grad=trainable) 98 | self.w2_tild = nn.Parameter(torch.log(w / (1- w)), requires_grad=trainable) 99 | self.w3_tild = nn.Parameter(torch.log(w / (1- w)), requires_grad=trainable) 100 | self.w4_tild = nn.Parameter(torch.log(w / (1- w)), requires_grad=trainable) 101 | self.beta = nn.Parameter(torch.tensor(1.), requires_grad=trainable) # pooling 102 | 103 | # dropout for training 104 | self.dropout = nn.Dropout(0.5 if trainable else 0) 105 | 106 | # reduction 107 | self.reduction = reduction 108 | if reduction not in ['sum', 'none']: 109 | raise Exception('Reduction "{}" not supported. Valid values are: "sum", "none".'.format(reduction)) 110 | 111 | @property 112 | def t(self): 113 | return [torch.exp(t) for t in [self.t0_tild, self.t1_tild, self.t2_tild, self.t3_tild, self.t4_tild]] 114 | 115 | @property 116 | def w(self): 117 | # return luminance masking parameter 118 | return [torch.sigmoid(w) for w in [self.w0_tild, self.w1_tild, self.w2_tild, self.w3_tild, self.w4_tild]] 119 | 120 | def forward(self, input, target): 121 | # normalization 122 | input = (input - self.shift.expand_as(input))/self.scale.expand_as(input) 123 | target = (target - self.shift.expand_as(target))/self.scale.expand_as(target) 124 | 125 | # feature extraction 126 | c0 = self.vgg(target) 127 | c1 = self.vgg(input) 128 | 129 | # norm over channels 130 | for l in range(self.L): 131 | c0[l] = normalize_tensor(c0[l]) 132 | c1[l] = normalize_tensor(c1[l]) 133 | 134 | # contrast masking 135 | t = self.t 136 | w = self.w 137 | s = [] 138 | for l in range(self.L): 139 | N, C_l, H_l, W_l = c0[l].shape 140 | t_l = t[l].view(1,C_l,1,1).expand(N, C_l, H_l, W_l) 141 | s.append(softmax(t_l, (c0[l].abs() + EPS)**w[l] * t_l**(1 - w[l]))) 142 | 143 | # pooling 144 | watson_dist = 0 145 | for l in range(self.L): 146 | _, _, H_l, W_l = c0[l].shape 147 | layer_dist = (((c0[l] - c1[l]) / s[l]).abs() + EPS) ** self.beta 148 | layer_dist = self.dropout(layer_dist) + EPS 149 | layer_dist = torch.sum(layer_dist, dim=(1,2,3)) # sum over dimensions of layer 150 | layer_dist = (1 / (H_l * W_l)) * layer_dist # normalize by layer size 151 | watson_dist += layer_dist # sum over layers 152 | watson_dist = watson_dist ** (1 / self.beta) 153 | 154 | # reduction 155 | if self.reduction == 'sum': 156 | watson_dist = torch.sum(watson_dist) 157 | 158 | return watson_dist 159 | 160 | -------------------------------------------------------------------------------- /src/taming/data/ade20k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import albumentations 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | from taming.data.sflckr import SegmentationBase # for examples included in repo 9 | 10 | 11 | class Examples(SegmentationBase): 12 | def __init__(self, size=256, random_crop=False, interpolation="bicubic"): 13 | super().__init__(data_csv="data/ade20k_examples.txt", 14 | data_root="data/ade20k_images", 15 | segmentation_root="data/ade20k_segmentations", 16 | size=size, random_crop=random_crop, 17 | interpolation=interpolation, 18 | n_labels=151, shift_segmentation=False) 19 | 20 | 21 | # With semantic map and scene label 22 | class ADE20kBase(Dataset): 23 | def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None): 24 | self.split = self.get_split() 25 | self.n_labels = 151 # unknown + 150 26 | self.data_csv = {"train": "data/ade20k_train.txt", 27 | "validation": "data/ade20k_test.txt"}[self.split] 28 | self.data_root = "data/ade20k_root" 29 | with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f: 30 | self.scene_categories = f.read().splitlines() 31 | self.scene_categories = dict(line.split() for line in self.scene_categories) 32 | with open(self.data_csv, "r") as f: 33 | self.image_paths = f.read().splitlines() 34 | self._length = len(self.image_paths) 35 | self.labels = { 36 | "relative_file_path_": [l for l in self.image_paths], 37 | "file_path_": [os.path.join(self.data_root, "images", l) 38 | for l in self.image_paths], 39 | "relative_segmentation_path_": [l.replace(".jpg", ".png") 40 | for l in self.image_paths], 41 | "segmentation_path_": [os.path.join(self.data_root, "annotations", 42 | l.replace(".jpg", ".png")) 43 | for l in self.image_paths], 44 | "scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")] 45 | for l in self.image_paths], 46 | } 47 | 48 | size = None if size is not None and size<=0 else size 49 | self.size = size 50 | if crop_size is None: 51 | self.crop_size = size if size is not None else None 52 | else: 53 | self.crop_size = crop_size 54 | if self.size is not None: 55 | self.interpolation = interpolation 56 | self.interpolation = { 57 | "nearest": cv2.INTER_NEAREST, 58 | "bilinear": cv2.INTER_LINEAR, 59 | "bicubic": cv2.INTER_CUBIC, 60 | "area": cv2.INTER_AREA, 61 | "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] 62 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 63 | interpolation=self.interpolation) 64 | self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 65 | interpolation=cv2.INTER_NEAREST) 66 | 67 | if crop_size is not None: 68 | self.center_crop = not random_crop 69 | if self.center_crop: 70 | self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) 71 | else: 72 | self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) 73 | self.preprocessor = self.cropper 74 | 75 | def __len__(self): 76 | return self._length 77 | 78 | def __getitem__(self, i): 79 | example = dict((k, self.labels[k][i]) for k in self.labels) 80 | image = Image.open(example["file_path_"]) 81 | if not image.mode == "RGB": 82 | image = image.convert("RGB") 83 | image = np.array(image).astype(np.uint8) 84 | if self.size is not None: 85 | image = self.image_rescaler(image=image)["image"] 86 | segmentation = Image.open(example["segmentation_path_"]) 87 | segmentation = np.array(segmentation).astype(np.uint8) 88 | if self.size is not None: 89 | segmentation = self.segmentation_rescaler(image=segmentation)["image"] 90 | if self.size is not None: 91 | processed = self.preprocessor(image=image, mask=segmentation) 92 | else: 93 | processed = {"image": image, "mask": segmentation} 94 | example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) 95 | segmentation = processed["mask"] 96 | onehot = np.eye(self.n_labels)[segmentation] 97 | example["segmentation"] = onehot 98 | return example 99 | 100 | 101 | class ADE20kTrain(ADE20kBase): 102 | # default to random_crop=True 103 | def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None): 104 | super().__init__(config=config, size=size, random_crop=random_crop, 105 | interpolation=interpolation, crop_size=crop_size) 106 | 107 | def get_split(self): 108 | return "train" 109 | 110 | 111 | class ADE20kValidation(ADE20kBase): 112 | def get_split(self): 113 | return "validation" 114 | 115 | 116 | if __name__ == "__main__": 117 | dset = ADE20kValidation() 118 | ex = dset[0] 119 | for k in ["image", "scene_category", "segmentation"]: 120 | print(type(ex[k])) 121 | try: 122 | print(ex[k].shape) 123 | except: 124 | print(ex[k]) 125 | -------------------------------------------------------------------------------- /src/taming/data/annotated_objects_coco.py: -------------------------------------------------------------------------------- 1 | import json 2 | from itertools import chain 3 | from pathlib import Path 4 | from typing import Iterable, Dict, List, Callable, Any 5 | from collections import defaultdict 6 | 7 | from tqdm import tqdm 8 | 9 | from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset 10 | from taming.data.helper_types import Annotation, ImageDescription, Category 11 | 12 | COCO_PATH_STRUCTURE = { 13 | 'train': { 14 | 'top_level': '', 15 | 'instances_annotations': 'annotations/instances_train2017.json', 16 | 'stuff_annotations': 'annotations/stuff_train2017.json', 17 | 'files': 'train2017' 18 | }, 19 | 'validation': { 20 | 'top_level': '', 21 | 'instances_annotations': 'annotations/instances_val2017.json', 22 | 'stuff_annotations': 'annotations/stuff_val2017.json', 23 | 'files': 'val2017' 24 | } 25 | } 26 | 27 | 28 | def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]: 29 | return { 30 | str(img['id']): ImageDescription( 31 | id=img['id'], 32 | license=img.get('license'), 33 | file_name=img['file_name'], 34 | coco_url=img['coco_url'], 35 | original_size=(img['width'], img['height']), 36 | date_captured=img.get('date_captured'), 37 | flickr_url=img.get('flickr_url') 38 | ) 39 | for img in description_json 40 | } 41 | 42 | 43 | def load_categories(category_json: Iterable) -> Dict[str, Category]: 44 | return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name']) 45 | for cat in category_json if cat['name'] != 'other'} 46 | 47 | 48 | def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription], 49 | category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]: 50 | annotations = defaultdict(list) 51 | total = sum(len(a) for a in annotations_json) 52 | for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total): 53 | image_id = str(ann['image_id']) 54 | if image_id not in image_descriptions: 55 | raise ValueError(f'image_id [{image_id}] has no image description.') 56 | category_id = ann['category_id'] 57 | try: 58 | category_no = category_no_for_id(str(category_id)) 59 | except KeyError: 60 | continue 61 | 62 | width, height = image_descriptions[image_id].original_size 63 | bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height) 64 | 65 | annotations[image_id].append( 66 | Annotation( 67 | id=ann['id'], 68 | area=bbox[2]*bbox[3], # use bbox area 69 | is_group_of=ann['iscrowd'], 70 | image_id=ann['image_id'], 71 | bbox=bbox, 72 | category_id=str(category_id), 73 | category_no=category_no 74 | ) 75 | ) 76 | return dict(annotations) 77 | 78 | 79 | class AnnotatedObjectsCoco(AnnotatedObjectsDataset): 80 | def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs): 81 | """ 82 | @param data_path: is the path to the following folder structure: 83 | coco/ 84 | ├── annotations 85 | │ ├── instances_train2017.json 86 | │ ├── instances_val2017.json 87 | │ ├── stuff_train2017.json 88 | │ └── stuff_val2017.json 89 | ├── train2017 90 | │ ├── 000000000009.jpg 91 | │ ├── 000000000025.jpg 92 | │ └── ... 93 | ├── val2017 94 | │ ├── 000000000139.jpg 95 | │ ├── 000000000285.jpg 96 | │ └── ... 97 | @param: split: one of 'train' or 'validation' 98 | @param: desired image size (give square images) 99 | """ 100 | super().__init__(**kwargs) 101 | self.use_things = use_things 102 | self.use_stuff = use_stuff 103 | 104 | with open(self.paths['instances_annotations']) as f: 105 | inst_data_json = json.load(f) 106 | with open(self.paths['stuff_annotations']) as f: 107 | stuff_data_json = json.load(f) 108 | 109 | category_jsons = [] 110 | annotation_jsons = [] 111 | if self.use_things: 112 | category_jsons.append(inst_data_json['categories']) 113 | annotation_jsons.append(inst_data_json['annotations']) 114 | if self.use_stuff: 115 | category_jsons.append(stuff_data_json['categories']) 116 | annotation_jsons.append(stuff_data_json['annotations']) 117 | 118 | self.categories = load_categories(chain(*category_jsons)) 119 | self.filter_categories() 120 | self.setup_category_id_and_number() 121 | 122 | self.image_descriptions = load_image_descriptions(inst_data_json['images']) 123 | annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split) 124 | self.annotations = self.filter_object_number(annotations, self.min_object_area, 125 | self.min_objects_per_image, self.max_objects_per_image) 126 | self.image_ids = list(self.annotations.keys()) 127 | self.clean_up_annotations_and_image_descriptions() 128 | 129 | def get_path_structure(self) -> Dict[str, str]: 130 | if self.split not in COCO_PATH_STRUCTURE: 131 | raise ValueError(f'Split [{self.split} does not exist for COCO data.]') 132 | return COCO_PATH_STRUCTURE[self.split] 133 | 134 | def get_image_path(self, image_id: str) -> Path: 135 | return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name) 136 | 137 | def get_image_description(self, image_id: str) -> Dict[str, Any]: 138 | # noinspection PyProtectedMember 139 | return self.image_descriptions[image_id]._asdict() 140 | -------------------------------------------------------------------------------- /src/taming/data/annotated_objects_open_images.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from csv import DictReader, reader as TupleReader 3 | from pathlib import Path 4 | from typing import Dict, List, Any 5 | import warnings 6 | 7 | from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset 8 | from taming.data.helper_types import Annotation, Category 9 | from tqdm import tqdm 10 | 11 | OPEN_IMAGES_STRUCTURE = { 12 | 'train': { 13 | 'top_level': '', 14 | 'class_descriptions': 'class-descriptions-boxable.csv', 15 | 'annotations': 'oidv6-train-annotations-bbox.csv', 16 | 'file_list': 'train-images-boxable.csv', 17 | 'files': 'train' 18 | }, 19 | 'validation': { 20 | 'top_level': '', 21 | 'class_descriptions': 'class-descriptions-boxable.csv', 22 | 'annotations': 'validation-annotations-bbox.csv', 23 | 'file_list': 'validation-images.csv', 24 | 'files': 'validation' 25 | }, 26 | 'test': { 27 | 'top_level': '', 28 | 'class_descriptions': 'class-descriptions-boxable.csv', 29 | 'annotations': 'test-annotations-bbox.csv', 30 | 'file_list': 'test-images.csv', 31 | 'files': 'test' 32 | } 33 | } 34 | 35 | 36 | def load_annotations(descriptor_path: Path, min_object_area: float, category_mapping: Dict[str, str], 37 | category_no_for_id: Dict[str, int]) -> Dict[str, List[Annotation]]: 38 | annotations: Dict[str, List[Annotation]] = defaultdict(list) 39 | with open(descriptor_path) as file: 40 | reader = DictReader(file) 41 | for i, row in tqdm(enumerate(reader), total=14620000, desc='Loading OpenImages annotations'): 42 | width = float(row['XMax']) - float(row['XMin']) 43 | height = float(row['YMax']) - float(row['YMin']) 44 | area = width * height 45 | category_id = row['LabelName'] 46 | if category_id in category_mapping: 47 | category_id = category_mapping[category_id] 48 | if area >= min_object_area and category_id in category_no_for_id: 49 | annotations[row['ImageID']].append( 50 | Annotation( 51 | id=i, 52 | image_id=row['ImageID'], 53 | source=row['Source'], 54 | category_id=category_id, 55 | category_no=category_no_for_id[category_id], 56 | confidence=float(row['Confidence']), 57 | bbox=(float(row['XMin']), float(row['YMin']), width, height), 58 | area=area, 59 | is_occluded=bool(int(row['IsOccluded'])), 60 | is_truncated=bool(int(row['IsTruncated'])), 61 | is_group_of=bool(int(row['IsGroupOf'])), 62 | is_depiction=bool(int(row['IsDepiction'])), 63 | is_inside=bool(int(row['IsInside'])) 64 | ) 65 | ) 66 | if 'train' in str(descriptor_path) and i < 14000000: 67 | warnings.warn(f'Running with subset of Open Images. Train dataset has length [{len(annotations)}].') 68 | return dict(annotations) 69 | 70 | 71 | def load_image_ids(csv_path: Path) -> List[str]: 72 | with open(csv_path) as file: 73 | reader = DictReader(file) 74 | return [row['image_name'] for row in reader] 75 | 76 | 77 | def load_categories(csv_path: Path) -> Dict[str, Category]: 78 | with open(csv_path) as file: 79 | reader = TupleReader(file) 80 | return {row[0]: Category(id=row[0], name=row[1], super_category=None) for row in reader} 81 | 82 | 83 | class AnnotatedObjectsOpenImages(AnnotatedObjectsDataset): 84 | def __init__(self, use_additional_parameters: bool, **kwargs): 85 | """ 86 | @param data_path: is the path to the following folder structure: 87 | open_images/ 88 | │ oidv6-train-annotations-bbox.csv 89 | ├── class-descriptions-boxable.csv 90 | ├── oidv6-train-annotations-bbox.csv 91 | ├── test 92 | │ ├── 000026e7ee790996.jpg 93 | │ ├── 000062a39995e348.jpg 94 | │ └── ... 95 | ├── test-annotations-bbox.csv 96 | ├── test-images.csv 97 | ├── train 98 | │ ├── 000002b66c9c498e.jpg 99 | │ ├── 000002b97e5471a0.jpg 100 | │ └── ... 101 | ├── train-images-boxable.csv 102 | ├── validation 103 | │ ├── 0001eeaf4aed83f9.jpg 104 | │ ├── 0004886b7d043cfd.jpg 105 | │ └── ... 106 | ├── validation-annotations-bbox.csv 107 | └── validation-images.csv 108 | @param: split: one of 'train', 'validation' or 'test' 109 | @param: desired image size (returns square images) 110 | """ 111 | 112 | super().__init__(**kwargs) 113 | self.use_additional_parameters = use_additional_parameters 114 | 115 | self.categories = load_categories(self.paths['class_descriptions']) 116 | self.filter_categories() 117 | self.setup_category_id_and_number() 118 | 119 | self.image_descriptions = {} 120 | annotations = load_annotations(self.paths['annotations'], self.min_object_area, self.category_mapping, 121 | self.category_number) 122 | self.annotations = self.filter_object_number(annotations, self.min_object_area, self.min_objects_per_image, 123 | self.max_objects_per_image) 124 | self.image_ids = list(self.annotations.keys()) 125 | self.clean_up_annotations_and_image_descriptions() 126 | 127 | def get_path_structure(self) -> Dict[str, str]: 128 | if self.split not in OPEN_IMAGES_STRUCTURE: 129 | raise ValueError(f'Split [{self.split} does not exist for Open Images data.]') 130 | return OPEN_IMAGES_STRUCTURE[self.split] 131 | 132 | def get_image_path(self, image_id: str) -> Path: 133 | return self.paths['files'].joinpath(f'{image_id:0>16}.jpg') 134 | 135 | def get_image_description(self, image_id: str) -> Dict[str, Any]: 136 | image_path = self.get_image_path(image_id) 137 | return {'file_path': str(image_path), 'file_name': image_path.name} 138 | -------------------------------------------------------------------------------- /src/taming/data/base.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import numpy as np 3 | import albumentations 4 | from PIL import Image 5 | from torch.utils.data import Dataset, ConcatDataset 6 | 7 | 8 | class ConcatDatasetWithIndex(ConcatDataset): 9 | """Modified from original pytorch code to return dataset idx""" 10 | def __getitem__(self, idx): 11 | if idx < 0: 12 | if -idx > len(self): 13 | raise ValueError("absolute value of index should not exceed dataset length") 14 | idx = len(self) + idx 15 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 16 | if dataset_idx == 0: 17 | sample_idx = idx 18 | else: 19 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 20 | return self.datasets[dataset_idx][sample_idx], dataset_idx 21 | 22 | 23 | class ImagePaths(Dataset): 24 | def __init__(self, paths, size=None, random_crop=False, labels=None): 25 | self.size = size 26 | self.random_crop = random_crop 27 | 28 | self.labels = dict() if labels is None else labels 29 | self.labels["file_path_"] = paths 30 | self._length = len(paths) 31 | 32 | if self.size is not None and self.size > 0: 33 | self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) 34 | if not self.random_crop: 35 | self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) 36 | else: 37 | self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) 38 | self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) 39 | else: 40 | self.preprocessor = lambda **kwargs: kwargs 41 | 42 | def __len__(self): 43 | return self._length 44 | 45 | def preprocess_image(self, image_path): 46 | image = Image.open(image_path) 47 | if not image.mode == "RGB": 48 | image = image.convert("RGB") 49 | image = np.array(image).astype(np.uint8) 50 | image = self.preprocessor(image=image)["image"] 51 | image = (image/127.5 - 1.0).astype(np.float32) 52 | return image 53 | 54 | def __getitem__(self, i): 55 | example = dict() 56 | example["image"] = self.preprocess_image(self.labels["file_path_"][i]) 57 | for k in self.labels: 58 | example[k] = self.labels[k][i] 59 | return example 60 | 61 | 62 | class NumpyPaths(ImagePaths): 63 | def preprocess_image(self, image_path): 64 | image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024 65 | image = np.transpose(image, (1,2,0)) 66 | image = Image.fromarray(image, mode="RGB") 67 | image = np.array(image).astype(np.uint8) 68 | image = self.preprocessor(image=image)["image"] 69 | image = (image/127.5 - 1.0).astype(np.float32) 70 | return image 71 | -------------------------------------------------------------------------------- /src/taming/data/conditional_builder/objects_bbox.py: -------------------------------------------------------------------------------- 1 | from itertools import cycle 2 | from typing import List, Tuple, Callable, Optional 3 | 4 | from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont 5 | from more_itertools.recipes import grouper 6 | from taming.data.image_transforms import convert_pil_to_tensor 7 | from torch import LongTensor, Tensor 8 | 9 | from taming.data.helper_types import BoundingBox, Annotation 10 | from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder 11 | from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \ 12 | pad_list, get_plot_font_size, absolute_bbox 13 | 14 | 15 | class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder): 16 | @property 17 | def object_descriptor_length(self) -> int: 18 | return 3 19 | 20 | def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: 21 | object_triples = [ 22 | (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox)) 23 | for ann in annotations 24 | ] 25 | empty_triple = (self.none, self.none, self.none) 26 | object_triples = pad_list(object_triples, empty_triple, self.no_max_objects) 27 | return object_triples 28 | 29 | def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]: 30 | conditional_list = conditional.tolist() 31 | crop_coordinates = None 32 | if self.encode_crop: 33 | crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1]) 34 | conditional_list = conditional_list[:-2] 35 | object_triples = grouper(conditional_list, 3) 36 | assert conditional.shape[0] == self.embedding_dim 37 | return [ 38 | (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2])) 39 | for object_triple in object_triples if object_triple[0] != self.none 40 | ], crop_coordinates 41 | 42 | def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], 43 | line_width: int = 3, font_size: Optional[int] = None) -> Tensor: 44 | plot = pil_image.new('RGB', figure_size, WHITE) 45 | draw = pil_img_draw.Draw(plot) 46 | font = ImageFont.truetype( 47 | "/usr/share/fonts/truetype/lato/Lato-Regular.ttf", 48 | size=get_plot_font_size(font_size, figure_size) 49 | ) 50 | width, height = plot.size 51 | description, crop_coordinates = self.inverse_build(conditional) 52 | for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)): 53 | annotation = self.representation_to_annotation(representation) 54 | class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation) 55 | bbox = absolute_bbox(bbox, width, height) 56 | draw.rectangle(bbox, outline=color, width=line_width) 57 | draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font) 58 | if crop_coordinates is not None: 59 | draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) 60 | return convert_pil_to_tensor(plot) / 127.5 - 1. 61 | -------------------------------------------------------------------------------- /src/taming/data/conditional_builder/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from typing import List, Any, Tuple, Optional 3 | 4 | from taming.data.helper_types import BoundingBox, Annotation 5 | 6 | # source: seaborn, color palette tab10 7 | COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188), 8 | (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)] 9 | BLACK = (0, 0, 0) 10 | GRAY_75 = (63, 63, 63) 11 | GRAY_50 = (127, 127, 127) 12 | GRAY_25 = (191, 191, 191) 13 | WHITE = (255, 255, 255) 14 | FULL_CROP = (0., 0., 1., 1.) 15 | 16 | 17 | def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float: 18 | """ 19 | Give intersection area of two rectangles. 20 | @param rectangle1: (x0, y0, w, h) of first rectangle 21 | @param rectangle2: (x0, y0, w, h) of second rectangle 22 | """ 23 | rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3] 24 | rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3] 25 | x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0])) 26 | y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1])) 27 | return x_overlap * y_overlap 28 | 29 | 30 | def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox: 31 | return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3] 32 | 33 | 34 | def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]: 35 | bbox = relative_bbox 36 | bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height 37 | return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]) 38 | 39 | 40 | def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List: 41 | return list_ + [pad_element for _ in range(pad_to_length - len(list_))] 42 | 43 | 44 | def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \ 45 | List[Annotation]: 46 | def clamp(x: float): 47 | return max(min(x, 1.), 0.) 48 | 49 | def rescale_bbox(bbox: BoundingBox) -> BoundingBox: 50 | x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) 51 | y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) 52 | w = min(bbox[2] / crop_coordinates[2], 1 - x0) 53 | h = min(bbox[3] / crop_coordinates[3], 1 - y0) 54 | if flip: 55 | x0 = 1 - (x0 + w) 56 | return x0, y0, w, h 57 | 58 | return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations] 59 | 60 | 61 | def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List: 62 | return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0] 63 | 64 | 65 | def additional_parameters_string(annotation: Annotation, short: bool = True) -> str: 66 | sl = slice(1) if short else slice(None) 67 | string = '' 68 | if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside): 69 | return string 70 | if annotation.is_group_of: 71 | string += 'group'[sl] + ',' 72 | if annotation.is_occluded: 73 | string += 'occluded'[sl] + ',' 74 | if annotation.is_depiction: 75 | string += 'depiction'[sl] + ',' 76 | if annotation.is_inside: 77 | string += 'inside'[sl] 78 | return '(' + string.strip(",") + ')' 79 | 80 | 81 | def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int: 82 | if font_size is None: 83 | font_size = 10 84 | if max(figure_size) >= 256: 85 | font_size = 12 86 | if max(figure_size) >= 512: 87 | font_size = 15 88 | return font_size 89 | 90 | 91 | def get_circle_size(figure_size: Tuple[int, int]) -> int: 92 | circle_size = 2 93 | if max(figure_size) >= 256: 94 | circle_size = 3 95 | if max(figure_size) >= 512: 96 | circle_size = 4 97 | return circle_size 98 | 99 | 100 | def load_object_from_string(object_string: str) -> Any: 101 | """ 102 | Source: https://stackoverflow.com/a/10773699 103 | """ 104 | module_name, class_name = object_string.rsplit(".", 1) 105 | return getattr(importlib.import_module(module_name), class_name) 106 | -------------------------------------------------------------------------------- /src/taming/data/custom.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import albumentations 4 | from torch.utils.data import Dataset 5 | 6 | from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex 7 | 8 | 9 | class CustomBase(Dataset): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__() 12 | self.data = None 13 | 14 | def __len__(self): 15 | return len(self.data) 16 | 17 | def __getitem__(self, i): 18 | example = self.data[i] 19 | return example 20 | 21 | 22 | 23 | class CustomTrain(CustomBase): 24 | def __init__(self, size, training_images_list_file): 25 | super().__init__() 26 | with open(training_images_list_file, "r") as f: 27 | paths = f.read().splitlines() 28 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 29 | 30 | 31 | class CustomTest(CustomBase): 32 | def __init__(self, size, test_images_list_file): 33 | super().__init__() 34 | with open(test_images_list_file, "r") as f: 35 | paths = f.read().splitlines() 36 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 37 | 38 | 39 | -------------------------------------------------------------------------------- /src/taming/data/faceshq.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import albumentations 4 | from torch.utils.data import Dataset 5 | 6 | from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex 7 | 8 | 9 | class FacesBase(Dataset): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__() 12 | self.data = None 13 | self.keys = None 14 | 15 | def __len__(self): 16 | return len(self.data) 17 | 18 | def __getitem__(self, i): 19 | example = self.data[i] 20 | ex = {} 21 | if self.keys is not None: 22 | for k in self.keys: 23 | ex[k] = example[k] 24 | else: 25 | ex = example 26 | return ex 27 | 28 | 29 | class CelebAHQTrain(FacesBase): 30 | def __init__(self, size, keys=None): 31 | super().__init__() 32 | root = "data/celebahq" 33 | with open("data/celebahqtrain.txt", "r") as f: 34 | relpaths = f.read().splitlines() 35 | paths = [os.path.join(root, relpath) for relpath in relpaths] 36 | self.data = NumpyPaths(paths=paths, size=size, random_crop=False) 37 | self.keys = keys 38 | 39 | 40 | class CelebAHQValidation(FacesBase): 41 | def __init__(self, size, keys=None): 42 | super().__init__() 43 | root = "data/celebahq" 44 | with open("data/celebahqvalidation.txt", "r") as f: 45 | relpaths = f.read().splitlines() 46 | paths = [os.path.join(root, relpath) for relpath in relpaths] 47 | self.data = NumpyPaths(paths=paths, size=size, random_crop=False) 48 | self.keys = keys 49 | 50 | 51 | class FFHQTrain(FacesBase): 52 | def __init__(self, size, keys=None): 53 | super().__init__() 54 | root = "data/ffhq" 55 | with open("data/ffhqtrain.txt", "r") as f: 56 | relpaths = f.read().splitlines() 57 | paths = [os.path.join(root, relpath) for relpath in relpaths] 58 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 59 | self.keys = keys 60 | 61 | 62 | class FFHQValidation(FacesBase): 63 | def __init__(self, size, keys=None): 64 | super().__init__() 65 | root = "data/ffhq" 66 | with open("data/ffhqvalidation.txt", "r") as f: 67 | relpaths = f.read().splitlines() 68 | paths = [os.path.join(root, relpath) for relpath in relpaths] 69 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 70 | self.keys = keys 71 | 72 | 73 | class FacesHQTrain(Dataset): 74 | # CelebAHQ [0] + FFHQ [1] 75 | def __init__(self, size, keys=None, crop_size=None, coord=False): 76 | d1 = CelebAHQTrain(size=size, keys=keys) 77 | d2 = FFHQTrain(size=size, keys=keys) 78 | self.data = ConcatDatasetWithIndex([d1, d2]) 79 | self.coord = coord 80 | if crop_size is not None: 81 | self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size) 82 | if self.coord: 83 | self.cropper = albumentations.Compose([self.cropper], 84 | additional_targets={"coord": "image"}) 85 | 86 | def __len__(self): 87 | return len(self.data) 88 | 89 | def __getitem__(self, i): 90 | ex, y = self.data[i] 91 | if hasattr(self, "cropper"): 92 | if not self.coord: 93 | out = self.cropper(image=ex["image"]) 94 | ex["image"] = out["image"] 95 | else: 96 | h,w,_ = ex["image"].shape 97 | coord = np.arange(h*w).reshape(h,w,1)/(h*w) 98 | out = self.cropper(image=ex["image"], coord=coord) 99 | ex["image"] = out["image"] 100 | ex["coord"] = out["coord"] 101 | ex["class"] = y 102 | return ex 103 | 104 | 105 | class FacesHQValidation(Dataset): 106 | # CelebAHQ [0] + FFHQ [1] 107 | def __init__(self, size, keys=None, crop_size=None, coord=False): 108 | d1 = CelebAHQValidation(size=size, keys=keys) 109 | d2 = FFHQValidation(size=size, keys=keys) 110 | self.data = ConcatDatasetWithIndex([d1, d2]) 111 | self.coord = coord 112 | if crop_size is not None: 113 | self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size) 114 | if self.coord: 115 | self.cropper = albumentations.Compose([self.cropper], 116 | additional_targets={"coord": "image"}) 117 | 118 | def __len__(self): 119 | return len(self.data) 120 | 121 | def __getitem__(self, i): 122 | ex, y = self.data[i] 123 | if hasattr(self, "cropper"): 124 | if not self.coord: 125 | out = self.cropper(image=ex["image"]) 126 | ex["image"] = out["image"] 127 | else: 128 | h,w,_ = ex["image"].shape 129 | coord = np.arange(h*w).reshape(h,w,1)/(h*w) 130 | out = self.cropper(image=ex["image"], coord=coord) 131 | ex["image"] = out["image"] 132 | ex["coord"] = out["coord"] 133 | ex["class"] = y 134 | return ex 135 | -------------------------------------------------------------------------------- /src/taming/data/helper_types.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple, Optional, NamedTuple, Union 2 | from PIL.Image import Image as pil_image 3 | from torch import Tensor 4 | 5 | try: 6 | from typing import Literal 7 | except ImportError: 8 | from typing_extensions import Literal 9 | 10 | Image = Union[Tensor, pil_image] 11 | BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h 12 | CropMethodType = Literal['none', 'random', 'center', 'random-2d'] 13 | SplitType = Literal['train', 'validation', 'test'] 14 | 15 | 16 | class ImageDescription(NamedTuple): 17 | id: int 18 | file_name: str 19 | original_size: Tuple[int, int] # w, h 20 | url: Optional[str] = None 21 | license: Optional[int] = None 22 | coco_url: Optional[str] = None 23 | date_captured: Optional[str] = None 24 | flickr_url: Optional[str] = None 25 | flickr_id: Optional[str] = None 26 | coco_id: Optional[str] = None 27 | 28 | 29 | class Category(NamedTuple): 30 | id: str 31 | super_category: Optional[str] 32 | name: str 33 | 34 | 35 | class Annotation(NamedTuple): 36 | area: float 37 | image_id: str 38 | bbox: BoundingBox 39 | category_no: int 40 | category_id: str 41 | id: Optional[int] = None 42 | source: Optional[str] = None 43 | confidence: Optional[float] = None 44 | is_group_of: Optional[bool] = None 45 | is_truncated: Optional[bool] = None 46 | is_occluded: Optional[bool] = None 47 | is_depiction: Optional[bool] = None 48 | is_inside: Optional[bool] = None 49 | segmentation: Optional[Dict] = None 50 | -------------------------------------------------------------------------------- /src/taming/data/image_transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import warnings 3 | from typing import Union 4 | 5 | import torch 6 | from torch import Tensor 7 | from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor 8 | from torchvision.transforms.functional import _get_image_size as get_image_size 9 | 10 | from taming.data.helper_types import BoundingBox, Image 11 | 12 | pil_to_tensor = PILToTensor() 13 | 14 | 15 | def convert_pil_to_tensor(image: Image) -> Tensor: 16 | with warnings.catch_warnings(): 17 | # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194 18 | warnings.simplefilter("ignore") 19 | return pil_to_tensor(image) 20 | 21 | 22 | class RandomCrop1dReturnCoordinates(RandomCrop): 23 | def forward(self, img: Image) -> (BoundingBox, Image): 24 | """ 25 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 26 | Args: 27 | img (PIL Image or Tensor): Image to be cropped. 28 | 29 | Returns: 30 | Bounding box: x0, y0, w, h 31 | PIL Image or Tensor: Cropped image. 32 | 33 | Based on: 34 | torchvision.transforms.RandomCrop, torchvision 1.7.0 35 | """ 36 | if self.padding is not None: 37 | img = F.pad(img, self.padding, self.fill, self.padding_mode) 38 | 39 | width, height = get_image_size(img) 40 | # pad the width if needed 41 | if self.pad_if_needed and width < self.size[1]: 42 | padding = [self.size[1] - width, 0] 43 | img = F.pad(img, padding, self.fill, self.padding_mode) 44 | # pad the height if needed 45 | if self.pad_if_needed and height < self.size[0]: 46 | padding = [0, self.size[0] - height] 47 | img = F.pad(img, padding, self.fill, self.padding_mode) 48 | 49 | i, j, h, w = self.get_params(img, self.size) 50 | bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h 51 | return bbox, F.crop(img, i, j, h, w) 52 | 53 | 54 | class Random2dCropReturnCoordinates(torch.nn.Module): 55 | """ 56 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 57 | Args: 58 | img (PIL Image or Tensor): Image to be cropped. 59 | 60 | Returns: 61 | Bounding box: x0, y0, w, h 62 | PIL Image or Tensor: Cropped image. 63 | 64 | Based on: 65 | torchvision.transforms.RandomCrop, torchvision 1.7.0 66 | """ 67 | 68 | def __init__(self, min_size: int): 69 | super().__init__() 70 | self.min_size = min_size 71 | 72 | def forward(self, img: Image) -> (BoundingBox, Image): 73 | width, height = get_image_size(img) 74 | max_size = min(width, height) 75 | if max_size <= self.min_size: 76 | size = max_size 77 | else: 78 | size = random.randint(self.min_size, max_size) 79 | top = random.randint(0, height - size) 80 | left = random.randint(0, width - size) 81 | bbox = left / width, top / height, size / width, size / height 82 | return bbox, F.crop(img, top, left, size, size) 83 | 84 | 85 | class CenterCropReturnCoordinates(CenterCrop): 86 | @staticmethod 87 | def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox: 88 | if width > height: 89 | w = height / width 90 | h = 1.0 91 | x0 = 0.5 - w / 2 92 | y0 = 0. 93 | else: 94 | w = 1.0 95 | h = width / height 96 | x0 = 0. 97 | y0 = 0.5 - h / 2 98 | return x0, y0, w, h 99 | 100 | def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]): 101 | """ 102 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 103 | Args: 104 | img (PIL Image or Tensor): Image to be cropped. 105 | 106 | Returns: 107 | Bounding box: x0, y0, w, h 108 | PIL Image or Tensor: Cropped image. 109 | Based on: 110 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0) 111 | """ 112 | width, height = get_image_size(img) 113 | return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size) 114 | 115 | 116 | class RandomHorizontalFlipReturn(RandomHorizontalFlip): 117 | def forward(self, img: Image) -> (bool, Image): 118 | """ 119 | Additionally to flipping, returns a boolean whether it was flipped or not. 120 | Args: 121 | img (PIL Image or Tensor): Image to be flipped. 122 | 123 | Returns: 124 | flipped: whether the image was flipped or not 125 | PIL Image or Tensor: Randomly flipped image. 126 | 127 | Based on: 128 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0) 129 | """ 130 | if torch.rand(1) < self.p: 131 | return True, F.hflip(img) 132 | return False, img 133 | -------------------------------------------------------------------------------- /src/taming/data/sflckr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import albumentations 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class SegmentationBase(Dataset): 10 | def __init__(self, 11 | data_csv, data_root, segmentation_root, 12 | size=None, random_crop=False, interpolation="bicubic", 13 | n_labels=182, shift_segmentation=False, 14 | ): 15 | self.n_labels = n_labels 16 | self.shift_segmentation = shift_segmentation 17 | self.data_csv = data_csv 18 | self.data_root = data_root 19 | self.segmentation_root = segmentation_root 20 | with open(self.data_csv, "r") as f: 21 | self.image_paths = f.read().splitlines() 22 | self._length = len(self.image_paths) 23 | self.labels = { 24 | "relative_file_path_": [l for l in self.image_paths], 25 | "file_path_": [os.path.join(self.data_root, l) 26 | for l in self.image_paths], 27 | "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png")) 28 | for l in self.image_paths] 29 | } 30 | 31 | size = None if size is not None and size<=0 else size 32 | self.size = size 33 | if self.size is not None: 34 | self.interpolation = interpolation 35 | self.interpolation = { 36 | "nearest": cv2.INTER_NEAREST, 37 | "bilinear": cv2.INTER_LINEAR, 38 | "bicubic": cv2.INTER_CUBIC, 39 | "area": cv2.INTER_AREA, 40 | "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] 41 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 42 | interpolation=self.interpolation) 43 | self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 44 | interpolation=cv2.INTER_NEAREST) 45 | self.center_crop = not random_crop 46 | if self.center_crop: 47 | self.cropper = albumentations.CenterCrop(height=self.size, width=self.size) 48 | else: 49 | self.cropper = albumentations.RandomCrop(height=self.size, width=self.size) 50 | self.preprocessor = self.cropper 51 | 52 | def __len__(self): 53 | return self._length 54 | 55 | def __getitem__(self, i): 56 | example = dict((k, self.labels[k][i]) for k in self.labels) 57 | image = Image.open(example["file_path_"]) 58 | if not image.mode == "RGB": 59 | image = image.convert("RGB") 60 | image = np.array(image).astype(np.uint8) 61 | if self.size is not None: 62 | image = self.image_rescaler(image=image)["image"] 63 | segmentation = Image.open(example["segmentation_path_"]) 64 | assert segmentation.mode == "L", segmentation.mode 65 | segmentation = np.array(segmentation).astype(np.uint8) 66 | if self.shift_segmentation: 67 | # used to support segmentations containing unlabeled==255 label 68 | segmentation = segmentation+1 69 | if self.size is not None: 70 | segmentation = self.segmentation_rescaler(image=segmentation)["image"] 71 | if self.size is not None: 72 | processed = self.preprocessor(image=image, 73 | mask=segmentation 74 | ) 75 | else: 76 | processed = {"image": image, 77 | "mask": segmentation 78 | } 79 | example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) 80 | segmentation = processed["mask"] 81 | onehot = np.eye(self.n_labels)[segmentation] 82 | example["segmentation"] = onehot 83 | return example 84 | 85 | 86 | class Examples(SegmentationBase): 87 | def __init__(self, size=None, random_crop=False, interpolation="bicubic"): 88 | super().__init__(data_csv="data/sflckr_examples.txt", 89 | data_root="data/sflckr_images", 90 | segmentation_root="data/sflckr_segmentations", 91 | size=size, random_crop=random_crop, interpolation=interpolation) 92 | -------------------------------------------------------------------------------- /src/taming/data/utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import tarfile 4 | import urllib 5 | import zipfile 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | import torch 10 | from taming.data.helper_types import Annotation 11 | from torch._six import string_classes 12 | from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format 13 | from tqdm import tqdm 14 | 15 | 16 | def unpack(path): 17 | if path.endswith("tar.gz"): 18 | with tarfile.open(path, "r:gz") as tar: 19 | tar.extractall(path=os.path.split(path)[0]) 20 | elif path.endswith("tar"): 21 | with tarfile.open(path, "r:") as tar: 22 | tar.extractall(path=os.path.split(path)[0]) 23 | elif path.endswith("zip"): 24 | with zipfile.ZipFile(path, "r") as f: 25 | f.extractall(path=os.path.split(path)[0]) 26 | else: 27 | raise NotImplementedError( 28 | "Unknown file extension: {}".format(os.path.splitext(path)[1]) 29 | ) 30 | 31 | 32 | def reporthook(bar): 33 | """tqdm progress bar for downloads.""" 34 | 35 | def hook(b=1, bsize=1, tsize=None): 36 | if tsize is not None: 37 | bar.total = tsize 38 | bar.update(b * bsize - bar.n) 39 | 40 | return hook 41 | 42 | 43 | def get_root(name): 44 | base = "data/" 45 | root = os.path.join(base, name) 46 | os.makedirs(root, exist_ok=True) 47 | return root 48 | 49 | 50 | def is_prepared(root): 51 | return Path(root).joinpath(".ready").exists() 52 | 53 | 54 | def mark_prepared(root): 55 | Path(root).joinpath(".ready").touch() 56 | 57 | 58 | def prompt_download(file_, source, target_dir, content_dir=None): 59 | targetpath = os.path.join(target_dir, file_) 60 | while not os.path.exists(targetpath): 61 | if content_dir is not None and os.path.exists( 62 | os.path.join(target_dir, content_dir) 63 | ): 64 | break 65 | print( 66 | "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath) 67 | ) 68 | if content_dir is not None: 69 | print( 70 | "Or place its content into '{}'.".format( 71 | os.path.join(target_dir, content_dir) 72 | ) 73 | ) 74 | input("Press Enter when done...") 75 | return targetpath 76 | 77 | 78 | def download_url(file_, url, target_dir): 79 | targetpath = os.path.join(target_dir, file_) 80 | os.makedirs(target_dir, exist_ok=True) 81 | with tqdm( 82 | unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_ 83 | ) as bar: 84 | urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar)) 85 | return targetpath 86 | 87 | 88 | def download_urls(urls, target_dir): 89 | paths = dict() 90 | for fname, url in urls.items(): 91 | outpath = download_url(fname, url, target_dir) 92 | paths[fname] = outpath 93 | return paths 94 | 95 | 96 | def quadratic_crop(x, bbox, alpha=1.0): 97 | """bbox is xmin, ymin, xmax, ymax""" 98 | im_h, im_w = x.shape[:2] 99 | bbox = np.array(bbox, dtype=np.float32) 100 | bbox = np.clip(bbox, 0, max(im_h, im_w)) 101 | center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3]) 102 | w = bbox[2] - bbox[0] 103 | h = bbox[3] - bbox[1] 104 | l = int(alpha * max(w, h)) 105 | l = max(l, 2) 106 | 107 | required_padding = -1 * min( 108 | center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l) 109 | ) 110 | required_padding = int(np.ceil(required_padding)) 111 | if required_padding > 0: 112 | padding = [ 113 | [required_padding, required_padding], 114 | [required_padding, required_padding], 115 | ] 116 | padding += [[0, 0]] * (len(x.shape) - 2) 117 | x = np.pad(x, padding, "reflect") 118 | center = center[0] + required_padding, center[1] + required_padding 119 | xmin = int(center[0] - l / 2) 120 | ymin = int(center[1] - l / 2) 121 | return np.array(x[ymin : ymin + l, xmin : xmin + l, ...]) 122 | 123 | 124 | def custom_collate(batch): 125 | r"""source: pytorch 1.9.0, only one modification to original code """ 126 | 127 | elem = batch[0] 128 | elem_type = type(elem) 129 | if isinstance(elem, torch.Tensor): 130 | out = None 131 | if torch.utils.data.get_worker_info() is not None: 132 | # If we're in a background process, concatenate directly into a 133 | # shared memory tensor to avoid an extra copy 134 | numel = sum([x.numel() for x in batch]) 135 | storage = elem.storage()._new_shared(numel) 136 | out = elem.new(storage) 137 | return torch.stack(batch, 0, out=out) 138 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 139 | and elem_type.__name__ != 'string_': 140 | if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': 141 | # array of string classes and object 142 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None: 143 | raise TypeError(default_collate_err_msg_format.format(elem.dtype)) 144 | 145 | return custom_collate([torch.as_tensor(b) for b in batch]) 146 | elif elem.shape == (): # scalars 147 | return torch.as_tensor(batch) 148 | elif isinstance(elem, float): 149 | return torch.tensor(batch, dtype=torch.float64) 150 | elif isinstance(elem, int): 151 | return torch.tensor(batch) 152 | elif isinstance(elem, string_classes): 153 | return batch 154 | elif isinstance(elem, collections.abc.Mapping): 155 | return {key: custom_collate([d[key] for d in batch]) for key in elem} 156 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple 157 | return elem_type(*(custom_collate(samples) for samples in zip(*batch))) 158 | if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added 159 | return batch # added 160 | elif isinstance(elem, collections.abc.Sequence): 161 | # check to make sure that the elements in batch have consistent size 162 | it = iter(batch) 163 | elem_size = len(next(it)) 164 | if not all(len(elem) == elem_size for elem in it): 165 | raise RuntimeError('each element in list of batch should be of equal size') 166 | transposed = zip(*batch) 167 | return [custom_collate(samples) for samples in transposed] 168 | 169 | raise TypeError(default_collate_err_msg_format.format(elem_type)) 170 | -------------------------------------------------------------------------------- /src/taming/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n): 33 | return self.schedule(n) 34 | 35 | -------------------------------------------------------------------------------- /src/taming/models/dummy_cond_stage.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | 4 | class DummyCondStage: 5 | def __init__(self, conditional_key): 6 | self.conditional_key = conditional_key 7 | self.train = None 8 | 9 | def eval(self): 10 | return self 11 | 12 | @staticmethod 13 | def encode(c: Tensor): 14 | return c, None, (None, None, c) 15 | 16 | @staticmethod 17 | def decode(c: Tensor): 18 | return c 19 | 20 | @staticmethod 21 | def to_rgb(c: Tensor): 22 | return c 23 | -------------------------------------------------------------------------------- /src/taming/modules/discriminator/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | 4 | 5 | from taming.modules.util import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find('Conv') != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find('BatchNorm') != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class NLayerDiscriminator(nn.Module): 18 | """Defines a PatchGAN discriminator as in Pix2Pix 19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 20 | """ 21 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 22 | """Construct a PatchGAN discriminator 23 | Parameters: 24 | input_nc (int) -- the number of channels in input images 25 | ndf (int) -- the number of filters in the last conv layer 26 | n_layers (int) -- the number of conv layers in the discriminator 27 | norm_layer -- normalization layer 28 | """ 29 | super(NLayerDiscriminator, self).__init__() 30 | if not use_actnorm: 31 | norm_layer = nn.BatchNorm2d 32 | else: 33 | norm_layer = ActNorm 34 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 35 | use_bias = norm_layer.func != nn.BatchNorm2d 36 | else: 37 | use_bias = norm_layer != nn.BatchNorm2d 38 | 39 | kw = 4 40 | padw = 1 41 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 42 | nf_mult = 1 43 | nf_mult_prev = 1 44 | for n in range(1, n_layers): # gradually increase the number of filters 45 | nf_mult_prev = nf_mult 46 | nf_mult = min(2 ** n, 8) 47 | sequence += [ 48 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 49 | norm_layer(ndf * nf_mult), 50 | nn.LeakyReLU(0.2, True) 51 | ] 52 | 53 | nf_mult_prev = nf_mult 54 | nf_mult = min(2 ** n_layers, 8) 55 | sequence += [ 56 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 57 | norm_layer(ndf * nf_mult), 58 | nn.LeakyReLU(0.2, True) 59 | ] 60 | 61 | sequence += [ 62 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 63 | self.main = nn.Sequential(*sequence) 64 | 65 | def forward(self, input): 66 | """Standard forward.""" 67 | return self.main(input) 68 | -------------------------------------------------------------------------------- /src/taming/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from taming.modules.losses.vqperceptual import DummyLoss 2 | 3 | -------------------------------------------------------------------------------- /src/taming/modules/losses/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | 8 | from taming.util import get_ckpt_path 9 | 10 | 11 | class LPIPS(nn.Module): 12 | # Learned perceptual metric 13 | def __init__(self, use_dropout=True): 14 | super().__init__() 15 | self.scaling_layer = ScalingLayer() 16 | self.chns = [64, 128, 256, 512, 512] # vg16 features 17 | self.net = vgg16(pretrained=True, requires_grad=False) 18 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 19 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 20 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 21 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 22 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 23 | self.load_from_pretrained() 24 | for param in self.parameters(): 25 | param.requires_grad = False 26 | 27 | def load_from_pretrained(self, name="vgg_lpips"): 28 | ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") 29 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 30 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 31 | 32 | @classmethod 33 | def from_pretrained(cls, name="vgg_lpips"): 34 | if name != "vgg_lpips": 35 | raise NotImplementedError 36 | model = cls() 37 | ckpt = get_ckpt_path(name) 38 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 39 | return model 40 | 41 | def forward(self, input, target): 42 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 43 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 44 | feats0, feats1, diffs = {}, {}, {} 45 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 46 | for kk in range(len(self.chns)): 47 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 49 | 50 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 51 | val = res[0] 52 | for l in range(1, len(self.chns)): 53 | val += res[l] 54 | return val 55 | 56 | 57 | class ScalingLayer(nn.Module): 58 | def __init__(self): 59 | super(ScalingLayer, self).__init__() 60 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 61 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 62 | 63 | def forward(self, inp): 64 | return (inp - self.shift) / self.scale 65 | 66 | 67 | class NetLinLayer(nn.Module): 68 | """ A single linear layer which does a 1x1 conv """ 69 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 70 | super(NetLinLayer, self).__init__() 71 | layers = [nn.Dropout(), ] if (use_dropout) else [] 72 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 73 | self.model = nn.Sequential(*layers) 74 | 75 | 76 | class vgg16(torch.nn.Module): 77 | def __init__(self, requires_grad=False, pretrained=True): 78 | super(vgg16, self).__init__() 79 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 80 | self.slice1 = torch.nn.Sequential() 81 | self.slice2 = torch.nn.Sequential() 82 | self.slice3 = torch.nn.Sequential() 83 | self.slice4 = torch.nn.Sequential() 84 | self.slice5 = torch.nn.Sequential() 85 | self.N_slices = 5 86 | for x in range(4): 87 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 88 | for x in range(4, 9): 89 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 90 | for x in range(9, 16): 91 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 92 | for x in range(16, 23): 93 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 94 | for x in range(23, 30): 95 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 96 | if not requires_grad: 97 | for param in self.parameters(): 98 | param.requires_grad = False 99 | 100 | def forward(self, X): 101 | h = self.slice1(X) 102 | h_relu1_2 = h 103 | h = self.slice2(h) 104 | h_relu2_2 = h 105 | h = self.slice3(h) 106 | h_relu3_3 = h 107 | h = self.slice4(h) 108 | h_relu4_3 = h 109 | h = self.slice5(h) 110 | h_relu5_3 = h 111 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 112 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 113 | return out 114 | 115 | 116 | def normalize_tensor(x,eps=1e-10): 117 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 118 | return x/(norm_factor+eps) 119 | 120 | 121 | def spatial_average(x, keepdim=True): 122 | return x.mean([2,3],keepdim=keepdim) 123 | 124 | -------------------------------------------------------------------------------- /src/taming/modules/losses/segmentation.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class BCELoss(nn.Module): 6 | def forward(self, prediction, target): 7 | loss = F.binary_cross_entropy_with_logits(prediction,target) 8 | return loss, {} 9 | 10 | 11 | class BCELossWithQuant(nn.Module): 12 | def __init__(self, codebook_weight=1.): 13 | super().__init__() 14 | self.codebook_weight = codebook_weight 15 | 16 | def forward(self, qloss, target, prediction, split): 17 | bce_loss = F.binary_cross_entropy_with_logits(prediction,target) 18 | loss = bce_loss + self.codebook_weight*qloss 19 | return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(), 20 | "{}/bce_loss".format(split): bce_loss.detach().mean(), 21 | "{}/quant_loss".format(split): qloss.detach().mean() 22 | } 23 | -------------------------------------------------------------------------------- /src/taming/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from taming.modules.losses.lpips import LPIPS 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | 8 | 9 | class DummyLoss(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | 14 | def adopt_weight(weight, global_step, threshold=0, value=0.): 15 | if global_step < threshold: 16 | weight = value 17 | return weight 18 | 19 | 20 | def hinge_d_loss(logits_real, logits_fake): 21 | loss_real = torch.mean(F.relu(1. - logits_real)) 22 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 23 | d_loss = 0.5 * (loss_real + loss_fake) 24 | return d_loss 25 | 26 | 27 | def vanilla_d_loss(logits_real, logits_fake): 28 | d_loss = 0.5 * ( 29 | torch.mean(torch.nn.functional.softplus(-logits_real)) + 30 | torch.mean(torch.nn.functional.softplus(logits_fake))) 31 | return d_loss 32 | 33 | 34 | class VQLPIPSWithDiscriminator(nn.Module): 35 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 36 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 37 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 38 | disc_ndf=64, disc_loss="hinge"): 39 | super().__init__() 40 | assert disc_loss in ["hinge", "vanilla"] 41 | self.codebook_weight = codebook_weight 42 | self.pixel_weight = pixelloss_weight 43 | self.perceptual_loss = LPIPS().eval() 44 | self.perceptual_weight = perceptual_weight 45 | 46 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 47 | n_layers=disc_num_layers, 48 | use_actnorm=use_actnorm, 49 | ndf=disc_ndf 50 | ).apply(weights_init) 51 | self.discriminator_iter_start = disc_start 52 | if disc_loss == "hinge": 53 | self.disc_loss = hinge_d_loss 54 | elif disc_loss == "vanilla": 55 | self.disc_loss = vanilla_d_loss 56 | else: 57 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 58 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 59 | self.disc_factor = disc_factor 60 | self.discriminator_weight = disc_weight 61 | self.disc_conditional = disc_conditional 62 | 63 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 64 | if last_layer is not None: 65 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 66 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 67 | else: 68 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 69 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 70 | 71 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 72 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 73 | d_weight = d_weight * self.discriminator_weight 74 | return d_weight 75 | 76 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 77 | global_step, last_layer=None, cond=None, split="train"): 78 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 79 | if self.perceptual_weight > 0: 80 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 81 | rec_loss = rec_loss + self.perceptual_weight * p_loss 82 | else: 83 | p_loss = torch.tensor([0.0]) 84 | 85 | nll_loss = rec_loss 86 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 87 | nll_loss = torch.mean(nll_loss) 88 | 89 | # now the GAN part 90 | if optimizer_idx == 0: 91 | # generator update 92 | if cond is None: 93 | assert not self.disc_conditional 94 | logits_fake = self.discriminator(reconstructions.contiguous()) 95 | else: 96 | assert self.disc_conditional 97 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 98 | g_loss = -torch.mean(logits_fake) 99 | 100 | try: 101 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 102 | except RuntimeError: 103 | assert not self.training 104 | d_weight = torch.tensor(0.0) 105 | 106 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 107 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 108 | 109 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 110 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 111 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 112 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 113 | "{}/p_loss".format(split): p_loss.detach().mean(), 114 | "{}/d_weight".format(split): d_weight.detach(), 115 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 116 | "{}/g_loss".format(split): g_loss.detach().mean(), 117 | } 118 | return loss, log 119 | 120 | if optimizer_idx == 1: 121 | # second pass for discriminator update 122 | if cond is None: 123 | logits_real = self.discriminator(inputs.contiguous().detach()) 124 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 125 | else: 126 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 127 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 128 | 129 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 130 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 131 | 132 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 133 | "{}/logits_real".format(split): logits_real.detach().mean(), 134 | "{}/logits_fake".format(split): logits_fake.detach().mean() 135 | } 136 | return d_loss, log 137 | -------------------------------------------------------------------------------- /src/taming/modules/misc/coord.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class CoordStage(object): 4 | def __init__(self, n_embed, down_factor): 5 | self.n_embed = n_embed 6 | self.down_factor = down_factor 7 | 8 | def eval(self): 9 | return self 10 | 11 | def encode(self, c): 12 | """fake vqmodel interface""" 13 | assert 0.0 <= c.min() and c.max() <= 1.0 14 | b,ch,h,w = c.shape 15 | assert ch == 1 16 | 17 | c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor, 18 | mode="area") 19 | c = c.clamp(0.0, 1.0) 20 | c = self.n_embed*c 21 | c_quant = c.round() 22 | c_ind = c_quant.to(dtype=torch.long) 23 | 24 | info = None, None, c_ind 25 | return c_quant, None, info 26 | 27 | def decode(self, c): 28 | c = c/self.n_embed 29 | c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor, 30 | mode="nearest") 31 | return c 32 | -------------------------------------------------------------------------------- /src/taming/modules/transformer/permuter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class AbstractPermuter(nn.Module): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__() 9 | def forward(self, x, reverse=False): 10 | raise NotImplementedError 11 | 12 | 13 | class Identity(AbstractPermuter): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def forward(self, x, reverse=False): 18 | return x 19 | 20 | 21 | class Subsample(AbstractPermuter): 22 | def __init__(self, H, W): 23 | super().__init__() 24 | C = 1 25 | indices = np.arange(H*W).reshape(C,H,W) 26 | while min(H, W) > 1: 27 | indices = indices.reshape(C,H//2,2,W//2,2) 28 | indices = indices.transpose(0,2,4,1,3) 29 | indices = indices.reshape(C*4,H//2, W//2) 30 | H = H//2 31 | W = W//2 32 | C = C*4 33 | assert H == W == 1 34 | idx = torch.tensor(indices.ravel()) 35 | self.register_buffer('forward_shuffle_idx', 36 | nn.Parameter(idx, requires_grad=False)) 37 | self.register_buffer('backward_shuffle_idx', 38 | nn.Parameter(torch.argsort(idx), requires_grad=False)) 39 | 40 | def forward(self, x, reverse=False): 41 | if not reverse: 42 | return x[:, self.forward_shuffle_idx] 43 | else: 44 | return x[:, self.backward_shuffle_idx] 45 | 46 | 47 | def mortonify(i, j): 48 | """(i,j) index to linear morton code""" 49 | i = np.uint64(i) 50 | j = np.uint64(j) 51 | 52 | z = np.uint(0) 53 | 54 | for pos in range(32): 55 | z = (z | 56 | ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) | 57 | ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1)) 58 | ) 59 | return z 60 | 61 | 62 | class ZCurve(AbstractPermuter): 63 | def __init__(self, H, W): 64 | super().__init__() 65 | reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)] 66 | idx = np.argsort(reverseidx) 67 | idx = torch.tensor(idx) 68 | reverseidx = torch.tensor(reverseidx) 69 | self.register_buffer('forward_shuffle_idx', 70 | idx) 71 | self.register_buffer('backward_shuffle_idx', 72 | reverseidx) 73 | 74 | def forward(self, x, reverse=False): 75 | if not reverse: 76 | return x[:, self.forward_shuffle_idx] 77 | else: 78 | return x[:, self.backward_shuffle_idx] 79 | 80 | 81 | class SpiralOut(AbstractPermuter): 82 | def __init__(self, H, W): 83 | super().__init__() 84 | assert H == W 85 | size = W 86 | indices = np.arange(size*size).reshape(size,size) 87 | 88 | i0 = size//2 89 | j0 = size//2-1 90 | 91 | i = i0 92 | j = j0 93 | 94 | idx = [indices[i0, j0]] 95 | step_mult = 0 96 | for c in range(1, size//2+1): 97 | step_mult += 1 98 | # steps left 99 | for k in range(step_mult): 100 | i = i - 1 101 | j = j 102 | idx.append(indices[i, j]) 103 | 104 | # step down 105 | for k in range(step_mult): 106 | i = i 107 | j = j + 1 108 | idx.append(indices[i, j]) 109 | 110 | step_mult += 1 111 | if c < size//2: 112 | # step right 113 | for k in range(step_mult): 114 | i = i + 1 115 | j = j 116 | idx.append(indices[i, j]) 117 | 118 | # step up 119 | for k in range(step_mult): 120 | i = i 121 | j = j - 1 122 | idx.append(indices[i, j]) 123 | else: 124 | # end reached 125 | for k in range(step_mult-1): 126 | i = i + 1 127 | idx.append(indices[i, j]) 128 | 129 | assert len(idx) == size*size 130 | idx = torch.tensor(idx) 131 | self.register_buffer('forward_shuffle_idx', idx) 132 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) 133 | 134 | def forward(self, x, reverse=False): 135 | if not reverse: 136 | return x[:, self.forward_shuffle_idx] 137 | else: 138 | return x[:, self.backward_shuffle_idx] 139 | 140 | 141 | class SpiralIn(AbstractPermuter): 142 | def __init__(self, H, W): 143 | super().__init__() 144 | assert H == W 145 | size = W 146 | indices = np.arange(size*size).reshape(size,size) 147 | 148 | i0 = size//2 149 | j0 = size//2-1 150 | 151 | i = i0 152 | j = j0 153 | 154 | idx = [indices[i0, j0]] 155 | step_mult = 0 156 | for c in range(1, size//2+1): 157 | step_mult += 1 158 | # steps left 159 | for k in range(step_mult): 160 | i = i - 1 161 | j = j 162 | idx.append(indices[i, j]) 163 | 164 | # step down 165 | for k in range(step_mult): 166 | i = i 167 | j = j + 1 168 | idx.append(indices[i, j]) 169 | 170 | step_mult += 1 171 | if c < size//2: 172 | # step right 173 | for k in range(step_mult): 174 | i = i + 1 175 | j = j 176 | idx.append(indices[i, j]) 177 | 178 | # step up 179 | for k in range(step_mult): 180 | i = i 181 | j = j - 1 182 | idx.append(indices[i, j]) 183 | else: 184 | # end reached 185 | for k in range(step_mult-1): 186 | i = i + 1 187 | idx.append(indices[i, j]) 188 | 189 | assert len(idx) == size*size 190 | idx = idx[::-1] 191 | idx = torch.tensor(idx) 192 | self.register_buffer('forward_shuffle_idx', idx) 193 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) 194 | 195 | def forward(self, x, reverse=False): 196 | if not reverse: 197 | return x[:, self.forward_shuffle_idx] 198 | else: 199 | return x[:, self.backward_shuffle_idx] 200 | 201 | 202 | class Random(nn.Module): 203 | def __init__(self, H, W): 204 | super().__init__() 205 | indices = np.random.RandomState(1).permutation(H*W) 206 | idx = torch.tensor(indices.ravel()) 207 | self.register_buffer('forward_shuffle_idx', idx) 208 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) 209 | 210 | def forward(self, x, reverse=False): 211 | if not reverse: 212 | return x[:, self.forward_shuffle_idx] 213 | else: 214 | return x[:, self.backward_shuffle_idx] 215 | 216 | 217 | class AlternateParsing(AbstractPermuter): 218 | def __init__(self, H, W): 219 | super().__init__() 220 | indices = np.arange(W*H).reshape(H,W) 221 | for i in range(1, H, 2): 222 | indices[i, :] = indices[i, ::-1] 223 | idx = indices.flatten() 224 | assert len(idx) == H*W 225 | idx = torch.tensor(idx) 226 | self.register_buffer('forward_shuffle_idx', idx) 227 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) 228 | 229 | def forward(self, x, reverse=False): 230 | if not reverse: 231 | return x[:, self.forward_shuffle_idx] 232 | else: 233 | return x[:, self.backward_shuffle_idx] 234 | 235 | 236 | if __name__ == "__main__": 237 | p0 = AlternateParsing(16, 16) 238 | print(p0.forward_shuffle_idx) 239 | print(p0.backward_shuffle_idx) 240 | 241 | x = torch.randint(0, 768, size=(11, 256)) 242 | y = p0(x) 243 | xre = p0(y, reverse=True) 244 | assert torch.equal(x, xre) 245 | 246 | p1 = SpiralOut(2, 2) 247 | print(p1.forward_shuffle_idx) 248 | print(p1.backward_shuffle_idx) 249 | -------------------------------------------------------------------------------- /src/taming/modules/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def count_params(model): 6 | total_params = sum(p.numel() for p in model.parameters()) 7 | return total_params 8 | 9 | 10 | class ActNorm(nn.Module): 11 | def __init__(self, num_features, logdet=False, affine=True, 12 | allow_reverse_init=False): 13 | assert affine 14 | super().__init__() 15 | self.logdet = logdet 16 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 17 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 18 | self.allow_reverse_init = allow_reverse_init 19 | 20 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 21 | 22 | def initialize(self, input): 23 | with torch.no_grad(): 24 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 25 | mean = ( 26 | flatten.mean(1) 27 | .unsqueeze(1) 28 | .unsqueeze(2) 29 | .unsqueeze(3) 30 | .permute(1, 0, 2, 3) 31 | ) 32 | std = ( 33 | flatten.std(1) 34 | .unsqueeze(1) 35 | .unsqueeze(2) 36 | .unsqueeze(3) 37 | .permute(1, 0, 2, 3) 38 | ) 39 | 40 | self.loc.data.copy_(-mean) 41 | self.scale.data.copy_(1 / (std + 1e-6)) 42 | 43 | def forward(self, input, reverse=False): 44 | if reverse: 45 | return self.reverse(input) 46 | if len(input.shape) == 2: 47 | input = input[:,:,None,None] 48 | squeeze = True 49 | else: 50 | squeeze = False 51 | 52 | _, _, height, width = input.shape 53 | 54 | if self.training and self.initialized.item() == 0: 55 | self.initialize(input) 56 | self.initialized.fill_(1) 57 | 58 | h = self.scale * (input + self.loc) 59 | 60 | if squeeze: 61 | h = h.squeeze(-1).squeeze(-1) 62 | 63 | if self.logdet: 64 | log_abs = torch.log(torch.abs(self.scale)) 65 | logdet = height*width*torch.sum(log_abs) 66 | logdet = logdet * torch.ones(input.shape[0]).to(input) 67 | return h, logdet 68 | 69 | return h 70 | 71 | def reverse(self, output): 72 | if self.training and self.initialized.item() == 0: 73 | if not self.allow_reverse_init: 74 | raise RuntimeError( 75 | "Initializing ActNorm in reverse direction is " 76 | "disabled by default. Use allow_reverse_init=True to enable." 77 | ) 78 | else: 79 | self.initialize(output) 80 | self.initialized.fill_(1) 81 | 82 | if len(output.shape) == 2: 83 | output = output[:,:,None,None] 84 | squeeze = True 85 | else: 86 | squeeze = False 87 | 88 | h = output / self.scale - self.loc 89 | 90 | if squeeze: 91 | h = h.squeeze(-1).squeeze(-1) 92 | return h 93 | 94 | 95 | class AbstractEncoder(nn.Module): 96 | def __init__(self): 97 | super().__init__() 98 | 99 | def encode(self, *args, **kwargs): 100 | raise NotImplementedError 101 | 102 | 103 | class Labelator(AbstractEncoder): 104 | """Net2Net Interface for Class-Conditional Model""" 105 | def __init__(self, n_classes, quantize_interface=True): 106 | super().__init__() 107 | self.n_classes = n_classes 108 | self.quantize_interface = quantize_interface 109 | 110 | def encode(self, c): 111 | c = c[:,None] 112 | if self.quantize_interface: 113 | return c, None, [None, None, c.long()] 114 | return c 115 | 116 | 117 | class SOSProvider(AbstractEncoder): 118 | # for unconditional training 119 | def __init__(self, sos_token, quantize_interface=True): 120 | super().__init__() 121 | self.sos_token = sos_token 122 | self.quantize_interface = quantize_interface 123 | 124 | def encode(self, x): 125 | # get batch size from data and replicate sos_token 126 | c = torch.ones(x.shape[0], 1)*self.sos_token 127 | c = c.long().to(x.device) 128 | if self.quantize_interface: 129 | return c, None, [None, None, c] 130 | return c 131 | -------------------------------------------------------------------------------- /src/taming/util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 7 | } 8 | 9 | CKPT_MAP = { 10 | "vgg_lpips": "vgg.pth" 11 | } 12 | 13 | MD5_MAP = { 14 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" 15 | } 16 | 17 | 18 | def download(url, local_path, chunk_size=1024): 19 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 20 | with requests.get(url, stream=True) as r: 21 | total_size = int(r.headers.get("content-length", 0)) 22 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 23 | with open(local_path, "wb") as f: 24 | for data in r.iter_content(chunk_size=chunk_size): 25 | if data: 26 | f.write(data) 27 | pbar.update(chunk_size) 28 | 29 | 30 | def md5_hash(path): 31 | with open(path, "rb") as f: 32 | content = f.read() 33 | return hashlib.md5(content).hexdigest() 34 | 35 | 36 | def get_ckpt_path(name, root, check=False): 37 | assert name in URL_MAP 38 | path = os.path.join(root, CKPT_MAP[name]) 39 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 40 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 41 | download(URL_MAP[name], path) 42 | md5 = md5_hash(path) 43 | assert md5 == MD5_MAP[name], md5 44 | return path 45 | 46 | 47 | class KeyNotFoundError(Exception): 48 | def __init__(self, cause, keys=None, visited=None): 49 | self.cause = cause 50 | self.keys = keys 51 | self.visited = visited 52 | messages = list() 53 | if keys is not None: 54 | messages.append("Key not found: {}".format(keys)) 55 | if visited is not None: 56 | messages.append("Visited: {}".format(visited)) 57 | messages.append("Cause:\n{}".format(cause)) 58 | message = "\n".join(messages) 59 | super().__init__(message) 60 | 61 | 62 | def retrieve( 63 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False 64 | ): 65 | """Given a nested list or dict return the desired value at key expanding 66 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion 67 | is done in-place. 68 | 69 | Parameters 70 | ---------- 71 | list_or_dict : list or dict 72 | Possibly nested list or dictionary. 73 | key : str 74 | key/to/value, path like string describing all keys necessary to 75 | consider to get to the desired value. List indices can also be 76 | passed here. 77 | splitval : str 78 | String that defines the delimiter between keys of the 79 | different depth levels in `key`. 80 | default : obj 81 | Value returned if :attr:`key` is not found. 82 | expand : bool 83 | Whether to expand callable nodes on the path or not. 84 | 85 | Returns 86 | ------- 87 | The desired value or if :attr:`default` is not ``None`` and the 88 | :attr:`key` is not found returns ``default``. 89 | 90 | Raises 91 | ------ 92 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is 93 | ``None``. 94 | """ 95 | 96 | keys = key.split(splitval) 97 | 98 | success = True 99 | try: 100 | visited = [] 101 | parent = None 102 | last_key = None 103 | for key in keys: 104 | if callable(list_or_dict): 105 | if not expand: 106 | raise KeyNotFoundError( 107 | ValueError( 108 | "Trying to get past callable node with expand=False." 109 | ), 110 | keys=keys, 111 | visited=visited, 112 | ) 113 | list_or_dict = list_or_dict() 114 | parent[last_key] = list_or_dict 115 | 116 | last_key = key 117 | parent = list_or_dict 118 | 119 | try: 120 | if isinstance(list_or_dict, dict): 121 | list_or_dict = list_or_dict[key] 122 | else: 123 | list_or_dict = list_or_dict[int(key)] 124 | except (KeyError, IndexError, ValueError) as e: 125 | raise KeyNotFoundError(e, keys=keys, visited=visited) 126 | 127 | visited += [key] 128 | # final expansion of retrieved value 129 | if expand and callable(list_or_dict): 130 | list_or_dict = list_or_dict() 131 | parent[last_key] = list_or_dict 132 | except KeyNotFoundError as e: 133 | if default is None: 134 | raise e 135 | else: 136 | list_or_dict = default 137 | success = False 138 | 139 | if not pass_success: 140 | return list_or_dict 141 | else: 142 | return list_or_dict, success 143 | 144 | 145 | if __name__ == "__main__": 146 | config = {"keya": "a", 147 | "keyb": "b", 148 | "keyc": 149 | {"cc1": 1, 150 | "cc2": 2, 151 | } 152 | } 153 | from omegaconf import OmegaConf 154 | config = OmegaConf.create(config) 155 | print(config) 156 | retrieve(config, "keya") 157 | 158 | -------------------------------------------------------------------------------- /utils_img.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyright: reportMissingModuleSource=false 8 | 9 | import numpy as np 10 | from augly.image import functional as aug_functional 11 | import torch 12 | from torchvision import transforms 13 | from torchvision.transforms import functional 14 | 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | default_transform = transforms.Compose([ 18 | transforms.ToTensor(), 19 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 20 | ]) 21 | 22 | normalize_vqgan = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize (x - 0.5) / 0.5 23 | unnormalize_vqgan = transforms.Normalize(mean=[-1, -1, -1], std=[1/0.5, 1/0.5, 1/0.5]) # Unnormalize (x * 0.5) + 0.5 24 | normalize_img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize (x - mean) / std 25 | unnormalize_img = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], std=[1/0.229, 1/0.224, 1/0.225]) # Unnormalize (x * std) + mean 26 | 27 | def psnr(x, y, img_space='vqgan'): 28 | """ 29 | Return PSNR 30 | Args: 31 | x: Image tensor with values approx. between [-1,1] 32 | y: Image tensor with values approx. between [-1,1], ex: original image 33 | """ 34 | if img_space == 'vqgan': 35 | delta = torch.clamp(unnormalize_vqgan(x), 0, 1) - torch.clamp(unnormalize_vqgan(y), 0, 1) 36 | elif img_space == 'img': 37 | delta = torch.clamp(unnormalize_img(x), 0, 1) - torch.clamp(unnormalize_img(y), 0, 1) 38 | else: 39 | delta = x - y 40 | delta = 255 * delta 41 | delta = delta.reshape(-1, x.shape[-3], x.shape[-2], x.shape[-1]) # BxCxHxW 42 | psnr = 20*np.log10(255) - 10*torch.log10(torch.mean(delta**2, dim=(1,2,3))) # B 43 | return psnr 44 | 45 | def center_crop(x, scale): 46 | """ Perform center crop such that the target area of the crop is at a given scale 47 | Args: 48 | x: PIL image 49 | scale: target area scale 50 | """ 51 | scale = np.sqrt(scale) 52 | new_edges_size = [int(s*scale) for s in x.shape[-2:]][::-1] 53 | return functional.center_crop(x, new_edges_size) 54 | 55 | def resize(x, scale): 56 | """ Perform center crop such that the target area of the crop is at a given scale 57 | Args: 58 | x: PIL image 59 | scale: target area scale 60 | """ 61 | scale = np.sqrt(scale) 62 | new_edges_size = [int(s*scale) for s in x.shape[-2:]][::-1] 63 | return functional.resize(x, new_edges_size) 64 | 65 | def rotate(x, angle): 66 | """ Rotate image by angle 67 | Args: 68 | x: image (PIl or tensor) 69 | angle: angle in degrees 70 | """ 71 | return functional.rotate(x, angle) 72 | 73 | def adjust_brightness(x, brightness_factor): 74 | """ Adjust brightness of an image 75 | Args: 76 | x: PIL image 77 | brightness_factor: brightness factor 78 | """ 79 | return normalize_img(functional.adjust_brightness(unnormalize_img(x), brightness_factor)) 80 | 81 | def adjust_contrast(x, contrast_factor): 82 | """ Adjust contrast of an image 83 | Args: 84 | x: PIL image 85 | contrast_factor: contrast factor 86 | """ 87 | return normalize_img(functional.adjust_contrast(unnormalize_img(x), contrast_factor)) 88 | 89 | def adjust_saturation(x, saturation_factor): 90 | """ Adjust saturation of an image 91 | Args: 92 | x: PIL image 93 | saturation_factor: saturation factor 94 | """ 95 | return normalize_img(functional.adjust_saturation(unnormalize_img(x), saturation_factor)) 96 | 97 | def adjust_hue(x, hue_factor): 98 | """ Adjust hue of an image 99 | Args: 100 | x: PIL image 101 | hue_factor: hue factor 102 | """ 103 | return normalize_img(functional.adjust_hue(unnormalize_img(x), hue_factor)) 104 | 105 | def adjust_gamma(x, gamma, gain=1): 106 | """ Adjust gamma of an image 107 | Args: 108 | x: PIL image 109 | gamma: gamma factor 110 | gain: gain factor 111 | """ 112 | return normalize_img(functional.adjust_gamma(unnormalize_img(x), gamma, gain)) 113 | 114 | def adjust_sharpness(x, sharpness_factor): 115 | """ Adjust sharpness of an image 116 | Args: 117 | x: PIL image 118 | sharpness_factor: sharpness factor 119 | """ 120 | return normalize_img(functional.adjust_sharpness(unnormalize_img(x), sharpness_factor)) 121 | 122 | def overlay_text(x, text='Lorem Ipsum'): 123 | """ Overlay text on image 124 | Args: 125 | x: PIL image 126 | text: text to overlay 127 | font_path: path to font 128 | font_size: font size 129 | color: text color 130 | position: text position 131 | """ 132 | to_pil = transforms.ToPILImage() 133 | to_tensor = transforms.ToTensor() 134 | img_aug = torch.zeros_like(x, device=x.device) 135 | for ii,img in enumerate(x): 136 | pil_img = to_pil(unnormalize_img(img)) 137 | img_aug[ii] = to_tensor(aug_functional.overlay_text(pil_img, text=text)) 138 | return normalize_img(img_aug) 139 | 140 | def jpeg_compress(x, quality_factor): 141 | """ Apply jpeg compression to image 142 | Args: 143 | x: PIL image 144 | quality_factor: quality factor 145 | """ 146 | to_pil = transforms.ToPILImage() 147 | to_tensor = transforms.ToTensor() 148 | img_aug = torch.zeros_like(x, device=x.device) 149 | for ii,img in enumerate(x): 150 | pil_img = to_pil(unnormalize_img(img)) 151 | img_aug[ii] = to_tensor(aug_functional.encoding_quality(pil_img, quality=quality_factor)) 152 | return normalize_img(img_aug) 153 | -------------------------------------------------------------------------------- /utils_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import importlib 8 | import torch 9 | import torch.nn as nn 10 | 11 | ### Load HiDDeN models 12 | 13 | class ConvBNRelu(nn.Module): 14 | """ 15 | Building block used in HiDDeN network. Is a sequence of Convolution, Batch Normalization, and ReLU activation 16 | """ 17 | def __init__(self, channels_in, channels_out): 18 | 19 | super(ConvBNRelu, self).__init__() 20 | 21 | self.layers = nn.Sequential( 22 | nn.Conv2d(channels_in, channels_out, 3, stride=1, padding=1), 23 | nn.BatchNorm2d(channels_out, eps=1e-3), 24 | nn.GELU() 25 | ) 26 | 27 | def forward(self, x): 28 | return self.layers(x) 29 | 30 | class HiddenDecoder(nn.Module): 31 | """ 32 | Decoder module. Receives a watermarked image and extracts the watermark. 33 | """ 34 | def __init__(self, num_blocks, num_bits, channels, redundancy=1): 35 | 36 | super(HiddenDecoder, self).__init__() 37 | 38 | layers = [ConvBNRelu(3, channels)] 39 | for _ in range(num_blocks - 1): 40 | layers.append(ConvBNRelu(channels, channels)) 41 | 42 | layers.append(ConvBNRelu(channels, num_bits*redundancy)) 43 | layers.append(nn.AdaptiveAvgPool2d(output_size=(1, 1))) 44 | self.layers = nn.Sequential(*layers) 45 | 46 | self.linear = nn.Linear(num_bits*redundancy, num_bits*redundancy) 47 | 48 | self.num_bits = num_bits 49 | self.redundancy = redundancy 50 | 51 | def forward(self, img_w): 52 | 53 | x = self.layers(img_w) # b d 1 1 54 | x = x.squeeze(-1).squeeze(-1) # b d 55 | x = self.linear(x) 56 | 57 | x = x.view(-1, self.num_bits, self.redundancy) # b k*r -> b k r 58 | x = torch.sum(x, dim=-1) # b k r -> b k 59 | 60 | return x 61 | 62 | class HiddenEncoder(nn.Module): 63 | """ 64 | Inserts a watermark into an image. 65 | """ 66 | def __init__(self, num_blocks, num_bits, channels, last_tanh=True): 67 | super(HiddenEncoder, self).__init__() 68 | layers = [ConvBNRelu(3, channels)] 69 | 70 | for _ in range(num_blocks-1): 71 | layer = ConvBNRelu(channels, channels) 72 | layers.append(layer) 73 | 74 | self.conv_bns = nn.Sequential(*layers) 75 | self.after_concat_layer = ConvBNRelu(channels + 3 + num_bits, channels) 76 | 77 | self.final_layer = nn.Conv2d(channels, 3, kernel_size=1) 78 | 79 | self.last_tanh = last_tanh 80 | self.tanh = nn.Tanh() 81 | 82 | def forward(self, imgs, msgs): 83 | 84 | msgs = msgs.unsqueeze(-1).unsqueeze(-1) # b l 1 1 85 | msgs = msgs.expand(-1,-1, imgs.size(-2), imgs.size(-1)) # b l h w 86 | 87 | encoded_image = self.conv_bns(imgs) 88 | 89 | concat = torch.cat([msgs, encoded_image, imgs], dim=1) 90 | im_w = self.after_concat_layer(concat) 91 | im_w = self.final_layer(im_w) 92 | 93 | if self.last_tanh: 94 | im_w = self.tanh(im_w) 95 | 96 | return im_w 97 | 98 | def get_hidden_decoder(num_bits, redundancy=1, num_blocks=7, channels=64): 99 | decoder = HiddenDecoder(num_blocks=num_blocks, num_bits=num_bits, channels=channels, redundancy=redundancy) 100 | return decoder 101 | 102 | def get_hidden_decoder_ckpt(ckpt_path): 103 | ckpt = torch.load(ckpt_path, map_location="cpu") 104 | decoder_ckpt = { k.replace('module.', '').replace('decoder.', '') : v for k,v in ckpt['encoder_decoder'].items() if 'decoder' in k} 105 | return decoder_ckpt 106 | 107 | def get_hidden_encoder(num_bits, num_blocks=4, channels=64): 108 | encoder = HiddenEncoder(num_blocks=num_blocks, num_bits=num_bits, channels=channels) 109 | return encoder 110 | 111 | def get_hidden_encoder_ckpt(ckpt_path): 112 | ckpt = torch.load(ckpt_path, map_location="cpu") 113 | encoder_ckpt = { k.replace('module.', '').replace('encoder.', '') : v for k,v in ckpt['encoder_decoder'].items() if 'encoder' in k} 114 | return encoder_ckpt 115 | 116 | ### Load LDM models 117 | 118 | def instantiate_from_config(config): 119 | if not "target" in config: 120 | if config == '__is_first_stage__': 121 | return None 122 | elif config == "__is_unconditional__": 123 | return None 124 | raise KeyError("Expected key `target` to instantiate.") 125 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 126 | 127 | def get_obj_from_str(string, reload=False): 128 | module, cls = string.rsplit(".", 1) 129 | if reload: 130 | module_imp = importlib.import_module(module) 131 | importlib.reload(module_imp) 132 | return getattr(importlib.import_module(module, package=None), cls) 133 | 134 | def load_model_from_config(config, ckpt, verbose=False): 135 | print(f"Loading model from {ckpt}") 136 | pl_sd = torch.load(ckpt, map_location="cpu") 137 | if "global_step" in pl_sd: 138 | print(f"Global Step: {pl_sd['global_step']}") 139 | sd = pl_sd["state_dict"] 140 | model = instantiate_from_config(config.model) 141 | m, u = model.load_state_dict(sd, strict=False) 142 | if len(m) > 0 and verbose: 143 | print("missing keys:") 144 | print(m) 145 | if len(u) > 0 and verbose: 146 | print("unexpected keys:") 147 | print(u) 148 | 149 | model.cuda() 150 | model.eval() 151 | return model 152 | --------------------------------------------------------------------------------