├── .gitignore ├── LICENSE.md ├── LPNet.py ├── README.md ├── assets ├── res0.png └── res1.png ├── data.py ├── ffhq.log ├── helpers ├── imle_helpers.py ├── train_helpers.py └── utils.py ├── hps.py ├── lpips └── weights │ └── v0.1 │ ├── alex.pth │ ├── squeeze.pth │ └── vgg.pth ├── mapping_network.py ├── metrics ├── ppl.py └── ppl_uniform.py ├── models.py ├── notebooks └── low-dim.ipynb ├── reproduce ├── 100-shot-grumpy_cat.sh ├── 100-shot-obama.sh ├── 100-shot-panda.sh ├── AnimalFace-cat.sh ├── AnimalFace-dog.sh └── ffhq.sh ├── sampler.py ├── setup_datasets.sh ├── test.py ├── train.py └── visual ├── generate_rnd.py ├── generate_rnd_nn.py ├── generate_sample_nn.py ├── interpolate.py ├── nn_interplate.py ├── spatial_visual.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | cifar-10* 3 | saved_models 4 | .DS_Store 5 | __pycache__ 6 | dciknn_cuda 7 | res/ 8 | experiments 9 | .vscode 10 | .datasets/ 11 | datasets 12 | wandb 13 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright (c) 2012-2023 Scott Chacon and others 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining 4 | a copy of this software and associated documentation files (the 5 | "Software"), to deal in the Software without restriction, including 6 | without limitation the rights to use, copy, modify, merge, publish, 7 | distribute, sublicense, and/or sell copies of the Software, and to 8 | permit persons to whom the Software is furnished to do so, subject to 9 | the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /LPNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import namedtuple 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models as tv 6 | 7 | 8 | def normalize_tensor(in_feat, eps=1e-10): 9 | norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True) + eps) 10 | return in_feat / (norm_factor + eps) 11 | 12 | 13 | class RerangeLayer(nn.Module): 14 | # Change the input from range [-1., 1.] to [0., 1.] 15 | def __init__(self): 16 | super(RerangeLayer, self).__init__() 17 | 18 | def forward(self, inp): 19 | return (inp + 1.) / 2. 20 | 21 | 22 | class NetLinLayer(nn.Module): 23 | ''' A single linear layer used as placeholder for LPIPS learnt weights ''' 24 | def __init__(self): 25 | super(NetLinLayer, self).__init__() 26 | self.weight = None 27 | 28 | def forward(self, inp): 29 | out = self.weight * inp 30 | return out 31 | 32 | 33 | class ScalingLayer(nn.Module): 34 | # For rescaling the input to vgg16 35 | def __init__(self): 36 | super(ScalingLayer, self).__init__() 37 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 38 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 39 | 40 | def forward(self, inp): 41 | return (inp - self.shift) / self.scale 42 | 43 | 44 | # Learned perceptual network, modified from https://github.com/richzhang/PerceptualSimilarity 45 | class LPNet(nn.Module): 46 | def __init__(self, pnet_type='vgg', version='0.1', path='.'): 47 | super(LPNet, self).__init__() 48 | 49 | self.scaling_layer = ScalingLayer() 50 | self.net = vgg16(pretrained=True, requires_grad=False) 51 | self.L = 5 52 | self.lins = [NetLinLayer() for _ in range(self.L)] 53 | 54 | model_path = os.path.abspath( 55 | os.path.join(path, 'weights/v%s/%s.pth' % (version, pnet_type))) 56 | print('Loading model from: %s' % model_path) 57 | weights = torch.load(model_path) 58 | for i in range(self.L): 59 | self.lins[i].weight = torch.sqrt(weights["lin%d.model.1.weight" % i]) 60 | 61 | def forward(self, in0, avg=False): 62 | in0_input = self.scaling_layer(in0) 63 | outs0 = self.net.forward(in0_input) 64 | feats0 = {} 65 | shapes = [] 66 | res = [] 67 | 68 | for kk in range(self.L): 69 | feats0[kk] = normalize_tensor(outs0[kk]) 70 | 71 | if avg: 72 | res = [self.lins[kk](feats0[kk]).mean([2,3],keepdim=False) for kk in range(self.L)] 73 | else: 74 | for kk in range(self.L): 75 | cur_res = self.lins[kk](feats0[kk]) 76 | shapes.append(cur_res.shape[-1]) 77 | res.append(cur_res.reshape(cur_res.shape[0], -1)) 78 | 79 | return res, shapes 80 | 81 | 82 | class vgg16(torch.nn.Module): 83 | def __init__(self, requires_grad=False, pretrained=True): 84 | super(vgg16, self).__init__() 85 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features 86 | self.slice1 = torch.nn.Sequential() 87 | self.slice2 = torch.nn.Sequential() 88 | self.slice3 = torch.nn.Sequential() 89 | self.slice4 = torch.nn.Sequential() 90 | self.slice5 = torch.nn.Sequential() 91 | self.N_slices = 5 92 | for x in range(4): 93 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 94 | for x in range(4, 9): 95 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 96 | for x in range(9, 16): 97 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 98 | for x in range(16, 23): 99 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 100 | for x in range(23, 30): 101 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 102 | if not requires_grad: 103 | for param in self.parameters(): 104 | param.requires_grad = False 105 | 106 | def forward(self, x): 107 | h = self.slice1(x) 108 | h_relu1_2 = h 109 | h = self.slice2(h) 110 | h_relu2_2 = h 111 | h = self.slice3(h) 112 | h_relu3_3 = h 113 | h = self.slice4(h) 114 | h_relu4_3 = h 115 | h = self.slice5(h) 116 | h_relu5_3 = h 117 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 118 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 119 | 120 | return out -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adaptive IMLE for Few-shot Pretraining-free Generative Modelling 2 | 3 | Official PyTorch implementation of the ICML 2023 paper 4 | 5 | 6 | 7 | ![img](./assets/res0.png) 8 | ![img](./assets/res1.png) 9 | 10 | ### Abstract 11 | 12 | _Despite their success on large datasets, GANs have been difficult to apply in the few-shot setting, where only a limited number of training examples are provided. Due to mode collapse, GANs tend to ignore some training examples, causing overfitting to a subset of the training dataset, which is small in the first place. A recent method called Implicit Maximum Likelihood Estimation (IMLE) is an alternative to GAN that tries to address this issue. It uses the same kind of generators as GANs but trains it with a different objective that encourages mode coverage. However, the theoretical guarantees of IMLE hold under a restrictive condition that the optimal likelihood at all data points is the same. In this paper, we present a more generalized formulation of IMLE which includes the original formulation as a special case, and we prove that the theoretical guarantees hold under weaker conditions. Using this generalized formulation, we further derive a new algorithm, which we dub Adaptive IMLE, which can adapt to the varying difficulty of different training examples. We demonstrate on multiple few-shot image synthesis datasets that our method significantly outperforms existing methods._ 13 | 14 | [Paper PDF](https://scholar.google.com/citations?view_op=view_citation&hl=en&user=sKWTHpsAAAAJ&citation_for_view=sKWTHpsAAAAJ:u5HHmVD_uO8C). 15 | 16 | ## Requirements 17 | 18 | Python3.8 19 | 20 | ```bash 21 | virtualenv -p python3.8 venv 22 | pip3 install -r requirements.txt 23 | pip3 install dciknn_cuda-0.1.15.tar.gz 24 | ``` 25 | 26 | ## Pretrained Models 27 | 28 | Pretrained models can be downloaded from the following link: 29 | https://drive.google.com/file/d/1X8nl1TWjv2w_zk_8FoRhtht0au3jvI8k/view?usp=sharing 30 | 31 | ## Datasets 32 | 33 | Running the following will doownload and extract all datasets used in this project. 34 | 35 | ```bash 36 | bash ./scripts/setup_datasets.sh 37 | ``` 38 | 39 | Alternatively, you can download the datasets manually from the following links: 40 | https://drive.google.com/file/d/1VwFFzU8wJD1XJtfg60iLwnyBQ_cLZObL/view?usp=drive_link 41 | 42 | ## Reproducing Results 43 | 44 | The reported results can be reproduced using the scripts in the `reproduce` folder. For example, to reproduce the results on the `FFHQ` dataset, run the following command: 45 | 46 | ```bash 47 | python ./reproduce/ffhq.sh 48 | ``` 49 | 50 | Make sure to set the `--data_root` flag to the path where the datasets are stored. 51 | 52 | ## Training on Custom Datasets 53 | 54 | Change the `--data_root` flag to the path where the datasets are stored. Also, see the [Important Hyperparameters](#important-hyperparameters) section for appropriate set of hyperparameters to be used. 55 | 56 | ## Important Hyperparameters 57 | 58 | `--data_root`: 59 | path to the image folder dataset 60 | 61 | `--chagne_coef`: `\tau` in the paper, the percentage of threshold before considering the subproblem as solved. 62 | 63 | `--force_factor`: defines the number of random samples to be generated for the nearest neighbour finding part in terms of dataset length. E.g., `2` means 2 \* _dataset_length_ random samples will be generated. We have kept the number of random samples around 10k. 64 | 65 | `--lr`: learning rate. 66 | 67 | ## Notebook 68 | 69 | A very simple (and not efficient) implementation of IMLE and AdaptiveIMLE along with a toy training example can be found in the `notebooks` folder. 70 | 71 | ## Wandb 72 | 73 | If you set the following wandb parameters the metrics including loss and FID score throughout the training will be logged to wandb. 74 | 75 | `--use_wandb`: set to `True` to log to wandb. 76 | 77 | `--wandb_project`: wandb project name. 78 | 79 | `--wandb_name`: run name. 80 | 81 | ## Citation 82 | ```@inproceedings{aghabozorgi2023adaimle, 83 | title={Adaptive IMLE for Few-shot Pretraining-free Generative Modelling 84 | }, 85 | author={Mehran Aghabozorgi and Shichong Peng and Ke Li}, 86 | booktitle={International Conference on Machine Learning}, 87 | year={2023} 88 | }``` -------------------------------------------------------------------------------- /assets/res0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mehranagh20/AdaIMLE/4f785e1d650aadee62003c3425e0fc5dd10915d6/assets/res0.png -------------------------------------------------------------------------------- /assets/res1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mehranagh20/AdaIMLE/4f785e1d650aadee62003c3425e0fc5dd10915d6/assets/res1.png -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import os 4 | import torch 5 | from torch.utils.data import TensorDataset, DataLoader 6 | from torchvision.datasets import ImageFolder 7 | import torchvision.transforms as transforms 8 | from sklearn.model_selection import train_test_split 9 | 10 | 11 | def set_up_data(H): 12 | shift_loss = -127.5 13 | scale_loss = 1. / 127.5 14 | if H.dataset == 'imagenet32': 15 | trX, vaX, teX = imagenet32(H.data_root) 16 | H.image_size = 32 17 | H.image_channels = 3 18 | shift = -116.2373 19 | scale = 1. / 69.37404 20 | elif H.dataset in ['fewshot', 'fewshot512']: 21 | trX, vaX, teX = few_shot_image_folder(H.data_root, H.image_size) 22 | H.image_channels = 3 23 | shift = -116.2373 24 | scale = 1. / 69.37404 25 | elif H.dataset == 'imagenet64': 26 | trX, vaX, teX = imagenet64(H.data_root) 27 | H.image_size = 64 28 | H.image_channels = 3 29 | shift = -115.92961967 30 | scale = 1. / 69.37404 31 | elif H.dataset == 'ffhq_256': 32 | trX, vaX, teX = ffhq256(H.data_root) 33 | H.image_size = 256 34 | H.image_channels = 3 35 | shift = -112.8666757481 36 | scale = 1. / 69.84780273 37 | elif H.dataset == 'ffhq_1024': 38 | trX, vaX, teX = ffhq1024(H.data_root) 39 | H.image_size = 1024 40 | H.image_channels = 3 41 | shift = -0.4387 42 | scale = 1.0 / 0.2743 43 | shift_loss = -0.5 44 | scale_loss = 2.0 45 | elif H.dataset == 'cifar10': 46 | (trX, _), (vaX, _), (teX, _) = cifar10(H.data_root, one_hot=False) 47 | H.image_size = 32 48 | H.image_channels = 3 49 | shift = -120.63838 50 | scale = 1. / 64.16736 51 | else: 52 | raise ValueError('unknown dataset: ', H.dataset) 53 | 54 | do_low_bit = H.dataset in ['ffhq_256'] 55 | 56 | if H.test_eval: 57 | print('DOING TEST') 58 | eval_dataset = teX 59 | else: 60 | eval_dataset = vaX 61 | 62 | shift = torch.tensor([shift]).cuda().view(1, 1, 1, 1) 63 | scale = torch.tensor([scale]).cuda().view(1, 1, 1, 1) 64 | shift_loss = torch.tensor([shift_loss]).cuda().view(1, 1, 1, 1) 65 | scale_loss = torch.tensor([scale_loss]).cuda().view(1, 1, 1, 1) 66 | 67 | if H.dataset == 'ffhq_1024': 68 | train_data = ImageFolder(trX, transforms.ToTensor()) 69 | valid_data = ImageFolder(eval_dataset, transforms.ToTensor()) 70 | untranspose = True 71 | elif H.dataset not in ['fewshot', 'fewshot512']: 72 | train_data = TensorDataset(torch.as_tensor(trX)) 73 | valid_data = TensorDataset(torch.as_tensor(eval_dataset)) 74 | untranspose = False 75 | else: 76 | train_data = trX 77 | for data_train in DataLoader(train_data, batch_size=len(train_data)): 78 | ds = torch.tensor(data_train[0] * 255, dtype=torch.uint8) 79 | train_data = TensorDataset(ds.permute(0, 2, 3, 1)) 80 | break 81 | valid_data = train_data 82 | untranspose = False 83 | 84 | 85 | def preprocess_func(x): 86 | nonlocal shift 87 | nonlocal scale 88 | nonlocal shift_loss 89 | nonlocal scale_loss 90 | nonlocal do_low_bit 91 | nonlocal untranspose 92 | 'takes in a data example and returns the preprocessed input' 93 | 'as well as the input processed for the loss' 94 | if untranspose: 95 | x[0] = x[0].permute(0, 2, 3, 1) 96 | inp = x[0].cuda(non_blocking=True).float() 97 | inp.mul_(1./127.5).add_(-1) 98 | # out = inp.clone() 99 | # inp.add_(shift).mul_(scale) 100 | # if do_low_bit: 101 | # 5 bits of precision 102 | # out.mul_(1. / 8.).floor_().mul_(8.) 103 | # out.add_(shift_loss).mul_(scale_loss) 104 | return inp, inp 105 | 106 | return H, train_data, valid_data, preprocess_func 107 | 108 | 109 | def mkdir_p(path): 110 | os.makedirs(path, exist_ok=True) 111 | 112 | 113 | def flatten(outer): 114 | return [el for inner in outer for el in inner] 115 | 116 | 117 | def unpickle_cifar10(file): 118 | fo = open(file, 'rb') 119 | data = pickle.load(fo, encoding='bytes') 120 | fo.close() 121 | data = dict(zip([k.decode() for k in data.keys()], data.values())) 122 | return data 123 | 124 | 125 | def few_shot_image_folder(data_root, image_size): 126 | transform_list = [ 127 | transforms.Resize((int(image_size), int(image_size))), 128 | transforms.ToTensor(), 129 | ] 130 | trans = transforms.Compose(transform_list) 131 | train_data = ImageFolder(data_root, trans) 132 | return train_data, train_data, train_data 133 | 134 | 135 | def imagenet32(data_root): 136 | trX = np.load(os.path.join(data_root, 'imagenet32-train.npy'), mmap_mode='r') 137 | np.random.seed(42) 138 | tr_va_split_indices = np.random.permutation(trX.shape[0]) 139 | train = trX[tr_va_split_indices[:-5000]] 140 | valid = trX[tr_va_split_indices[-5000:]] 141 | test = np.load(os.path.join(data_root, 'imagenet32-valid.npy'), mmap_mode='r') 142 | return train, valid, test 143 | 144 | 145 | def imagenet64(data_root): 146 | trX = np.load(os.path.join(data_root, 'imagenet64-train.npy'), mmap_mode='r') 147 | np.random.seed(42) 148 | tr_va_split_indices = np.random.permutation(trX.shape[0]) 149 | train = trX[tr_va_split_indices[:-5000]] 150 | valid = trX[tr_va_split_indices[-5000:]] 151 | test = np.load(os.path.join(data_root, 'imagenet64-valid.npy'), mmap_mode='r') # this is test. 152 | return train, valid, test 153 | 154 | 155 | def ffhq1024(data_root): 156 | # we did not significantly tune hyperparameters on ffhq-1024, and so simply evaluate on the test set 157 | return os.path.join(data_root, 'ffhq1024/train'), os.path.join(data_root, 'ffhq1024/valid'), os.path.join(data_root, 'ffhq1024/valid') 158 | 159 | 160 | def ffhq256(data_root): 161 | trX = np.load(os.path.join(data_root, 'ffhq-256.npy'), mmap_mode='r') 162 | np.random.seed(5) 163 | tr_va_split_indices = np.random.permutation(trX.shape[0]) 164 | train = trX[tr_va_split_indices[:-7000]] 165 | valid = trX[tr_va_split_indices[-7000:]] 166 | # we did not significantly tune hyperparameters on ffhq-256, and so simply evaluate on the test set 167 | return train, valid, valid 168 | 169 | 170 | def cifar10(data_root, one_hot=True): 171 | tr_data = [unpickle_cifar10(os.path.join(data_root, 'cifar-10-batches-py/', 'data_batch_%d' % i)) for i in range(1, 6)] 172 | trX = np.vstack(data['data'] for data in tr_data) 173 | trY = np.asarray(flatten([data['labels'] for data in tr_data])) 174 | te_data = unpickle_cifar10(os.path.join(data_root, 'cifar-10-batches-py/', 'test_batch')) 175 | teX = np.asarray(te_data['data']) 176 | teY = np.asarray(te_data['labels']) 177 | trX = trX.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1) 178 | teX = teX.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1) 179 | trX, vaX, trY, vaY = train_test_split(trX, trY, test_size=5000, random_state=11172018) 180 | if one_hot: 181 | trY = np.eye(10, dtype=np.float32)[trY] 182 | vaY = np.eye(10, dtype=np.float32)[vaY] 183 | teY = np.eye(10, dtype=np.float32)[teY] 184 | else: 185 | trY = np.reshape(trY, [-1, 1]) 186 | vaY = np.reshape(vaY, [-1, 1]) 187 | teY = np.reshape(teY, [-1, 1]) 188 | return (trX, trY), (vaX, vaY), (teX, teY) 189 | -------------------------------------------------------------------------------- /ffhq.log: -------------------------------------------------------------------------------- 1 | Found 100 images in the folder /home/mehran/data/few-shot-images/ffhq 2 | Found 4992 images in the folder /home/mehran/sfu/res/few-shot/ours/ffhq/5000-rep 3 | ['../vdimle-scripts/fid/script.py', '/home/mehran/data/few-shot-images/ffhq', '/home/mehran/sfu/res/few-shot/ours/ffhq/5000-rep']: 33.15797959063741 4 | -------------------------------------------------------------------------------- /helpers/imle_helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.optim import AdamW 6 | import imageio 7 | from visual.utils import get_sample_for_visualization, generate_for_NN, generate_images_initial 8 | from torch.utils.data import DataLoader, TensorDataset 9 | from helpers.utils import ZippedDataset, get_cpu_stats_over_ranks 10 | 11 | 12 | @torch.jit.script 13 | def gaussian_analytical_kl(mu1, mu2, logsigma1, logsigma2): 14 | return -0.5 + logsigma2 - logsigma1 + 0.5 * (logsigma1.exp() ** 2 + (mu1 - mu2) ** 2) / (logsigma2.exp() ** 2) 15 | 16 | 17 | @torch.jit.script 18 | def draw_gaussian_diag_samples(mu, logsigma, eps): 19 | return torch.exp(logsigma) * eps + mu 20 | 21 | 22 | def get_conv(in_dim, out_dim, kernel_size, stride, padding, zero_bias=True, zero_weights=False, groups=1, scaled=False): 23 | c = nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, groups=groups) 24 | if zero_bias: 25 | c.bias.data *= 0.0 26 | if zero_weights: 27 | c.weight.data *= 0.0 28 | return c 29 | 30 | 31 | def get_3x3(in_dim, out_dim, zero_bias=True, zero_weights=False, groups=1, scaled=False): 32 | return get_conv(in_dim, out_dim, 3, 1, 1, zero_bias, zero_weights, groups=groups, scaled=scaled) 33 | 34 | 35 | def get_1x1(in_dim, out_dim, zero_bias=True, zero_weights=False, groups=1, scaled=False): 36 | return get_conv(in_dim, out_dim, 1, 1, 0, zero_bias, zero_weights, groups=groups, scaled=scaled) 37 | 38 | 39 | def log_prob_from_logits(x): 40 | """ numerically stable log_softmax implementation that prevents overflow """ 41 | axis = len(x.shape) - 1 42 | m = x.max(dim=axis, keepdim=True)[0] 43 | return x - m - torch.log(torch.exp(x - m).sum(dim=axis, keepdim=True)) 44 | 45 | 46 | def const_max(t, constant): 47 | other = torch.ones_like(t) * constant 48 | return torch.max(t, other) 49 | 50 | 51 | def const_min(t, constant): 52 | other = torch.ones_like(t) * constant 53 | return torch.min(t, other) 54 | 55 | 56 | def discretized_mix_logistic_loss(x, l, low_bit=False): 57 | """ log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """ 58 | # Adapted from https://github.com/openai/pixel-cnn/blob/master/pixel_cnn_pp/nn.py 59 | xs = [s for s in x.shape] # true image (i.e. labels) to regress to, e.g. (B,32,32,3) 60 | ls = [s for s in l.shape] # predicted distribution, e.g. (B,32,32,100) 61 | nr_mix = int(ls[-1] / 10) # here and below: unpacking the params of the mixture of logistics 62 | logit_probs = l[:, :, :, :nr_mix] 63 | l = torch.reshape(l[:, :, :, nr_mix:], xs + [nr_mix * 3]) 64 | means = l[:, :, :, :, :nr_mix] 65 | log_scales = const_max(l[:, :, :, :, nr_mix:2 * nr_mix], -7.) 66 | coeffs = torch.tanh(l[:, :, :, :, 2 * nr_mix:3 * nr_mix]) 67 | x = torch.reshape(x, xs + [1]) + torch.zeros(xs + [nr_mix]).to(x.device) # here and below: getting the means and adjusting them based on preceding sub-pixels 68 | m2 = torch.reshape(means[:, :, :, 1, :] + coeffs[:, :, :, 0, :] * x[:, :, :, 0, :], [xs[0], xs[1], xs[2], 1, nr_mix]) 69 | m3 = torch.reshape(means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] + coeffs[:, :, :, 2, :] * x[:, :, :, 1, :], [xs[0], xs[1], xs[2], 1, nr_mix]) 70 | means = torch.cat([torch.reshape(means[:, :, :, 0, :], [xs[0], xs[1], xs[2], 1, nr_mix]), m2, m3], dim=3) 71 | centered_x = x - means 72 | inv_stdv = torch.exp(-log_scales) 73 | if low_bit: 74 | plus_in = inv_stdv * (centered_x + 1. / 31.) 75 | cdf_plus = torch.sigmoid(plus_in) 76 | min_in = inv_stdv * (centered_x - 1. / 31.) 77 | else: 78 | plus_in = inv_stdv * (centered_x + 1. / 255.) 79 | cdf_plus = torch.sigmoid(plus_in) 80 | min_in = inv_stdv * (centered_x - 1. / 255.) 81 | cdf_min = torch.sigmoid(min_in) 82 | log_cdf_plus = plus_in - F.softplus(plus_in) # log probability for edge case of 0 (before scaling) 83 | log_one_minus_cdf_min = -F.softplus(min_in) # log probability for edge case of 255 (before scaling) 84 | cdf_delta = cdf_plus - cdf_min # probability for all other cases 85 | mid_in = inv_stdv * centered_x 86 | log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in) # log probability in the center of the bin, to be used in extreme cases (not actually used in our code) 87 | 88 | # now select the right output: left edge case, right edge case, normal case, extremely low prob case (doesn't actually happen for us) 89 | 90 | # this is what we are really doing, but using the robust version below for extreme cases in other applications and to avoid NaN issue with tf.select() 91 | # log_probs = tf.select(x < -0.999, log_cdf_plus, tf.select(x > 0.999, log_one_minus_cdf_min, tf.log(cdf_delta))) 92 | 93 | # robust version, that still works if probabilities are below 1e-5 (which never happens in our code) 94 | # tensorflow backpropagates through tf.select() by multiplying with zero instead of selecting: this requires use to use some ugly tricks to avoid potential NaNs 95 | # the 1e-12 in tf.maximum(cdf_delta, 1e-12) is never actually used as output, it's purely there to get around the tf.select() gradient issue 96 | # if the probability on a sub-pixel is below 1e-5, we use an approximation based on the assumption that the log-density is constant in the bin of the observed sub-pixel value 97 | if low_bit: 98 | log_probs = torch.where(x < -0.999, 99 | log_cdf_plus, 100 | torch.where(x > 0.999, 101 | log_one_minus_cdf_min, 102 | torch.where(cdf_delta > 1e-5, 103 | torch.log(const_max(cdf_delta, 1e-12)), 104 | log_pdf_mid - np.log(15.5)))) 105 | else: 106 | log_probs = torch.where(x < -0.999, 107 | log_cdf_plus, 108 | torch.where(x > 0.999, 109 | log_one_minus_cdf_min, 110 | torch.where(cdf_delta > 1e-5, 111 | torch.log(const_max(cdf_delta, 1e-12)), 112 | log_pdf_mid - np.log(127.5)))) 113 | log_probs = log_probs.sum(dim=3) + log_prob_from_logits(logit_probs) 114 | mixture_probs = torch.logsumexp(log_probs, -1) 115 | res = -1. * mixture_probs.sum(dim=[1, 2]) / np.prod(xs[1:]) 116 | return res 117 | 118 | 119 | def sample_from_discretized_mix_logistic(l, nr_mix, eps=None, u=None): 120 | ls = [s for s in l.shape] 121 | xs = ls[:-1] + [3] 122 | # unpack parameters 123 | logit_probs = l[:, :, :, :nr_mix] 124 | l = torch.reshape(l[:, :, :, nr_mix:], xs + [nr_mix * 3]) 125 | # sample mixture indicator from softmax 126 | if eps is None: 127 | eps = torch.empty(logit_probs.shape, device=l.device).uniform_(1e-5, 1. - 1e-5) 128 | amax = torch.argmax(logit_probs - torch.log(-torch.log(eps)), dim=3) 129 | sel = F.one_hot(amax, num_classes=nr_mix).float() 130 | sel = torch.reshape(sel, xs[:-1] + [1, nr_mix]) 131 | # select logistic parameters 132 | means = (l[:, :, :, :, :nr_mix] * sel).sum(dim=4) 133 | log_scales = const_max((l[:, :, :, :, nr_mix:nr_mix * 2] * sel).sum(dim=4), -7.) 134 | coeffs = (torch.tanh(l[:, :, :, :, nr_mix * 2:nr_mix * 3]) * sel).sum(dim=4) 135 | # sample from logistic & clip to interval 136 | # we don't actually round to the nearest 8bit value when sampling 137 | if u is None: 138 | u = torch.empty(means.shape, device=means.device).uniform_(1e-5, 1. - 1e-5) 139 | x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u)) 140 | x0 = const_min(const_max(x[:, :, :, 0], -1.), 1.) 141 | x1 = const_min(const_max(x[:, :, :, 1] + coeffs[:, :, :, 0] * x0, -1.), 1.) 142 | x2 = const_min(const_max(x[:, :, :, 2] + coeffs[:, :, :, 1] * x0 + coeffs[:, :, :, 2] * x1, -1.), 1.) 143 | return torch.cat([torch.reshape(x0, xs[:-1] + [1]), torch.reshape(x1, xs[:-1] + [1]), torch.reshape(x2, xs[:-1] + [1])], dim=3), eps, u 144 | 145 | 146 | def backtrack(H, sampler, imle, preprocess_fn, data, logprint, training_step_imle): 147 | latents = torch.randn([data.shape[0], H.latent_dim], requires_grad=True, dtype=torch.float32, device='cuda') 148 | snoise = [torch.randn([data.shape[0], s.shape[1], s.shape[2], s.shape[3]], dtype=torch.float32, device='cuda') for s in sampler.snoise_tmp] 149 | 150 | if H.restore_latent_path: 151 | logprint('restoring latent path') 152 | latents = torch.tensor(torch.load(f'{H.restore_latent_path}/latent-best.npy'), requires_grad=True, dtype=torch.float32, device='cuda') 153 | snoise = [torch.tensor(torch.load(f'{H.restore_latent_path}/snoise-best-{s.shape[2]}.npy'), requires_grad=True, dtype=torch.float32, device='cuda') for s in sampler.snoise_tmp] 154 | 155 | latent_optimizer = AdamW([latents], lr=H.latent_lr) 156 | if H.space == 'w': 157 | latent_optimizer = AdamW([latents] + snoise, lr=H.latent_lr) 158 | # latent_optimizer = SGD([latents] + snoise, lr=H.latent_lr) 159 | dists = torch.empty([data.shape[0]], dtype=torch.float32).cuda() 160 | 161 | sampler.calc_dists_existing(data, imle, dists=dists, latents=latents, snoise=snoise) 162 | print(f'initial dists: {dists.mean()}') 163 | 164 | best_loss = np.inf 165 | num_iters = 0 166 | 167 | while num_iters < H.reconstruct_iter_num: 168 | comb_dataset = ZippedDataset(data, TensorDataset(latents)) 169 | data_loader = DataLoader(comb_dataset, batch_size=H.n_batch) 170 | for cur, indices in data_loader: 171 | x = cur 172 | lat = cur[1][0] 173 | _, target = preprocess_fn(x) 174 | cur_snoise = [s[indices] for s in snoise] 175 | training_step_imle(H, target.shape[0], target, lat, cur_snoise, imle, None, latent_optimizer, sampler.calc_loss) 176 | latents.grad.zero_() 177 | [s.grad.zero_() for s in snoise] 178 | num_iters += len(data) 179 | 180 | logprint(f'iteration: {num_iters}') 181 | # torch.save(latents.detach(), f'{H.save_dir}/latent-latest.npy') 182 | # for s in snoise: 183 | # torch.save(s.detach(), f'{H.save_dir}/snoise-latest-{s.shape[2]}.npy') 184 | 185 | sampler.calc_dists_existing(data, imle, dists=dists, latents=latents, snoise=snoise) 186 | cur_mean = dists.mean() 187 | logprint(f'cur mean: {cur_mean}, best: {best_loss}') 188 | if cur_mean < best_loss: 189 | torch.save(latents.detach(), f'{H.save_dir}/latent-best.npy') 190 | for s in snoise: 191 | torch.save(s.detach(), f'{H.save_dir}/snoise-best-{s.shape[2]}.npy') 192 | logprint(f'improved: {cur_mean}') 193 | best_loss = cur_mean 194 | for i in range(data.shape[0]): 195 | samp = sampler.sample(latents[i:i+1], imle, [s[i:i+1] for s in snoise]) 196 | imageio.imwrite(f'{H.save_dir}/{i}.png', samp[0]) 197 | imageio.imwrite(f'{H.save_dir}/{i}-real.png', data[i]) 198 | 199 | if num_iters >= H.reconstruct_iter_num: 200 | break 201 | 202 | 203 | def reconstruct(H, sampler, imle, preprocess_fn, images, latents, snoise, name, logprint, training_step_imle): 204 | latent_optimizer = AdamW([latents], lr=H.latent_lr) 205 | generate_for_NN(sampler, images, latents.detach(), snoise, images.shape, imle, 206 | f'{H.save_dir}/{name}-initial.png', logprint) 207 | for i in range(H.latent_epoch): 208 | for iter in range(H.reconstruct_iter_num): 209 | _, target = preprocess_fn([images]) 210 | stat = training_step_imle(H, target.shape[0], target, latents, snoise, imle, None, latent_optimizer, sampler.calc_loss) 211 | 212 | latents.grad.zero_() 213 | if iter % 50 == 0: 214 | print('loss is: ', stat['loss']) 215 | generate_for_NN(sampler, images, latents.detach(), snoise, images.shape, imle, 216 | f'{H.save_dir}/{name}-{iter}.png', logprint) 217 | 218 | torch.save(latents.detach(), '{}/reconstruct-latest.npy'.format(H.save_dir)) -------------------------------------------------------------------------------- /helpers/train_helpers.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | import numpy as np 5 | import socket 6 | import argparse 7 | import os 8 | import json 9 | import subprocess 10 | from hps import Hyperparams, parse_args_and_update_hparams, add_imle_arguments 11 | from helpers.utils import (logger, maybe_download) 12 | from data import mkdir_p 13 | from contextlib import contextmanager 14 | import torch.distributed as dist 15 | # from apex.optimizers import FusedAdam as AdamW 16 | from torch.optim import AdamW 17 | from models import IMLE 18 | from torch.nn.parallel.distributed import DistributedDataParallel 19 | 20 | 21 | def update_ema(imle, ema_imle, ema_rate): 22 | for p1, p2 in zip(imle.parameters(), ema_imle.parameters()): 23 | p2.data.mul_(ema_rate) 24 | p2.data.add_(p1.data * (1 - ema_rate)) 25 | 26 | 27 | def save_model(path, imle, ema_imle, optimizer, H): 28 | torch.save(imle.state_dict(), f'{path}-model.th') 29 | torch.save(ema_imle.state_dict(), f'{path}-model-ema.th') 30 | torch.save(optimizer.state_dict(), f'{path}-opt.th') 31 | from_log = os.path.join(H.save_dir, 'log.jsonl') 32 | to_log = f'{os.path.dirname(path)}/{os.path.basename(path)}-log.jsonl' 33 | subprocess.check_output(['cp', from_log, to_log]) 34 | 35 | 36 | def accumulate_stats(stats, frequency): 37 | z = {} 38 | for k in stats[-1]: 39 | if k in ['distortion_nans', 'rate_nans', 'skipped_updates', 'gcskip', 'loss_nans']: 40 | z[k] = np.sum([a[k] for a in stats[-frequency:]]) 41 | elif k == 'grad_norm': 42 | vals = [a[k] for a in stats[-frequency:]] 43 | finites = np.array(vals)[np.isfinite(vals)] 44 | if len(finites) == 0: 45 | z[k] = 0.0 46 | else: 47 | z[k] = np.max(finites) 48 | elif k == 'loss': 49 | vals = [a[k] for a in stats[-frequency:]] 50 | finites = np.array(vals)[np.isfinite(vals)] 51 | z['loss'] = np.mean(vals) 52 | z['loss_filtered'] = np.mean(finites) 53 | elif k == 'iter_time': 54 | z[k] = stats[-1][k] if len(stats) < frequency else np.mean([a[k] for a in stats[-frequency:]]) 55 | else: 56 | z[k] = np.mean([a[k] for a in stats[-frequency:]]) 57 | return z 58 | 59 | 60 | def linear_warmup(warmup_iters): 61 | def f(iteration): 62 | return 1.0 if iteration > warmup_iters else iteration / warmup_iters 63 | return f 64 | 65 | 66 | 67 | def distributed_maybe_download(path, local_rank, mpi_size): 68 | if not path.startswith('gs://'): 69 | return path 70 | filename = path[5:].replace('/', '-') 71 | with first_rank_first(local_rank, mpi_size): 72 | fp = maybe_download(path, filename) 73 | return fp 74 | 75 | 76 | @contextmanager 77 | def first_rank_first(local_rank, mpi_size): 78 | if mpi_size > 1 and local_rank > 0: 79 | dist.barrier() 80 | 81 | try: 82 | yield 83 | finally: 84 | if mpi_size > 1 and local_rank == 0: 85 | dist.barrier() 86 | 87 | 88 | def setup_save_dirs(H): 89 | H.save_dir = os.path.join(H.save_dir, H.desc) 90 | mkdir_p(H.save_dir) 91 | mkdir_p(f'H.save_dir/fid') 92 | H.logdir = os.path.join(H.save_dir, 'log') 93 | 94 | 95 | def set_up_hyperparams(s=None): 96 | H = Hyperparams() 97 | parser = argparse.ArgumentParser() 98 | parser = add_imle_arguments(parser) 99 | parse_args_and_update_hparams(H, parser, s=s) 100 | setup_save_dirs(H) 101 | logprint = logger(H.logdir) 102 | for i, k in enumerate(sorted(H)): 103 | logprint(type='hparam', key=k, value=H[k]) 104 | np.random.seed(H.seed) 105 | torch.manual_seed(H.seed) 106 | torch.cuda.manual_seed(H.seed) 107 | logprint('training model', H.desc, 'on', H.dataset) 108 | return H, logprint 109 | 110 | 111 | def restore_params(model, path, local_rank, mpi_size, map_ddp=True, map_cpu=False, strict=True): 112 | state_dict = torch.load(distributed_maybe_download(path, local_rank, mpi_size), map_location='cpu' if map_cpu else None) 113 | if map_ddp: 114 | new_state_dict = {} 115 | l = len('module.') 116 | for k in state_dict: 117 | if k.startswith('module.'): 118 | new_state_dict[k[l:]] = state_dict[k] 119 | else: 120 | new_state_dict[k] = state_dict[k] 121 | state_dict = new_state_dict 122 | model.load_state_dict(state_dict, strict=strict) 123 | 124 | 125 | def restore_log(path, local_rank, mpi_size): 126 | loaded = [json.loads(l) for l in open(distributed_maybe_download(path, local_rank, mpi_size))] 127 | try: 128 | cur_eval_loss = min([z['elbo'] for z in loaded if 'type' in z and z['type'] == 'eval_loss']) 129 | except ValueError: 130 | cur_eval_loss = float('inf') 131 | starting_epoch = max([z['epoch'] for z in loaded if 'type' in z and z['type'] == 'train_loss']) 132 | iterate = max([z['step'] for z in loaded if 'type' in z and z['type'] == 'train_loss']) 133 | return cur_eval_loss, iterate, starting_epoch 134 | 135 | 136 | def load_imle(H, logprint): 137 | imle = IMLE(H) 138 | if H.restore_path: 139 | logprint(f'Restoring imle from {H.restore_path}') 140 | restore_params(imle, H.restore_path, map_cpu=True, local_rank=H.local_rank, mpi_size=H.mpi_size, strict=H.load_strict) 141 | 142 | ema_imle = IMLE(H) 143 | if H.restore_ema_path: 144 | logprint(f'Restoring ema imle from {H.restore_ema_path}') 145 | restore_params(ema_imle, H.restore_ema_path, map_cpu=True, local_rank=H.local_rank, mpi_size=H.mpi_size, strict=H.load_strict) 146 | else: 147 | ema_imle.load_state_dict(imle.state_dict()) 148 | ema_imle.requires_grad_(False) 149 | 150 | ema_imle = ema_imle.cuda() 151 | 152 | imle = imle.cuda() 153 | imle = torch.nn.DataParallel(imle) 154 | 155 | if len(list(imle.named_parameters())) != len(list(imle.parameters())): 156 | raise ValueError('Some params are not named. Please name all params.') 157 | total_params = 0 158 | for name, p in imle.named_parameters(): 159 | total_params += np.prod(p.shape) 160 | logprint(total_params=total_params, readable=f'{total_params:,}') 161 | return imle, ema_imle 162 | 163 | 164 | def load_opt(H, imle, logprint): 165 | optimizer = AdamW(imle.parameters(), weight_decay=H.wd, lr=H.lr, betas=(H.adam_beta1, H.adam_beta2)) 166 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=linear_warmup(H.warmup_iters)) 167 | 168 | if H.restore_optimizer_path: 169 | optimizer.load_state_dict( 170 | torch.load(distributed_maybe_download(H.restore_optimizer_path, H.local_rank, H.mpi_size), map_location='cpu')) 171 | if H.restore_log_path: 172 | cur_eval_loss, iterate, starting_epoch = restore_log(H.restore_log_path, H.local_rank, H.mpi_size) 173 | else: 174 | cur_eval_loss, iterate, starting_epoch = float('inf'), 0, 0 175 | logprint('starting at epoch', starting_epoch, 'iterate', iterate, 'eval loss', cur_eval_loss) 176 | return optimizer, scheduler, cur_eval_loss, iterate, starting_epoch 177 | 178 | 179 | def save_latents(H, outer, split_ind, latents, name='latents'): 180 | Path("{}/latent/".format(H.save_dir)).mkdir(parents=True, exist_ok=True) 181 | # for ind, z in enumerate(latents): 182 | torch.save(latents, '{}/latent/{}-{}-{}.npy'.format(H.save_dir, outer, split_ind, name)) 183 | 184 | 185 | def save_snoise(H, outer, snoise): 186 | Path("{}/latent/".format(H.save_dir)).mkdir(parents=True, exist_ok=True) 187 | for sn in snoise: 188 | torch.save(sn, '{}/latent/snoise-{}-{}.npy'.format(H.save_dir, outer, sn.shape[2])) 189 | 190 | 191 | def save_latents_latest(H, split_ind, latents, name='latest'): 192 | Path("{}/latent/".format(H.save_dir)).mkdir(parents=True, exist_ok=True) 193 | # for ind, z in enumerate(latents): 194 | torch.save(latents, '{}/latent/{}-{}.npy'.format(H.save_dir, split_ind, name)) 195 | -------------------------------------------------------------------------------- /helpers/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import tempfile 4 | import numpy as np 5 | import torch 6 | import time 7 | import subprocess 8 | import torch.distributed as dist 9 | import torch.utils.data as data 10 | 11 | 12 | 13 | def allreduce(x, average): 14 | if mpi_size() > 1: 15 | dist.all_reduce(x, dist.ReduceOp.SUM) 16 | return x / mpi_size() if average else x 17 | 18 | 19 | def get_cpu_stats_over_ranks(stat_dict): 20 | keys = sorted(stat_dict.keys()) 21 | allreduced = allreduce(torch.stack([torch.as_tensor(stat_dict[k]).detach().cpu().float() for k in keys]), average=True).cpu() 22 | return {k: allreduced[i].item() for (i, k) in enumerate(keys)} 23 | 24 | 25 | class Hyperparams(dict): 26 | def __getattr__(self, attr): 27 | try: 28 | return self[attr] 29 | except KeyError: 30 | return None 31 | 32 | def __setattr__(self, attr, value): 33 | self[attr] = value 34 | 35 | 36 | def logger(log_prefix): 37 | 'Prints the arguments out to stdout, .txt, and .jsonl files' 38 | 39 | jsonl_path = f'{log_prefix}.jsonl' 40 | txt_path = f'{log_prefix}.txt' 41 | 42 | def log(*args, pprint=False, **kwargs): 43 | if mpi_rank() != 0: 44 | return 45 | t = time.ctime() 46 | argdict = {'time': t} 47 | if len(args) > 0: 48 | argdict['message'] = ' '.join([str(x) for x in args]) 49 | argdict.update(kwargs) 50 | 51 | txt_str = [] 52 | args_iter = sorted(argdict) if pprint else argdict 53 | for k in args_iter: 54 | val = argdict[k] 55 | if isinstance(val, np.ndarray): 56 | val = val.tolist() 57 | elif isinstance(val, np.integer): 58 | val = int(val) 59 | elif isinstance(val, np.floating): 60 | val = float(val) 61 | argdict[k] = val 62 | if isinstance(val, float): 63 | val = f'{val:.5f}' 64 | txt_str.append(f'{k}: {val}') 65 | txt_str = ', '.join(txt_str) 66 | 67 | if pprint: 68 | json_str = json.dumps(argdict, sort_keys=True) 69 | txt_str = json.dumps(argdict, sort_keys=True, indent=4) 70 | else: 71 | json_str = json.dumps(argdict) 72 | 73 | print(txt_str, flush=True) 74 | 75 | with open(txt_path, "a+") as f: 76 | print(txt_str, file=f, flush=True) 77 | with open(jsonl_path, "a+") as f: 78 | print(json_str, file=f, flush=True) 79 | 80 | return log 81 | 82 | 83 | def maybe_download(path, filename=None): 84 | '''If a path is a gsutil path, download it and return the local link, 85 | otherwise return link''' 86 | if not path.startswith('gs://'): 87 | return path 88 | if filename: 89 | local_dest = f'/tmp/' 90 | out_path = f'/tmp/{filename}' 91 | if os.path.isfile(out_path): 92 | return out_path 93 | subprocess.check_output(['gsutil', '-m', 'cp', '-R', path, out_path]) 94 | return out_path 95 | else: 96 | local_dest = tempfile.mkstemp()[1] 97 | subprocess.check_output(['gsutil', '-m', 'cp', path, local_dest]) 98 | return local_dest 99 | 100 | 101 | def tile_images(images, d1=4, d2=4, border=1): 102 | id1, id2, c = images[0].shape 103 | out = np.ones([d1 * id1 + border * (d1 + 1), 104 | d2 * id2 + border * (d2 + 1), 105 | c], dtype=np.uint8) 106 | out *= 255 107 | if len(images) != d1 * d2: 108 | raise ValueError('Wrong num of images') 109 | for imgnum, im in enumerate(images): 110 | num_d1 = imgnum // d2 111 | num_d2 = imgnum % d2 112 | start_d1 = num_d1 * id1 + border * (num_d1 + 1) 113 | start_d2 = num_d2 * id2 + border * (num_d2 + 1) 114 | out[start_d1:start_d1 + id1, start_d2:start_d2 + id2, :] = im 115 | return out 116 | 117 | 118 | def mpi_size(): 119 | return 0 120 | 121 | 122 | def mpi_rank(): 123 | return 0 124 | 125 | 126 | def num_nodes(): 127 | nn = mpi_size() 128 | if nn % 8 == 0: 129 | return nn // 8 130 | return nn // 8 + 1 131 | 132 | 133 | def gpus_per_node(): 134 | size = mpi_size() 135 | if size > 1: 136 | return max(size // num_nodes(), 1) 137 | return 1 138 | 139 | 140 | def local_mpi_rank(): 141 | return mpi_rank() % gpus_per_node() 142 | 143 | 144 | # def printGPUInfo(prefix=""): 145 | # print(prefix, end=" ") 146 | # deviceCount = pynvml.nvmlDeviceGetCount() 147 | # for i in range(deviceCount): 148 | # handle = pynvml.nvmlDeviceGetHandleByIndex(i) 149 | # meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) 150 | # print("GPU %d used: %d MB" % (i, meminfo.used/1048576), end=" ") 151 | # print() 152 | 153 | 154 | class ZippedDataset(data.Dataset): 155 | 156 | def __init__(self, *datasets): 157 | assert all(len(datasets[0]) == len(dataset) for dataset in datasets) 158 | self.datasets = datasets 159 | 160 | def __getitem__(self, index): 161 | # print(index, [len(x) for x in self.datasets]) 162 | return tuple(dataset[index] for dataset in self.datasets), index 163 | 164 | def __len__(self): 165 | return len(self.datasets[0]) 166 | 167 | -------------------------------------------------------------------------------- /hps.py: -------------------------------------------------------------------------------- 1 | HPARAMS_REGISTRY = {} 2 | 3 | 4 | class Hyperparams(dict): 5 | def __getattr__(self, attr): 6 | try: 7 | return self[attr] 8 | except KeyError: 9 | return None 10 | 11 | def __setattr__(self, attr, value): 12 | self[attr] = value 13 | 14 | 15 | fewshot = Hyperparams() 16 | fewshot.width = 384 17 | fewshot.lr = 0.0002 18 | fewshot.wd = 0.01 19 | fewshot.dec_blocks = '1x4,4m1,4x4,8m4,8x4,16m8,16x3,32m16,32x2,64m32,64x2,128m64,128x2,256m128' 20 | fewshot.warmup_iters = 100 21 | fewshot.dataset = 'fewshot' 22 | fewshot.n_batch = 4 23 | fewshot.ema_rate = 0.9999 24 | HPARAMS_REGISTRY['fewshot'] = fewshot 25 | 26 | def parse_args_and_update_hparams(H, parser, s=None): 27 | args = parser.parse_args(s) 28 | valid_args = set(args.__dict__.keys()) 29 | hparam_sets = [x for x in args.hparam_sets.split(',') if x] 30 | for hp_set in hparam_sets: 31 | hps = HPARAMS_REGISTRY[hp_set] 32 | for k in hps: 33 | if k not in valid_args: 34 | raise ValueError(f"{k} not in default args") 35 | parser.set_defaults(**hps) 36 | H.update(parser.parse_args(s).__dict__) 37 | 38 | 39 | def add_imle_arguments(parser): 40 | parser.add_argument('--seed', type=int, default=0) 41 | parser.add_argument('--save_dir', type=str, default='./saved_models') 42 | parser.add_argument('--data_root', type=str, default='./') 43 | parser.add_argument('--desc', type=str, default='train') 44 | parser.add_argument('--dataset', type=str, default='cifar10') # path to dataset 45 | parser.add_argument('--hparam_sets', '--hps', type=str) # e.g. 'fewshot' 46 | parser.add_argument('--enc_blocks', type=str, default=None) # specify encoder blocks, e.g. '1x2,4m1,4x4,8m4,8x5,16m8,16x8,32m16,32x5,64m32,64x4,128m64,128x4,256m128' 47 | parser.add_argument('--dec_blocks', type=str, default=None) # specify decoder blocks, e.g. '256x4,128m64,128x4,64m32,64x4,32m16,32x5,16m8,16x8,8m4,8x5,4m1,4x4,1x2' 48 | parser.add_argument('--width', type=int, default=512) # width of encoder and decoder convs 49 | parser.add_argument('--custom_width_str', type=str, default='') # custom width for each block 50 | parser.add_argument('--bottleneck_multiple', type=float, default=0.25) # coefficient width of bottleneck layers, e.g. 0.25 means 1/4 of width 51 | 52 | parser.add_argument('--restore_path', type=str, default=None) # restore from checkpoint 53 | parser.add_argument('--restore_ema_path', type=str, default=None) # restore ema from checkpoint 54 | parser.add_argument('--restore_log_path', type=str, default=None) # restore log from checkpoint 55 | parser.add_argument('--restore_optimizer_path', type=str, default=None) # restore optimizer from checkpoint 56 | parser.add_argument('--restore_latent_path', type=str, default=None) # restore nearest neighbour latent codes from checkpoint 57 | parser.add_argument('--restore_threshold_path', type=str, default=None) # restore nearest neighbour thresholds, i.e., \tau_i, from checkpoint 58 | parser.add_argument('--ema_rate', type=float, default=0.999) # exponential moving average rate 59 | parser.add_argument('--warmup_iters', type=float, default=0) # number of iterations for warmup for scheduler 60 | 61 | parser.add_argument('--lr', type=float, default=0.00015) # learning rate 62 | parser.add_argument('--wd', type=float, default=0.00) # weight decay 63 | parser.add_argument('--num_epochs', type=int, default=10000) # number of epochs 64 | parser.add_argument('--n_batch', type=int, default=4) # batch size 65 | parser.add_argument('--adam_beta1', type=float, default=0.9) 66 | parser.add_argument('--adam_beta2', type=float, default=0.9) 67 | 68 | parser.add_argument('--iters_per_ckpt', type=int, default=5000) # number of iterations per checkpoint 69 | parser.add_argument('--iters_per_save', type=int, default=1000) # number of iterations per saving the latest models 70 | parser.add_argument('--iters_per_images', type=int, default=1000) # number of iterations per sample save 71 | parser.add_argument('--num_images_visualize', type=int, default=8) # number of images to visualize 72 | parser.add_argument('--num_rows_visualize', type=int, default=3) # number of rows to visualize, e.g. 3 means 3x8=24 images 73 | 74 | parser.add_argument('--num_comp_indices', type=int, default=2) # dci number of components 75 | parser.add_argument('--num_simp_indices', type=int, default=7) # dci number of simplices 76 | parser.add_argument('--imle_db_size', type=int, default=1024) # imle database size 77 | parser.add_argument('--imle_factor', type=float, default=0.) # imle soft-sampling factor -- not used in the paper 78 | parser.add_argument('--imle_staleness', type=int, default=7) # imle staleness, i.e., number of iterations to wait before considering the thresholds, tau_i 79 | parser.add_argument('--imle_batch', type=int, default=16) # imle batch size used for sampling 80 | parser.add_argument('--subset_len', type=int, default=-1) # subset length for training -- random subset of the dataset. -1 means full dataset 81 | parser.add_argument('--latent_dim', type=int, default=1024) # latent code dimension 82 | parser.add_argument('--imle_perturb_coef', type=float, default=0.001) # imle perturbation coefficient to avoid same latent codes 83 | parser.add_argument('--lpips_net', type=str, default='vgg') # lpips network type 84 | parser.add_argument('--proj_dim', type=int, default=800) # projection dimension for nearest neighbour search 85 | parser.add_argument('--proj_proportion', type=int, default=1) # whether to use projection proportional to the lpips feature dimensions for nearest neighbour search 86 | parser.add_argument('--lpips_coef', type=float, default=1.0) # lpips loss coefficient 87 | parser.add_argument('--l2_coef', type=float, default=0.1) # l2 loss coefficient 88 | parser.add_argument('--force_factor', type=float, default=1.5) # sampling factor for imle, i.e., force_factor * len(dataset) 89 | parser.add_argument('--change_coef', type=float, default=0.04) # \gamma in the paper, rate of change of the thresholds, tau_i 90 | parser.add_argument('--change_threshold', type=float, default=1) # starting threshold 91 | parser.add_argument('--n_mpl', type=int, default=8) # mapping network layers 92 | parser.add_argument('--latent_lr', type=float, default=0.0001) # learning rate for optimizing latent codes -- not used 93 | parser.add_argument('--latent_decay', type=float, default=0.0) # learning rate decay for optimizing latent codes -- not used 94 | parser.add_argument('--latent_epoch', type=int, default=0) # number of epochs for optimizing latent codes -- not used 95 | parser.add_argument('--reconstruct_iter_num', type=int, default=100000) # number of iterations for reconstructing images using backtracking 96 | parser.add_argument('--imle_force_resample', type=int, default=30) # number of iterations to wait before ignoringthe threshold and resample anyway 97 | parser.add_argument('--snoise_factor', type=int, default=8) # spatial noise factor 98 | parser.add_argument('--max_hierarchy', type=int, default=256) # maximum hierarchy level for spatial noise, i.e., 64 means up to 64x64 spatial noise but not higher resolution 99 | parser.add_argument('--load_strict', type=int, default=1) # whether to load checkpoints strict 100 | parser.add_argument('--lpips_path', type=str, default='./lpips') # path to lpips weights 101 | parser.add_argument('--image_size', type=int, default=256) # image size of dataset -- possible to downsample the dataset 102 | parser.add_argument('--num_images_to_generate', type=int, default=100) 103 | parser.add_argument('--mode', type=str, default='train') # mode of running, train, eval, reconstruct, generate 104 | parser.add_argument('--wandb_name', type=str, default='AdaptiveIMLE') # used for wandb 105 | parser.add_argument('--wandb_project', type=str, default='AdaptiveIMLE') # used for wandb 106 | parser.add_argument('--use_wandb', type=int, default=0) 107 | parser.add_argument('--wandb_mode', type=str, default='online') 108 | 109 | # some metric args 110 | parser.add_argument("--space", choices=["z", "w"], help="space that PPL calculated with") 111 | parser.add_argument("--batch", type=int, default=16, help="batch size for the models") 112 | parser.add_argument("--n_sample", type=int, default=5000, help="number of the samples for calculating PPL",) 113 | parser.add_argument("--size", type=int, default=256, help="output image sizes of the generator") 114 | parser.add_argument("--eps", type=float, default=1e-4, help="epsilon for numerical stability") 115 | parser.add_argument("--ppl_snoise", type=int, default=0, help="whether to interpolate spatial noise in PPL") 116 | parser.add_argument("--sampling", default="end", choices=["end", "full"], help="set endpoint sampling method",) 117 | parser.add_argument("--step", type=float, default=0.1, help="step size for interpolation") 118 | parser.add_argument('--ppl_save_name', type=str, default='ppl') 119 | parser.add_argument("--fid_factor", type=int, default=5, help="number of the samples for calculating FID") 120 | parser.add_argument("--fid_freq", type=int, default=5, help="frequency of calculating fid") 121 | return parser 122 | -------------------------------------------------------------------------------- /lpips/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mehranagh20/AdaIMLE/4f785e1d650aadee62003c3425e0fc5dd10915d6/lpips/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /lpips/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mehranagh20/AdaIMLE/4f785e1d650aadee62003c3425e0fc5dd10915d6/lpips/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /lpips/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mehranagh20/AdaIMLE/4f785e1d650aadee62003c3425e0fc5dd10915d6/lpips/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /mapping_network.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class PixelNorm(nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | 11 | def forward(self, input): 12 | return input / torch.sqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) 13 | 14 | 15 | class EqualLR: 16 | def __init__(self, name): 17 | self.name = name 18 | 19 | def compute_weight(self, module): 20 | weight = getattr(module, self.name + '_orig') 21 | fan_in = weight.data.size(1) * weight.data[0][0].numel() 22 | 23 | return weight * sqrt(2 / fan_in) 24 | 25 | @staticmethod 26 | def apply(module, name): 27 | fn = EqualLR(name) 28 | 29 | weight = getattr(module, name) 30 | del module._parameters[name] 31 | module.register_parameter(name + '_orig', nn.Parameter(weight.data)) 32 | module.register_forward_pre_hook(fn) 33 | 34 | return fn 35 | 36 | def __call__(self, module, input): 37 | weight = self.compute_weight(module) 38 | setattr(module, self.name, weight) 39 | 40 | 41 | def equal_lr(module, name='weight'): 42 | EqualLR.apply(module, name) 43 | 44 | return module 45 | 46 | 47 | class EqualLinear(nn.Module): 48 | def __init__(self, in_dim, out_dim): 49 | super().__init__() 50 | 51 | linear = nn.Linear(in_dim, out_dim) 52 | # linear.weight.data.normal_() 53 | linear.bias.data.zero_() 54 | 55 | self.linear = linear 56 | 57 | def forward(self, input): 58 | return self.linear(input) 59 | 60 | 61 | class MappingNetowrk(nn.Module): 62 | def __init__(self, code_dim=512, n_mlp=8): 63 | super().__init__() 64 | 65 | layers = [PixelNorm()] 66 | for i in range(n_mlp): 67 | layers.append(EqualLinear(code_dim, code_dim)) 68 | layers.append(nn.LeakyReLU(0.2)) 69 | 70 | self.style = nn.Sequential(*layers) 71 | 72 | def forward( 73 | self, 74 | input, 75 | noise=None, 76 | step=0, 77 | alpha=-1, 78 | mean_style=None, 79 | style_weight=0, 80 | mixing_range=(-1, -1), 81 | ): 82 | styles = [] 83 | if type(input) not in (list, tuple): 84 | input = [input] 85 | 86 | for i in input: 87 | x = self.style(i) 88 | styles.append(x) 89 | 90 | # batch = input[0].shape[0] 91 | # 92 | # if noise is None: 93 | # noise = [] 94 | # 95 | # for i in range(step + 1): 96 | # size = 4 * 2 ** i 97 | # noise.append(torch.randn(batch, 1, size, size, device=input[0].device)) 98 | 99 | # if mean_style is not None: 100 | # styles_norm = [] 101 | # 102 | # for style in styles: 103 | # styles_norm.append(mean_style + style_weight * (style - mean_style)) 104 | # 105 | # styles = styles_norm 106 | 107 | return styles 108 | 109 | # def mean_style(self, input): 110 | # style = self.style(input).mean(0, keepdim=True) 111 | # 112 | # return style 113 | 114 | 115 | class AdaptiveInstanceNorm(nn.Module): 116 | def __init__(self, in_channel, style_dim): 117 | super().__init__() 118 | 119 | self.norm = nn.InstanceNorm2d(in_channel) 120 | self.style = EqualLinear(style_dim, in_channel * 2) 121 | 122 | self.style.linear.bias.data[:in_channel] = 1 123 | self.style.linear.bias.data[in_channel:] = 0 124 | 125 | def forward(self, input, style): 126 | style = self.style(style).unsqueeze(2).unsqueeze(3) 127 | gamma, beta = style.chunk(2, 1) 128 | 129 | out = input 130 | if input.shape[3] > 1: 131 | out = self.norm(input) 132 | out = gamma * out + beta 133 | return out 134 | 135 | 136 | class NoiseInjection(nn.Module): 137 | def __init__(self, channel): 138 | super().__init__() 139 | 140 | self.weight = nn.Parameter(torch.randn(1, channel, 1, 1)) 141 | 142 | def forward(self, image, spatial_noise): 143 | return image + self.weight * spatial_noise 144 | -------------------------------------------------------------------------------- /metrics/ppl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | import numpy as np 6 | from tqdm import tqdm 7 | # from LPNet import LPNet 8 | 9 | from models import parse_layer_string 10 | 11 | 12 | def normalize(x): 13 | return x / torch.sqrt(x.pow(2).sum(-1, keepdim=True)) 14 | 15 | 16 | def slerp(a, b, t): 17 | a = normalize(a) 18 | b = normalize(b) 19 | d = (a * b).sum(-1, keepdim=True) 20 | p = t * torch.acos(d) 21 | c = normalize(b - d * a) 22 | d = a * torch.cos(p) + c * torch.sin(p) 23 | 24 | return normalize(d) 25 | 26 | 27 | def lerp(a, b, t): 28 | return a + (b - a) * t 29 | 30 | 31 | def calc_ppl(args, g, sampler): 32 | device = "cuda" 33 | 34 | latent_dim = args.latent_dim 35 | g.eval() 36 | 37 | percept = sampler.calc_loss 38 | 39 | distances = [] 40 | 41 | n_batch = args.n_sample // args.n_batch 42 | resid = args.n_sample - (n_batch * args.n_batch) 43 | batch_sizes = [args.n_batch] * n_batch 44 | if resid: 45 | batch_sizes.append(resid) 46 | 47 | blocks = parse_layer_string(args.dec_blocks) 48 | res = sorted(set([s[0] for s in blocks if s[0] <= args.max_hierarchy])) 49 | 50 | with torch.no_grad(): 51 | for batch in tqdm(batch_sizes): 52 | snoise = [torch.randn([batch * 2, 1, s, s], dtype=torch.float32).cuda() for s in res] 53 | snoise_e = [torch.randn([batch * 2, 1, s, s], dtype=torch.float32) for s in res] 54 | 55 | latent = torch.randn([batch * 2, latent_dim], device=device) 56 | if args.sampling == "full": 57 | lerp_t = torch.rand(batch, device=device) 58 | else: 59 | lerp_t = torch.zeros(batch, device=device) 60 | 61 | if args.space == "w": 62 | latent = g.module.decoder.mapping_network(latent)[0] 63 | 64 | # snoise_t0, snoise_t1 = [sn[::2] for sn in snoise], [sn[1::2] for sn in snoise] 65 | latent_t0, latent_t1 = latent[::2], latent[1::2] 66 | latent_e0 = lerp(latent_t0, latent_t1, lerp_t[:, None]) 67 | latent_e1 = lerp(latent_t0, latent_t1, lerp_t[:, None] + args.eps) 68 | latent_e = torch.stack([latent_e0, latent_e1], 1).view(*latent.shape) 69 | 70 | for i in range(len(snoise)): 71 | snoise_t0, snoise_t1 = snoise[i][::2], snoise[i][1::2] 72 | snoise_e0 = lerp(snoise_t0, snoise_t1, lerp_t[:, None, None, None]) 73 | snoise_e1 = lerp(snoise_t0, snoise_t1, lerp_t[:, None, None, None] + args.eps) 74 | snoise_p = torch.stack([snoise_e0, snoise_e1], 1).view(*snoise[i].shape) 75 | snoise_e[i] = snoise_p 76 | 77 | if args.ppl_snoise == 0: 78 | for i in range(len(snoise)): 79 | snoise[i][::2] = snoise[i][1::2] 80 | snoise_e = snoise 81 | 82 | 83 | image = g(latent_e, snoise_e, input_is_w=args.space=="w") 84 | 85 | if args.crop: 86 | c = image.shape[2] // 8 87 | image = image[:, :, c * 3 : c * 7, c * 2 : c * 6] 88 | 89 | factor = image.shape[2] // 256 90 | 91 | if factor > 1: 92 | image = F.interpolate( 93 | image, size=(256, 256), mode="bilinear", align_corners=False 94 | ) 95 | 96 | dist = percept(image[::2], image[1::2], use_mean=False).view(image.shape[0] // 2) / ( 97 | args.eps ** 2 98 | ) 99 | distances.append(dist.to("cpu").numpy()) 100 | 101 | distances = np.concatenate(distances, 0) 102 | 103 | lo = np.percentile(distances, 1, interpolation="lower") 104 | hi = np.percentile(distances, 99, interpolation="higher") 105 | filtered_dist = np.extract( 106 | np.logical_and(lo <= distances, distances <= hi), distances 107 | ) 108 | 109 | print(f"{args.restore_path}, ppl:", filtered_dist.mean()) 110 | -------------------------------------------------------------------------------- /metrics/ppl_uniform.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | import numpy as np 6 | from tqdm import tqdm 7 | # from LPNet import LPNet 8 | from collections import defaultdict 9 | 10 | from models import parse_layer_string 11 | import pandas as pd 12 | 13 | 14 | def normalize(x): 15 | return x / torch.sqrt(x.pow(2).sum(-1, keepdim=True)) 16 | 17 | 18 | def slerp(a, b, t): 19 | a = normalize(a) 20 | b = normalize(b) 21 | d = (a * b).sum(-1, keepdim=True) 22 | p = t * torch.acos(d) 23 | c = normalize(b - d * a) 24 | d = a * torch.cos(p) + c * torch.sin(p) 25 | 26 | return normalize(d) 27 | 28 | 29 | def lerp(a, b, t): 30 | return a + (b - a) * t 31 | 32 | 33 | def calc_ppl_uniform(args, g, sampler): 34 | device = "cuda" 35 | 36 | latent_dim = args.latent_dim 37 | g.eval() 38 | 39 | percept = sampler.calc_loss 40 | 41 | distances = [] 42 | 43 | n_batch = args.n_sample // args.n_batch 44 | resid = args.n_sample - (n_batch * args.n_batch) 45 | batch_sizes = [args.n_batch] * n_batch 46 | if resid: 47 | batch_sizes.append(resid) 48 | 49 | steps = np.arange(0, 1+args.step, args.step) 50 | print(steps) 51 | output_dists = defaultdict(list) 52 | 53 | blocks = parse_layer_string(args.dec_blocks) 54 | res = sorted(set([s[0] for s in blocks if s[0] <= args.max_hierarchy])) 55 | 56 | with torch.no_grad(): 57 | for batch in tqdm(batch_sizes): 58 | snoise = [torch.randn([batch * 2, 1, s, s], dtype=torch.float32).cuda() for s in res] 59 | snoise_e = [torch.randn([batch * 2, 1, s, s], dtype=torch.float32) for s in res] 60 | 61 | latent = torch.randn([batch * 2, latent_dim], device=device) 62 | if args.sampling == "full": 63 | lerp_t = torch.rand(batch, device=device) 64 | else: 65 | lerp_t = torch.zeros(batch, device=device) 66 | 67 | if args.space == "w": 68 | latent = g.module.decoder.mapping_network(latent)[0] 69 | 70 | for i in range(len(steps) - 1): 71 | prev_step, step = steps[i], steps[i + 1] 72 | 73 | latent_t0, latent_t1 = latent[::2], latent[1::2] 74 | latent_e0 = lerp(latent_t0, latent_t1, lerp_t[:, None] + prev_step) 75 | latent_e1 = lerp(latent_t0, latent_t1, lerp_t[:, None] + step) 76 | latent_e = torch.stack([latent_e0, latent_e1], 1).view(*latent.shape) 77 | 78 | if args.ppl_snoise > 0: 79 | for j in range(len(snoise)): 80 | snoise_t0, snoise_t1 = snoise[j][::2], snoise[j][1::2] 81 | snoise_e0 = lerp(snoise_t0, snoise_t1, lerp_t[:, None, None, None] + prev_step) 82 | snoise_e1 = lerp(snoise_t0, snoise_t1, lerp_t[:, None, None, None] + args.eps) 83 | snoise_p = torch.stack([snoise_e0, snoise_e1], 1).view(*snoise[j].shape) 84 | snoise_e[j] = snoise_p 85 | 86 | if args.ppl_snoise == 0: 87 | for j in range(len(snoise)): 88 | snoise[j][::2] = snoise[j][1::2] 89 | snoise_e = snoise 90 | 91 | image = g(latent_e, snoise, input_is_w=args.space=="w") 92 | 93 | if args.crop: 94 | c = image.shape[2] // 8 95 | image = image[:, :, c * 3 : c * 7, c * 2 : c * 6] 96 | 97 | if i==0: 98 | start = image[::2] 99 | if i==len(steps)-2: 100 | end = image[1::2] 101 | endpoint_dist = percept(start, end, use_mean=False).view(image.shape[0] // 2) / ( 102 | args.step ** 2 103 | ) 104 | output_dists[str(i+1)].append(endpoint_dist.to("cpu").numpy()) 105 | 106 | 107 | factor = image.shape[2] // 256 108 | 109 | if factor > 1: 110 | image = F.interpolate( 111 | image, size=(256, 256), mode="bilinear", align_corners=False 112 | ) 113 | 114 | dist = percept(image[::2], image[1::2], use_mean=False).view(image.shape[0] // 2) / ( 115 | args.step ** 2 116 | ) 117 | output_dists[str(i)].append(dist.to("cpu").numpy()) 118 | 119 | distances = dict() 120 | for i in range(len(steps)): 121 | temp = np.concatenate(output_dists[str(i)], 0) 122 | 123 | # lo = np.percentile(temp, 1, interpolation="lower") 124 | # hi = np.percentile(temp, 99, interpolation="higher") 125 | # filtered_dist = np.extract( 126 | # np.logical_and(lo <= temp, temp <= hi), temp 127 | # ) 128 | distances[str(i)] = temp 129 | # print(distances) 130 | distances['endpoint'] = distances[f"{len(steps)-1}"] 131 | del distances[f"{len(steps)-1}"] 132 | # output_dists.append(filtered_dist.mean()) 133 | dist_df = pd.DataFrame(distances, index=list(range(len(distances['0'])))) 134 | # print(f"ppls: {dist_df}") 135 | means = dist_df.drop(columns=['endpoint']).mean(axis=1) 136 | stds = dist_df.drop(columns=['endpoint']).std(axis=1) 137 | assert len(stds) == len(distances['0']) 138 | # print(sum(stds)/len(stds)) 139 | print(means.shape) 140 | print("Mean: ", means.mean()) 141 | print("Std.Dev: ", stds.mean()) 142 | print("Endpoint Mean: ", dist_df['endpoint'].mean()) 143 | # ckpt_num = int(args.ckpt.split("/")[-1][:-3]) 144 | # save_dir = "/".join(args.ckpt.split("/")[:-2])+f"/ppl_uniform_at_{ckpt_num}.csv" 145 | dist_df.to_csv(args.save_dir + f'/{args.ppl_save_name}.csv') 146 | # output_dists = [dist.astype(np.float64) for dist in distances.items()] 147 | 148 | # print("ppl:", filtered_dist.mean()) 149 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from mapping_network import MappingNetowrk, AdaptiveInstanceNorm, NoiseInjection 6 | from helpers.imle_helpers import get_1x1, get_3x3, draw_gaussian_diag_samples, gaussian_analytical_kl 7 | from collections import defaultdict 8 | import numpy as np 9 | import itertools 10 | 11 | 12 | class Block(nn.Module): 13 | def __init__(self, in_width, middle_width, out_width, down_rate=None, residual=False, use_3x3=True, zero_last=False): 14 | super().__init__() 15 | self.down_rate = down_rate 16 | self.residual = residual 17 | self.c1 = get_1x1(in_width, middle_width) 18 | self.c2 = get_3x3(middle_width, middle_width) if use_3x3 else get_1x1(middle_width, middle_width) 19 | self.c3 = get_3x3(middle_width, middle_width) if use_3x3 else get_1x1(middle_width, middle_width) 20 | self.c4 = get_1x1(middle_width, out_width, zero_weights=zero_last) 21 | 22 | def forward(self, x): 23 | xhat = self.c1(F.gelu(x)) 24 | xhat = self.c2(F.gelu(xhat)) 25 | xhat = self.c3(F.gelu(xhat)) 26 | xhat = self.c4(F.gelu(xhat)) 27 | out = x + xhat if self.residual else xhat 28 | if self.down_rate is not None: 29 | out = F.avg_pool2d(out, kernel_size=self.down_rate, stride=self.down_rate) 30 | return out 31 | 32 | 33 | def parse_layer_string(s): 34 | layers = [] 35 | for ss in s.split(','): 36 | if 'x' in ss: 37 | res, num = ss.split('x') 38 | count = int(num) 39 | layers += [(int(res), None) for _ in range(count)] 40 | elif 'm' in ss: 41 | res, mixin = [int(a) for a in ss.split('m')] 42 | layers.append((res, mixin)) 43 | elif 'd' in ss: 44 | res, down_rate = [int(a) for a in ss.split('d')] 45 | layers.append((res, down_rate)) 46 | else: 47 | res = int(ss) 48 | layers.append((res, None)) 49 | return layers 50 | 51 | 52 | def pad_channels(t, width): 53 | d1, d2, d3, d4 = t.shape 54 | empty = torch.zeros(d1, width, d3, d4, device=t.device) 55 | empty[:, :d2, :, :] = t 56 | return empty 57 | 58 | 59 | def get_width_settings(width, s): 60 | mapping = defaultdict(lambda: width) 61 | if s: 62 | s = s.split(',') 63 | for ss in s: 64 | k, v = ss.split(':') 65 | mapping[int(k)] = int(v) 66 | return mapping 67 | 68 | 69 | class DecBlock(nn.Module): 70 | def __init__(self, H, res, mixin, n_blocks): 71 | super().__init__() 72 | self.base = res 73 | self.mixin = mixin 74 | self.H = H 75 | self.widths = get_width_settings(H.width, H.custom_width_str) 76 | width = self.widths[res] 77 | if res <= H.max_hierarchy: 78 | self.noise = NoiseInjection(width) 79 | self.adaIN = AdaptiveInstanceNorm(width, H.latent_dim) 80 | use_3x3 = res > 2 81 | cond_width = int(width * H.bottleneck_multiple) 82 | self.resnet = Block(width, cond_width, width, residual=True, use_3x3=use_3x3) 83 | self.resnet.c4.weight.data *= np.sqrt(1 / n_blocks) 84 | 85 | def forward(self, x, w, spatial_noise): 86 | if self.mixin is not None: 87 | x = F.interpolate(x, scale_factor=self.base // self.mixin) 88 | if self.base <= self.H.max_hierarchy: 89 | x = self.noise(x, spatial_noise) 90 | x = self.adaIN(x, w) 91 | x = self.resnet(x) 92 | return x 93 | 94 | 95 | class Decoder(nn.Module): 96 | def __init__(self, H): 97 | super().__init__() 98 | self.H = H 99 | self.mapping_network = MappingNetowrk(code_dim=H.latent_dim, n_mlp=H.n_mpl) 100 | resos = set() 101 | cond_width = int(H.width * H.bottleneck_multiple) 102 | dec_blocks = [] 103 | self.widths = get_width_settings(H.width, H.custom_width_str) 104 | blocks = parse_layer_string(H.dec_blocks) 105 | for idx, (res, mixin) in enumerate(blocks): 106 | dec_blocks.append(DecBlock(H, res, mixin, n_blocks=len(blocks))) 107 | resos.add(res) 108 | self.resolutions = sorted(resos) 109 | self.dec_blocks = nn.ModuleList(dec_blocks) 110 | first_res = self.resolutions[0] 111 | self.constant = nn.Parameter(torch.randn(1, self.widths[first_res], first_res, first_res)) 112 | self.resnet = get_1x1(H.width, H.image_channels) 113 | self.gain = nn.Parameter(torch.ones(1, H.image_channels, 1, 1)) 114 | self.bias = nn.Parameter(torch.zeros(1, H.image_channels, 1, 1)) 115 | 116 | def forward(self, latent_code, spatial_noise, input_is_w=False): 117 | if not input_is_w: 118 | w = self.mapping_network(latent_code)[0] 119 | else: 120 | w = latent_code 121 | 122 | x = self.constant.repeat(latent_code.shape[0], 1, 1, 1) 123 | if spatial_noise: 124 | res_to_noise = {x.shape[3]: x for x in spatial_noise} 125 | for idx, block in enumerate(self.dec_blocks): 126 | noise = None 127 | if block.base <= self.H.max_hierarchy: 128 | noise = res_to_noise[block.base] 129 | x = block(x, w, noise) 130 | x = self.resnet(x) 131 | x = self.gain * x + self.bias 132 | return x 133 | 134 | 135 | class IMLE(nn.Module): 136 | def __init__(self, H): 137 | super().__init__() 138 | self.dci_db = None 139 | self.decoder = Decoder(H) 140 | 141 | def forward(self, latents, spatial_noise=None, input_is_w=False): 142 | return self.decoder.forward(latents, spatial_noise, input_is_w) 143 | 144 | -------------------------------------------------------------------------------- /reproduce/100-shot-grumpy_cat.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python train.py --hps fewshot \ 4 | --data_root ./datasets/100-shot-grumpy_cat \ 5 | --change_coef 0.02 \ 6 | --force_factor 100 \ 7 | --imle_staleness 5 \ 8 | --imle_force_resample 25 \ 9 | --lr 0.0001 \ -------------------------------------------------------------------------------- /reproduce/100-shot-obama.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python train.py --hps fewshot \ 4 | --data_root ./datasets/100-shot-obama \ 5 | --change_coef 0.01 \ 6 | --force_factor 100 \ 7 | --imle_staleness 5 \ 8 | --imle_force_resample 15 \ 9 | --lr 0.00005 \ -------------------------------------------------------------------------------- /reproduce/100-shot-panda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python train.py --hps fewshot \ 4 | --data_root ./datasets/100-shot-panda \ 5 | --change_coef 0.01 \ 6 | --force_factor 200 \ 7 | --imle_staleness 5 \ 8 | --imle_force_resample 30 \ 9 | --lr 0.0001 \ -------------------------------------------------------------------------------- /reproduce/AnimalFace-cat.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python train.py --hps fewshot \ 4 | --data_root ./datasets/AnimalFace-cat \ 5 | --change_coef 0.01 \ 6 | --force_factor 100 \ 7 | --imle_staleness 5 \ 8 | --imle_force_resample 25 \ 9 | --lr 0.00005 \ -------------------------------------------------------------------------------- /reproduce/AnimalFace-dog.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python train.py --hps fewshot \ 4 | --data_root ./datasets/AnimalFace-dog \ 5 | --change_coef 0.01 \ 6 | --force_factor 100 \ 7 | --imle_staleness 5 \ 8 | --imle_force_resample 15 \ 9 | --lr 0.0001 \ -------------------------------------------------------------------------------- /reproduce/ffhq.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python train.py --hps fewshot \ 4 | --data_root ./datasets/ffhq \ 5 | --change_coef 0.02 \ 6 | --force_factor 100 \ 7 | --imle_staleness 5 \ 8 | --imle_force_resample 15 \ 9 | --lr 0.0001 \ -------------------------------------------------------------------------------- /sampler.py: -------------------------------------------------------------------------------- 1 | from curses import update_lines_cols 2 | from math import comb 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.utils.data import DataLoader, TensorDataset 9 | 10 | from LPNet import LPNet 11 | from dciknn_cuda import DCI, MDCI 12 | from torch.optim import AdamW 13 | from helpers.utils import ZippedDataset 14 | from models import parse_layer_string 15 | 16 | 17 | class Sampler: 18 | def __init__(self, H, sz, preprocess_fn): 19 | self.pool_size = int(H.force_factor * sz) 20 | self.preprocess_fn = preprocess_fn 21 | self.l2_loss = torch.nn.MSELoss(reduce=False).cuda() 22 | self.H = H 23 | self.latent_lr = H.latent_lr 24 | self.entire_ds = torch.arange(sz) 25 | self.selected_latents = torch.empty([sz, H.latent_dim], dtype=torch.float32) 26 | self.selected_latents_tmp = torch.empty([sz, H.latent_dim], dtype=torch.float32) 27 | 28 | blocks = parse_layer_string(H.dec_blocks) 29 | self.block_res = [s[0] for s in blocks] 30 | self.res = sorted(set([s[0] for s in blocks if s[0] <= H.max_hierarchy])) 31 | self.neutral_snoise = [torch.zeros([self.H.imle_db_size, 1, s, s], dtype=torch.float32) for s in self.res] 32 | self.snoise_tmp = [torch.randn([self.H.imle_db_size, 1, s, s], dtype=torch.float32) for s in self.res] 33 | self.selected_snoise = [torch.randn([sz, 1, s, s,], dtype=torch.float32) for s in self.res] 34 | self.snoise_pool = [torch.randn([self.pool_size, 1, s, s], dtype=torch.float32) for s in self.res] 35 | 36 | self.selected_dists = torch.empty([sz], dtype=torch.float32).cuda() 37 | self.selected_dists[:] = np.inf 38 | self.selected_dists_tmp = torch.empty([sz], dtype=torch.float32).cuda() 39 | self.temp_latent_rnds = torch.empty([self.H.imle_db_size, self.H.latent_dim], dtype=torch.float32) 40 | self.temp_samples = torch.empty([self.H.imle_db_size, H.image_channels, self.H.image_size, self.H.image_size], 41 | dtype=torch.float32) 42 | 43 | self.pool_latents = torch.randn([self.pool_size, H.latent_dim], dtype=torch.float32) 44 | self.sample_pool_usage = torch.ones([sz], dtype=torch.bool) 45 | 46 | self.projections = [] 47 | self.lpips_net = LPNet(pnet_type=H.lpips_net, path=H.lpips_path).cuda() 48 | 49 | fake = torch.zeros(1, 3, H.image_size, H.image_size).cuda() 50 | out, shapes = self.lpips_net(fake) 51 | dims = [int(H.proj_dim * 1. / len(out)) for _ in range(len(out))] 52 | if H.proj_proportion: 53 | sm = sum([dim.shape[1] for dim in out]) 54 | dims = [int(out[feat_ind].shape[1] * (H.proj_dim / sm)) for feat_ind in range(len(out) - 1)] 55 | dims.append(H.proj_dim - sum(dims)) 56 | print(dims) 57 | for ind, feat in enumerate(out): 58 | print(feat.shape) 59 | self.projections.append(F.normalize(torch.randn(feat.shape[1], dims[ind]), p=2, dim=1).cuda()) 60 | 61 | self.temp_samples_proj = torch.empty([self.H.imle_db_size, sum(dims)], dtype=torch.float32).cuda() 62 | self.dataset_proj = torch.empty([sz, sum(dims)], dtype=torch.float32) 63 | self.pool_samples_proj = torch.empty([self.pool_size, sum(dims)], dtype=torch.float32) 64 | self.snoise_pool_samples_proj = torch.empty([sz * H.snoise_factor, sum(dims)], dtype=torch.float32) 65 | 66 | def get_projected(self, inp, permute=True): 67 | if permute: 68 | out, _ = self.lpips_net(inp.permute(0, 3, 1, 2).cuda()) 69 | else: 70 | out, _ = self.lpips_net(inp.cuda()) 71 | gen_feat = [] 72 | for i in range(len(out)): 73 | gen_feat.append(torch.mm(out[i], self.projections[i])) 74 | # TODO divide? 75 | return torch.cat(gen_feat, dim=1) 76 | 77 | def init_projection(self, dataset): 78 | for proj_mat in self.projections: 79 | proj_mat[:] = F.normalize(torch.randn(proj_mat.shape), p=2, dim=1) 80 | 81 | for ind, x in enumerate(DataLoader(TensorDataset(dataset), batch_size=self.H.n_batch)): 82 | batch_slice = slice(ind * self.H.n_batch, ind * self.H.n_batch + x[0].shape[0]) 83 | self.dataset_proj[batch_slice] = self.get_projected(self.preprocess_fn(x)[1]) 84 | 85 | def sample(self, latents, gen, snoise=None): 86 | with torch.no_grad(): 87 | nm = latents.shape[0] 88 | if snoise is None: 89 | for i in range(len(self.res)): 90 | self.snoise_tmp[i].normal_() 91 | snoise = [s[:nm] for s in self.snoise_tmp] 92 | px_z = gen(latents, snoise).permute(0, 2, 3, 1) 93 | xhat = (px_z + 1.0) * 127.5 94 | xhat = xhat.detach().cpu().numpy() 95 | xhat = np.minimum(np.maximum(0.0, xhat), 255.0).astype(np.uint8) 96 | return xhat 97 | 98 | def sample_from_out(self, px_z): 99 | with torch.no_grad(): 100 | px_z = px_z.permute(0, 2, 3, 1) 101 | xhat = (px_z + 1.0) * 127.5 102 | xhat = xhat.detach().cpu().numpy() 103 | xhat = np.minimum(np.maximum(0.0, xhat), 255.0).astype(np.uint8) 104 | return xhat 105 | 106 | def calc_loss(self, inp, tar, use_mean=True): 107 | inp_feat, inp_shape = self.lpips_net(inp) 108 | tar_feat, _ = self.lpips_net(tar) 109 | res = 0 110 | for i, g_feat in enumerate(inp_feat): 111 | res += torch.sum((g_feat - tar_feat[i]) ** 2, dim=1) / (inp_shape[i] ** 2) 112 | if use_mean: 113 | return self.H.lpips_coef * res.mean() + self.H.l2_coef * self.l2_loss(inp, tar).mean() 114 | else: 115 | return self.H.lpips_coef * res + self.H.l2_coef * torch.mean(self.l2_loss(inp, tar), dim=[1, 2, 3]) 116 | 117 | def calc_dists_existing(self, dataset_tensor, gen, dists=None, latents=None, to_update=None, snoise=None): 118 | if dists is None: 119 | dists = self.selected_dists 120 | if latents is None: 121 | latents = self.selected_latents 122 | if snoise is None: 123 | snoise = self.selected_snoise 124 | 125 | if to_update is not None: 126 | latents = latents[to_update] 127 | dists = dists[to_update] 128 | dataset_tensor = dataset_tensor[to_update] 129 | snoise = [s[to_update] for s in snoise] 130 | 131 | for ind, x in enumerate(DataLoader(TensorDataset(dataset_tensor), batch_size=self.H.n_batch)): 132 | _, target = self.preprocess_fn(x) 133 | batch_slice = slice(ind * self.H.n_batch, ind * self.H.n_batch + target.shape[0]) 134 | cur_latents = latents[batch_slice] 135 | cur_snoise = [s[batch_slice] for s in snoise] 136 | with torch.no_grad(): 137 | out = gen(cur_latents, cur_snoise) 138 | dist = self.calc_loss(target.permute(0, 3, 1, 2), out, use_mean=False) 139 | dists[batch_slice] = torch.squeeze(dist) 140 | return dists 141 | 142 | def imle_sample(self, dataset, gen, factor=None): 143 | if factor is None: 144 | factor = self.H.imle_factor 145 | imle_pool_size = int(len(dataset) * factor) 146 | t1 = time.time() 147 | self.selected_dists_tmp[:] = self.selected_dists[:] 148 | for i in range(imle_pool_size // self.H.imle_db_size): 149 | self.temp_latent_rnds.normal_() 150 | for j in range(len(self.res)): 151 | self.snoise_tmp[j].normal_() 152 | for j in range(self.H.imle_db_size // self.H.imle_batch): 153 | batch_slice = slice(j * self.H.imle_batch, (j + 1) * self.H.imle_batch) 154 | cur_latents = self.temp_latent_rnds[batch_slice] 155 | cur_snoise = [x[batch_slice] for x in self.snoise_tmp] 156 | with torch.no_grad(): 157 | self.temp_samples[batch_slice] = gen(cur_latents, cur_snoise) 158 | self.temp_samples_proj[batch_slice] = self.get_projected(self.temp_samples[batch_slice], False) 159 | 160 | if not gen.module.dci_db: 161 | device_count = torch.cuda.device_count() 162 | gen.module.dci_db = MDCI(self.temp_samples_proj.shape[1], num_comp_indices=self.H.num_comp_indices, 163 | num_simp_indices=self.H.num_simp_indices, devices=[i for i in range(device_count)], ts=device_count) 164 | 165 | # gen.module.dci_db = DCI(self.temp_samples_proj.shape[1], num_comp_indices=self.H.num_comp_indices, 166 | # num_simp_indices=self.H.num_simp_indices) 167 | gen.module.dci_db.add(self.temp_samples_proj) 168 | 169 | t0 = time.time() 170 | for ind, y in enumerate(DataLoader(dataset, batch_size=self.H.imle_batch)): 171 | # t2 = time.time() 172 | _, target = self.preprocess_fn(y) 173 | x = self.dataset_proj[ind * self.H.imle_batch:ind * self.H.imle_batch + target.shape[0]] 174 | cur_batch_data_flat = x.float() 175 | nearest_indices, _ = gen.module.dci_db.query(cur_batch_data_flat, num_neighbours=1) 176 | nearest_indices = nearest_indices.long()[:, 0] 177 | 178 | batch_slice = slice(ind * self.H.imle_batch, ind * self.H.imle_batch + x.size()[0]) 179 | actual_selected_dists = self.calc_loss(target.permute(0, 3, 1, 2), 180 | self.temp_samples[nearest_indices].cuda(), use_mean=False) 181 | # actual_selected_dists = torch.squeeze(actual_selected_dists) 182 | 183 | to_update = torch.nonzero(actual_selected_dists < self.selected_dists[batch_slice], as_tuple=False) 184 | to_update = torch.squeeze(to_update) 185 | self.selected_dists[ind * self.H.imle_batch + to_update] = actual_selected_dists[to_update].clone() 186 | self.selected_latents[ind * self.H.imle_batch + to_update] = self.temp_latent_rnds[nearest_indices[to_update]].clone() 187 | for k in range(len(self.res)): 188 | self.selected_snoise[k][ind * self.H.imle_batch + to_update] = self.snoise_tmp[k][nearest_indices[to_update]].clone() 189 | 190 | del cur_batch_data_flat 191 | 192 | gen.module.dci_db.clear() 193 | 194 | # adding perturbation 195 | changed = torch.sum(self.selected_dists_tmp != self.selected_dists).item() 196 | print("Samples and NN are calculated, time: {}, mean: {} # changed: {}, {}%".format(time.time() - t1, 197 | self.selected_dists.mean(), 198 | changed, (changed / len( 199 | dataset)) * 100)) 200 | 201 | def resample_pool(self, gen, ds): 202 | # self.init_projection(ds) 203 | self.pool_latents.normal_() 204 | for i in range(len(self.res)): 205 | self.snoise_pool[i].normal_() 206 | 207 | for j in range(self.pool_size // self.H.imle_batch): 208 | batch_slice = slice(j * self.H.imle_batch, (j + 1) * self.H.imle_batch) 209 | cur_latents = self.pool_latents[batch_slice] 210 | cur_snosie = [s[batch_slice] for s in self.snoise_pool] 211 | with torch.no_grad(): 212 | self.pool_samples_proj[batch_slice] = self.get_projected(gen(cur_latents, cur_snosie), False) 213 | 214 | def imle_sample_force(self, dataset, gen, to_update=None): 215 | if to_update is None: 216 | to_update = self.entire_ds 217 | if to_update.shape[0] == 0: 218 | return 219 | 220 | t1 = time.time() 221 | print(torch.any(self.sample_pool_usage[to_update]), torch.any(self.sample_pool_usage)) 222 | if torch.any(self.sample_pool_usage[to_update]): 223 | self.resample_pool(gen, dataset) 224 | self.sample_pool_usage[:] = False 225 | print(f'resampling took {time.time() - t1}') 226 | 227 | self.selected_dists_tmp[:] = np.inf 228 | self.sample_pool_usage[to_update] = True 229 | 230 | with torch.no_grad(): 231 | for i in range(self.pool_size // self.H.imle_db_size): 232 | pool_slice = slice(i * self.H.imle_db_size, (i + 1) * self.H.imle_db_size) 233 | if not gen.module.dci_db: 234 | device_count = torch.cuda.device_count() 235 | gen.module.dci_db = MDCI(self.H.proj_dim, num_comp_indices=self.H.num_comp_indices, 236 | num_simp_indices=self.H.num_simp_indices, devices=[i for i in range(device_count)]) 237 | gen.module.dci_db.add(self.pool_samples_proj[pool_slice]) 238 | pool_latents = self.pool_latents[pool_slice] 239 | snoise_pool = [b[pool_slice] for b in self.snoise_pool] 240 | 241 | t0 = time.time() 242 | for ind, y in enumerate(DataLoader(TensorDataset(dataset[to_update]), batch_size=self.H.imle_batch)): 243 | _, target = self.preprocess_fn(y) 244 | batch_slice = slice(ind * self.H.imle_batch, ind * self.H.imle_batch + target.shape[0]) 245 | indices = to_update[batch_slice] 246 | x = self.dataset_proj[indices] 247 | nearest_indices, dci_dists = gen.module.dci_db.query(x.float(), num_neighbours=1) 248 | nearest_indices = nearest_indices.long()[:, 0] 249 | dci_dists = dci_dists[:, 0] 250 | 251 | need_update = dci_dists < self.selected_dists_tmp[indices] 252 | global_need_update = indices[need_update] 253 | 254 | self.selected_dists_tmp[global_need_update] = dci_dists[need_update].clone() 255 | self.selected_latents_tmp[global_need_update] = pool_latents[nearest_indices[need_update]].clone() + self.H.imle_perturb_coef * torch.randn((need_update.sum(), self.H.latent_dim)) 256 | for j in range(len(self.res)): 257 | self.selected_snoise[j][global_need_update] = snoise_pool[j][nearest_indices[need_update]].clone() 258 | 259 | gen.module.dci_db.clear() 260 | 261 | if i % 100 == 0: 262 | print("NN calculated for {} out of {} - {}".format((i + 1) * self.H.imle_db_size, self.pool_size, time.time() - t0)) 263 | 264 | 265 | if self.H.latent_epoch > 0: 266 | for param in gen.parameters(): 267 | param.requires_grad = False 268 | updatable_latents = self.selected_latents_tmp[to_update].clone().requires_grad_(True) 269 | latent_optimizer = AdamW([updatable_latents], lr=self.latent_lr) 270 | comb_dataset = ZippedDataset(TensorDataset(dataset[to_update]), TensorDataset(updatable_latents)) 271 | 272 | for gd_epoch in range(self.H.latent_epoch): 273 | losses = [] 274 | for cur, _ in DataLoader(comb_dataset, batch_size=self.H.n_batch): 275 | x = cur[0] 276 | latents = cur[1][0] 277 | _, target = self.preprocess_fn(x) 278 | gen.zero_grad() 279 | px_z = gen(latents) # TODO fix this 280 | loss = self.calc_loss(px_z, target.permute(0, 3, 1, 2)) 281 | loss.backward() 282 | latent_optimizer.step() 283 | updatable_latents.grad.zero_() 284 | 285 | losses.append(loss.detach()) 286 | print('avg loss', gd_epoch, sum(losses) / len(losses)) 287 | self.selected_latents[to_update] = updatable_latents.detach().clone() 288 | 289 | if self.H.latent_epoch > 0: 290 | for param in gen.parameters(): 291 | param.requires_grad = True 292 | self.latent_lr = self.latent_lr * (1 - self.H.latent_decay) 293 | -------------------------------------------------------------------------------- /setup_datasets.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gdown https://drive.google.com/file/d/1VwFFzU8wJD1XJtfg60iLwnyBQ_cLZObL/view\?usp\=sharing --fuzzy 4 | tar -xvf datasets.tar -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | if __name__ == '__main__': 4 | print(sys.argv) 5 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import imageio 5 | import torch 6 | import wandb 7 | from cleanfid import fid 8 | from torch.utils.data import DataLoader, TensorDataset 9 | 10 | from data import set_up_data 11 | from helpers.imle_helpers import backtrack, reconstruct 12 | from helpers.train_helpers import (load_imle, load_opt, save_latents, 13 | save_latents_latest, save_model, 14 | save_snoise, set_up_hyperparams, update_ema) 15 | from helpers.utils import ZippedDataset, get_cpu_stats_over_ranks 16 | from metrics.ppl import calc_ppl 17 | from metrics.ppl_uniform import calc_ppl_uniform 18 | from sampler import Sampler 19 | from visual.generate_rnd import generate_rnd 20 | from visual.generate_rnd_nn import generate_rnd_nn 21 | from visual.generate_sample_nn import generate_sample_nn 22 | from visual.interpolate import random_interp 23 | from visual.nn_interplate import nn_interp 24 | from visual.spatial_visual import spatial_vissual 25 | from visual.utils import (generate_and_save, generate_for_NN, 26 | generate_images_initial, 27 | get_sample_for_visualization) 28 | 29 | 30 | def training_step_imle(H, n, targets, latents, snoise, imle, ema_imle, optimizer, loss_fn): 31 | t0 = time.time() 32 | imle.zero_grad() 33 | px_z = imle(latents, snoise) 34 | loss = loss_fn(px_z, targets.permute(0, 3, 1, 2)) 35 | loss.backward() 36 | optimizer.step() 37 | if ema_imle is not None: 38 | update_ema(imle, ema_imle, H.ema_rate) 39 | 40 | stats = get_cpu_stats_over_ranks(dict(loss_nans=0, loss=loss)) 41 | stats.update(skipped_updates=0, iter_time=time.time() - t0, grad_norm=0) 42 | return stats 43 | 44 | 45 | def train_loop_imle(H, data_train, data_valid, preprocess_fn, imle, ema_imle, logprint): 46 | subset_len = len(data_train) 47 | if H.subset_len != -1: 48 | subset_len = H.subset_len 49 | for data_train in DataLoader(data_train, batch_size=subset_len): 50 | data_train = TensorDataset(data_train[0]) 51 | break 52 | 53 | optimizer, scheduler, _, iterate, _ = load_opt(H, imle, logprint) 54 | 55 | stats = [] 56 | H.ema_rate = torch.as_tensor(H.ema_rate) 57 | 58 | subset_len = H.subset_len 59 | if subset_len == -1: 60 | subset_len = len(data_train) 61 | 62 | sampler = Sampler(H, subset_len, preprocess_fn) 63 | 64 | last_updated = torch.zeros(subset_len, dtype=torch.int16).cuda() 65 | times_updated = torch.zeros(subset_len, dtype=torch.int8).cuda() 66 | change_thresholds = torch.empty(subset_len).cuda() 67 | change_thresholds[:] = H.change_threshold 68 | best_fid = 100000 69 | 70 | 71 | epoch = -1 72 | for outer in range(H.num_epochs): 73 | for split_ind, split_x_tensor in enumerate(DataLoader(data_train, batch_size=subset_len, pin_memory=True)): 74 | split_x_tensor = split_x_tensor[0].contiguous() 75 | split_x = TensorDataset(split_x_tensor) 76 | sampler.init_projection(split_x_tensor) 77 | viz_batch_original, _ = get_sample_for_visualization(split_x, preprocess_fn, H.num_images_visualize, H.dataset) 78 | 79 | print('Outer batch - {}'.format(split_ind, len(split_x))) 80 | 81 | while True: 82 | epoch += 1 83 | last_updated[:] = last_updated + 1 84 | 85 | sampler.selected_dists[:] = sampler.calc_dists_existing(split_x_tensor, imle, dists=sampler.selected_dists) 86 | dists_in_threshold = sampler.selected_dists < change_thresholds 87 | updated_enough = last_updated >= H.imle_staleness 88 | updated_too_much = last_updated >= H.imle_force_resample 89 | in_threshold = torch.logical_and(dists_in_threshold, updated_enough) 90 | all_conditions = torch.logical_or(in_threshold, updated_too_much) 91 | to_update = torch.nonzero(all_conditions, as_tuple=False).squeeze(1) 92 | 93 | if epoch == 0: 94 | if os.path.isfile(str(H.restore_latent_path)): 95 | latents = torch.load(H.restore_latent_path) 96 | sampler.selected_latents[:] = latents[:] 97 | for x in DataLoader(split_x, batch_size=H.num_images_visualize, pin_memory=True): 98 | break 99 | batch_slice = slice(0, x[0].size()[0]) 100 | latents = sampler.selected_latents[batch_slice] 101 | with torch.no_grad(): 102 | snoise = [s[batch_slice] for s in sampler.selected_snoise] 103 | generate_for_NN(sampler, x[0], latents, snoise, viz_batch_original.shape, imle, 104 | f'{H.save_dir}/NN-samples_{outer}-{split_ind}-imle.png', logprint) 105 | print('loaded latest latents') 106 | 107 | if os.path.isfile(str(H.restore_latent_path)): 108 | threshold = torch.load(H.restore_threshold_path) 109 | change_thresholds[:] = threshold[:] 110 | print('loaded thresholds', torch.mean(change_thresholds)) 111 | else: 112 | to_update = sampler.entire_ds 113 | 114 | 115 | change_thresholds[to_update] = sampler.selected_dists[to_update].clone() * (1 - H.change_coef) 116 | 117 | sampler.imle_sample_force(split_x_tensor, imle, to_update) 118 | 119 | last_updated[to_update] = 0 120 | times_updated[to_update] = times_updated[to_update] + 1 121 | 122 | save_latents_latest(H, split_ind, sampler.selected_latents) 123 | save_latents_latest(H, split_ind, change_thresholds, name='threshold_latest') 124 | 125 | if to_update.shape[0] >= H.num_images_visualize: 126 | latents = sampler.selected_latents[to_update[:H.num_images_visualize]] 127 | with torch.no_grad(): 128 | generate_for_NN(sampler, split_x_tensor[to_update[:H.num_images_visualize]], latents, 129 | [s[to_update[:H.num_images_visualize]] for s in sampler.selected_snoise], 130 | viz_batch_original.shape, imle, 131 | f'{H.save_dir}/NN-samples_{epoch}-imle.png', logprint) 132 | 133 | 134 | 135 | comb_dataset = ZippedDataset(split_x, TensorDataset(sampler.selected_latents)) 136 | data_loader = DataLoader(comb_dataset, batch_size=H.n_batch, pin_memory=True, shuffle=True) 137 | for cur, indices in data_loader: 138 | x = cur[0] 139 | latents = cur[1][0] 140 | _, target = preprocess_fn(x) 141 | cur_snoise = [s[indices] for s in sampler.selected_snoise] 142 | stat = training_step_imle(H, target.shape[0], target, latents, cur_snoise, imle, ema_imle, optimizer, sampler.calc_loss) 143 | stats.append(stat) 144 | scheduler.step() 145 | 146 | if iterate % H.iters_per_images == 0: 147 | with torch.no_grad(): 148 | generate_images_initial(H, sampler, viz_batch_original, 149 | sampler.selected_latents[0: H.num_images_visualize], 150 | [s[0: H.num_images_visualize] for s in sampler.selected_snoise], 151 | viz_batch_original.shape, imle, ema_imle, 152 | f'{H.save_dir}/samples-{iterate}.png', logprint) 153 | 154 | iterate += 1 155 | if iterate % H.iters_per_save == 0: 156 | fp = os.path.join(H.save_dir, 'latest') 157 | logprint(f'Saving model@ {iterate} to {fp}') 158 | save_model(fp, imle, ema_imle, optimizer, H) 159 | save_latents_latest(H, split_ind, sampler.selected_latents) 160 | save_latents_latest(H, split_ind, change_thresholds, name='threshold_latest') 161 | 162 | if iterate % H.iters_per_ckpt == 0: 163 | save_model(os.path.join(H.save_dir, f'iter-{iterate}'), imle, ema_imle, optimizer, H) 164 | save_latents(H, iterate, split_ind, sampler.selected_latents) 165 | save_latents(H, iterate, split_ind, change_thresholds, name='threshold') 166 | save_snoise(H, iterate, sampler.selected_snoise) 167 | 168 | cur_dists = torch.empty([subset_len], dtype=torch.float32).cuda() 169 | cur_dists[:] = sampler.calc_dists_existing(split_x_tensor, imle, dists=cur_dists) 170 | torch.save(cur_dists, f'{H.save_dir}/latent/dists-{epoch}.npy') 171 | 172 | metrics = { 173 | 'mean_loss': torch.mean(cur_dists).item(), 174 | 'std_loss': torch.std(cur_dists).item(), 175 | 'max_loss': torch.max(cur_dists).item(), 176 | 'min_loss': torch.min(cur_dists).item(), 177 | } 178 | 179 | if epoch % H.fid_freq == 0: 180 | generate_and_save(H, imle, sampler, subset_len * H.fid_factor) 181 | print(f'{H.data_root}/img', f'{H.save_dir}/fid/') 182 | cur_fid = fid.compute_fid(f'{H.data_root}/img', f'{H.save_dir}/fid/', verbose=False) 183 | if cur_fid < best_fid: 184 | best_fid = cur_fid 185 | # save models 186 | fp = os.path.join(H.save_dir, 'best_fid') 187 | logprint(f'Saving model best fid {best_fid} @ {iterate} to {fp}') 188 | save_model(fp, imle, ema_imle, optimizer, H) 189 | 190 | metrics['fid'] = cur_fid 191 | metrics['best_fid'] = best_fid 192 | 193 | 194 | logprint(model=H.desc, type='train_loss', epoch=epoch, step=iterate, **metrics) 195 | 196 | if H.use_wandb: 197 | wandb.log(metrics, step=iterate) 198 | 199 | 200 | 201 | def main(H=None): 202 | H_cur, logprint = set_up_hyperparams() 203 | if not H: 204 | H = H_cur 205 | H, data_train, data_valid_or_test, preprocess_fn = set_up_data(H) 206 | imle, ema_imle = load_imle(H, logprint) 207 | 208 | if H.use_wandb: 209 | wandb.init( 210 | name=H.wandb_name, 211 | project=H.wandb_project, 212 | config=H, 213 | mode=H.wandb_mode, 214 | ) 215 | 216 | os.makedirs(f'{H.save_dir}/fid', exist_ok=True) 217 | 218 | if H.mode == 'eval': 219 | with torch.no_grad(): 220 | # Generating 221 | sampler = Sampler(H, len(data_train), preprocess_fn) 222 | n_samp = H.n_batch 223 | temp_latent_rnds = torch.randn([n_samp, H.latent_dim], dtype=torch.float32).cuda() 224 | for i in range(0, H.num_images_to_generate // n_samp): 225 | if (i % 10 == 0): 226 | print(i * n_samp) 227 | temp_latent_rnds.normal_() 228 | tmp_snoise = [s[:n_samp].normal_() for s in sampler.snoise_tmp] 229 | samp = sampler.sample(temp_latent_rnds, imle, tmp_snoise) 230 | for j in range(n_samp): 231 | imageio.imwrite(f'{H.save_dir}/{i * n_samp + j}.png', samp[j]) 232 | 233 | 234 | elif H.mode == 'reconstruct': 235 | 236 | subset_len = H.subset_len 237 | if subset_len == -1: 238 | subset_len = len(data_train) 239 | ind = 0 240 | for split_ind, split_x_tensor in enumerate(DataLoader(data_train, batch_size=H.subset_len, pin_memory=True)): 241 | if (ind == 14): 242 | break 243 | split_x = TensorDataset(split_x_tensor[0]) 244 | ind += 1 245 | 246 | for param in imle.parameters(): 247 | param.requires_grad = False 248 | viz_batch_original, _ = get_sample_for_visualization(split_x, preprocess_fn, 249 | H.num_images_visualize, H.dataset) 250 | if os.path.isfile(str(H.restore_latent_path)): 251 | latents = torch.tensor(torch.load(H.restore_latent_path), requires_grad=True) 252 | else: 253 | latents = torch.randn([viz_batch_original.shape[0], H.latent_dim], requires_grad=True) 254 | sampler = Sampler(H, subset_len, preprocess_fn) 255 | reconstruct(H, sampler, imle, preprocess_fn, viz_batch_original, latents, 'reconstruct', logprint, training_step_imle) 256 | 257 | elif H.mode == 'backtrack': 258 | for param in imle.parameters(): 259 | param.requires_grad = False 260 | for split_x in DataLoader(data_train, batch_size=H.subset_len): 261 | split_x = split_x[0] 262 | pass 263 | print(f'split shape is {split_x.shape}') 264 | sampler = Sampler(H, H.subset_len, preprocess_fn) 265 | backtrack(H, sampler, imle, preprocess_fn, split_x, logprint, training_step_imle) 266 | 267 | 268 | elif H.mode == 'train': 269 | train_loop_imle(H, data_train, data_valid_or_test, preprocess_fn, imle, ema_imle, logprint) 270 | 271 | elif H.mode == 'ppl': 272 | sampler = Sampler(H, H.subset_len, preprocess_fn) 273 | calc_ppl(H, imle, sampler) 274 | 275 | elif H.mode == 'ppl_uniform': 276 | sampler = Sampler(H, H.subset_len, preprocess_fn) 277 | calc_ppl_uniform(H, imle, sampler) 278 | 279 | elif H.mode == 'interpolate': 280 | with torch.no_grad(): 281 | for split_x in DataLoader(data_train, batch_size=H.subset_len): 282 | split_x = split_x[0] 283 | viz_batch_original, _ = get_sample_for_visualization(split_x, preprocess_fn, 284 | H.num_images_visualize, H.dataset) 285 | sampler = Sampler(H, H.subset_len, preprocess_fn) 286 | for i in range(H.num_images_to_generate): 287 | random_interp(H, sampler, (0, 256, 256, 3), imle, f'{H.save_dir}/interp-{i}.png', logprint) 288 | 289 | elif H.mode == 'spatial_visual': 290 | with torch.no_grad(): 291 | for split_x in DataLoader(data_train, batch_size=H.subset_len): 292 | split_x = split_x[0] 293 | viz_batch_original, _ = get_sample_for_visualization(split_x, preprocess_fn, 294 | H.num_images_visualize, H.dataset) 295 | sampler = Sampler(H, H.subset_len, preprocess_fn) 296 | for i in range(H.num_images_to_generate): 297 | print(H.num_images_to_generate, i) 298 | spatial_vissual(H, sampler, (0, 256, 256, 3), imle, f'{H.save_dir}/interp-{i}.png', logprint) 299 | 300 | elif H.mode == 'generate_rnd': 301 | with torch.no_grad(): 302 | for split_x in DataLoader(data_train, batch_size=H.subset_len): 303 | split_x = split_x[0] 304 | viz_batch_original, _ = get_sample_for_visualization(split_x, preprocess_fn, 305 | H.num_images_visualize, H.dataset) 306 | sampler = Sampler(H, H.subset_len, preprocess_fn) 307 | generate_rnd(H, sampler, (0, 256, 256, 3), imle, f'{H.save_dir}/rnd.png', logprint) 308 | 309 | elif H.mode == 'generate_rnd_nn': 310 | with torch.no_grad(): 311 | for split_x in DataLoader(data_train, batch_size=len(data_train)): 312 | split_x = split_x[0] 313 | viz_batch_original, _ = get_sample_for_visualization(split_x, preprocess_fn, 314 | H.num_images_visualize, H.dataset) 315 | sampler = Sampler(H, H.subset_len, preprocess_fn) 316 | generate_rnd_nn(H, split_x, sampler, (0, 256, 256, 3), imle, f'{H.save_dir}', logprint, preprocess_fn) 317 | 318 | elif H.mode == 'nn_interp': 319 | with torch.no_grad(): 320 | for split_x in DataLoader(data_train, batch_size=len(data_train)): 321 | split_x = split_x[0] 322 | viz_batch_original, _ = get_sample_for_visualization(split_x, preprocess_fn, 323 | H.num_images_visualize, H.dataset) 324 | sampler = Sampler(H, H.subset_len, preprocess_fn) 325 | nn_interp(H, split_x, sampler, (0, 256, 256, 3), imle, f'{H.save_dir}', logprint, preprocess_fn) 326 | 327 | elif H.mode == 'generate_sample_nn': 328 | with torch.no_grad(): 329 | for split_x in DataLoader(data_train, batch_size=len(data_train)): 330 | split_x = split_x[0] 331 | viz_batch_original, _ = get_sample_for_visualization(split_x, preprocess_fn, 332 | H.num_images_visualize, H.dataset) 333 | sampler = Sampler(H, H.subset_len, preprocess_fn) 334 | generate_sample_nn(H, split_x, sampler, (0, 256, 256, 3), imle, f'{H.save_dir}/rnd2.png', logprint, preprocess_fn) 335 | 336 | elif H.mode == 'backtrack_interpolate': 337 | with torch.no_grad(): 338 | for split_x in DataLoader(data_train, batch_size=H.subset_len): 339 | split_x = split_x[0] 340 | viz_batch_original, _ = get_sample_for_visualization(split_x, preprocess_fn, 341 | H.num_images_visualize, H.dataset) 342 | sampler = Sampler(H, H.subset_len, preprocess_fn) 343 | latents = torch.tensor(torch.load(f'{H.restore_latent_path}'), requires_grad=True, dtype=torch.float32, device='cuda') 344 | for i in range(latents.shape[0] - 1): 345 | lat0 = latents[i:i+1] 346 | lat1 = latents[i+1:i+2] 347 | sn1 = None 348 | sn2 = None 349 | random_interp(H, sampler, (0, 256, 256, 3), imle, f'test/interp-{i}.png', logprint, lat0, lat1, sn1, sn2) 350 | 351 | 352 | if __name__ == "__main__": 353 | main() 354 | -------------------------------------------------------------------------------- /visual/generate_rnd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import imageio 4 | 5 | def generate_rnd(H, sampler, shape, ema_imle, fname, logprint): 6 | mb = H.num_rows_visualize 7 | batches = [] 8 | n_rows = mb 9 | temp_latent_rnds = torch.randn([mb, H.latent_dim], dtype=torch.float32).cuda() 10 | for t in range(H.num_rows_visualize): 11 | temp_latent_rnds.normal_() 12 | tmp_snoise = [s[:mb].normal_() for s in sampler.snoise_tmp] 13 | out = ema_imle(temp_latent_rnds, tmp_snoise) 14 | batches.append(sampler.sample_from_out(out)) 15 | 16 | im = np.concatenate(batches, axis=0).reshape((n_rows, mb, *shape[1:])).transpose([0, 2, 1, 3, 4]).reshape( 17 | [n_rows * shape[1], mb * shape[2], 3]) 18 | 19 | logprint(f'printing samples to {fname}') 20 | imageio.imwrite(fname, im) 21 | -------------------------------------------------------------------------------- /visual/generate_rnd_nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import imageio 4 | 5 | def generate_rnd_nn(H, data, sampler, shape, imle, fname, logprint, preprocess_fn): 6 | mb = 10 7 | batches = [] 8 | temp_latent_rnds = torch.randn([mb, H.latent_dim], dtype=torch.float32).cuda() 9 | for t in range(H.num_rows_visualize): 10 | temp_latent_rnds.normal_() 11 | tmp_snoise = [s[:mb].normal_() for s in sampler.snoise_tmp] 12 | out = imle(temp_latent_rnds, tmp_snoise) 13 | batches.append(out) 14 | to_s = [] 15 | nns = [] 16 | nns_pairs = [] 17 | for b in batches: 18 | for i in range(mb): 19 | to_s.append(b[i:i+1]) 20 | print(len(to_s)) 21 | for i in range(data.shape[0]): 22 | x = data[i:i+1] 23 | _, target = preprocess_fn([x]) 24 | bst_loss = np.inf 25 | bst_ind = -1 26 | for j, d in enumerate(to_s): 27 | cur = sampler.calc_loss(target.permute(0, 3, 1, 2).cuda(), d.cuda()).item() 28 | if cur < bst_loss: 29 | bst_loss = cur 30 | bst_ind = j 31 | real = sampler.sample_from_out(target.permute(0, 3, 1, 2).cpu()) 32 | nn = sampler.sample_from_out(to_s[bst_ind].cpu()) 33 | nns_pairs.append((bst_loss, real, nn, bst_ind)) 34 | 35 | print(len(nns)) 36 | nns_pairs = sorted(nns_pairs)[::-1] 37 | for a in nns_pairs: 38 | nns.append(a[1]) 39 | nns.append(a[2]) 40 | batches = nns 41 | mb = 10 42 | n_rows = 20 43 | im = np.concatenate(batches, axis=0).reshape((n_rows, mb, *shape[1:])).transpose([0, 2, 1, 3, 4]).reshape( 44 | [n_rows * shape[1], mb * shape[2], 3]) 45 | logprint(f'printing samples to {fname}/rnd-nn.png') 46 | imageio.imwrite(f'{fname}/rnd-nn.png', im) 47 | 48 | # used = [x[3] for x in nns_pairs] 49 | # others = [(np.inf, x, None) for i, x in enumerate(to_s) if i not in used] 50 | # print('others', len(others)) 51 | # for i in range(data.shape[0]): 52 | # x = data[i:i+1] 53 | # _, target = preprocess_fn([x]) 54 | # for j, x in enumerate(others): 55 | # d = x[1] 56 | # cur = sampler.calc_loss(target.permute(0, 3, 1, 2).cuda(), d.cuda()).item() 57 | # if cur < x[0]: 58 | # others[j] = (cur, d, target) 59 | # others = sorted(others)[::-1] 60 | 61 | # for i in range(len(others)//10): 62 | # nns = [] 63 | # for a in others[i*10:(i+1)*10]: 64 | # nn = sampler.sample_from_out(a[1].cpu()) 65 | # nns.append(nn) 66 | # for a in others[i*10:(i+1)*10]: 67 | # real = sampler.sample_from_out(a[2].permute(0, 3, 1, 2).cpu()) 68 | # nns.append(real) 69 | 70 | # print(len(nns)) 71 | # batches = nns 72 | # mb = 10 73 | # n_rows = 2 74 | # im = np.concatenate(batches, axis=0).reshape((n_rows, mb, *shape[1:])).transpose([0, 2, 1, 3, 4]).reshape( 75 | # [n_rows * shape[1], mb * shape[2], 3]) 76 | # logprint(f'printing samples to {fname}') 77 | # imageio.imwrite(f'{fname}/rnd-nn-rem-{i}.png', im) 78 | -------------------------------------------------------------------------------- /visual/generate_sample_nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import imageio 4 | from torch import nn as nnn 5 | 6 | def generate_sample_nn(H, data, sampler, shape, ema_imle, fname, logprint, preprocess_fn): 7 | mb = H.num_rows_visualize 8 | batches = [] 9 | n_rows = mb 10 | temp_latent_rnds = torch.randn([mb, H.latent_dim], dtype=torch.float32).cuda() 11 | tmp_snoise = [s[:mb].normal_() for s in sampler.snoise_tmp] 12 | temp_latent_rnds = torch.randn([mb, H.latent_dim], dtype=torch.float32).cuda() 13 | tmp_snoise = [s[:mb].normal_() for s in sampler.snoise_tmp] 14 | out = ema_imle(temp_latent_rnds, tmp_snoise) 15 | batches.append(out) 16 | to_s = [] 17 | nns = [] 18 | for b in batches: 19 | for i in range(mb): 20 | to_s.append((b[i:i+1], None, np.inf)) 21 | print(data.shape, len(to_s)) 22 | loss = nnn.MSELoss() 23 | for i in range(data.shape[0]): 24 | x = data[i:i+1] 25 | _, target = preprocess_fn([x]) 26 | for j, x in enumerate(to_s): 27 | d = x[0] 28 | # cur = sampler.calc_loss(target.permute(0, 3, 1, 2).cuda(), d.cuda()).item() 29 | cur = loss(target.permute(0, 3, 1, 2).cuda(), d.cuda()).item() 30 | if cur < x[2]: 31 | to_s[j] = (d, target, cur) 32 | 33 | for a in to_s: 34 | nn = sampler.sample_from_out(a[0].cpu()) 35 | nns.append(nn) 36 | for a in to_s: 37 | real = sampler.sample_from_out(a[1].permute(0, 3, 1, 2).cpu()) 38 | nns.append(real) 39 | 40 | print(len(nns)) 41 | batches = nns 42 | n_rows = 2 43 | im = np.concatenate(batches, axis=0).reshape((n_rows, mb, *shape[1:])).transpose([0, 2, 1, 3, 4]).reshape( 44 | [n_rows * shape[1], mb * shape[2], 3]) 45 | 46 | logprint(f'printing samples to {fname}') 47 | imageio.imwrite(fname, im) 48 | -------------------------------------------------------------------------------- /visual/interpolate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import imageio 4 | 5 | def random_interp(H, sampler, shape, imle, fname, logprint, lat1=None, lat2=None, sn1=None, sn2=None): 6 | num_lin = 1 7 | mb = 8 8 | 9 | batches = [] 10 | # step = (-f_latent + s_latent) / num_lin 11 | for t in range(num_lin): 12 | f_latent = torch.randn([1, H.latent_dim], dtype=torch.float32).cuda() 13 | s_latent = torch.randn([1, H.latent_dim], dtype=torch.float32).cuda() 14 | f_snoise = [torch.randn([1, 1, s, s], dtype=torch.float32).cuda() for s in sampler.res] 15 | s_snoise = [torch.randn([1, 1, s, s], dtype=torch.float32).cuda() for s in sampler.res] 16 | if lat1 is not None: 17 | print('loading from input') 18 | f_latent = lat1 19 | s_latent = lat2 20 | # if sn1 is not None: 21 | # f_snoise = sn1 22 | # s_snoise = sn2 23 | f_latent = imle.module.decoder.mapping_network(f_latent)[0] 24 | s_latent = imle.module.decoder.mapping_network(s_latent)[0] 25 | sample_w = torch.cat([torch.lerp(f_latent, s_latent, v) for v in torch.linspace(0, 1, mb).cuda()], dim=0) 26 | snoise = [torch.cat([f_snoise[i] for v in torch.linspace(0, 1, mb).cuda()], dim=0) for i in range(len(f_snoise))] 27 | 28 | # for i in range(len(snoise)): 29 | # snoise[i][:] = f_snoise[i][0] 30 | 31 | out = imle(sample_w, spatial_noise=snoise, input_is_w=True) 32 | batches.append(sampler.sample_from_out(out)) 33 | 34 | n_rows = len(batches) 35 | im = np.concatenate(batches, axis=0).reshape((n_rows, mb, *shape[1:])).transpose([0, 2, 1, 3, 4]).reshape( 36 | [n_rows * shape[1], mb * shape[2], 3]) 37 | 38 | logprint(f'printing samples to {fname}') 39 | imageio.imwrite(fname, im) -------------------------------------------------------------------------------- /visual/nn_interplate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import imageio 4 | 5 | def nn_interp(H, data, sampler, shape, ema_imle, fname, logprint, preprocess_fn): 6 | mb = 10 7 | batches = [] 8 | temp_latent_rnds = torch.randn([mb, H.latent_dim], dtype=torch.float32).cuda() 9 | for t in range(H.num_rows_visualize): 10 | temp_latent_rnds.normal_() 11 | tmp_snoise = [s[:mb].normal_() for s in sampler.snoise_tmp] 12 | out = ema_imle(temp_latent_rnds, tmp_snoise) 13 | batches.append((out, torch.tensor(temp_latent_rnds))) 14 | to_s = [] 15 | nns = [] 16 | nns_pairs = [] 17 | for bb in batches: 18 | for i in range(mb): 19 | b = bb[0] 20 | to_s.append((b[i:i+1], bb[1][i:i+1])) 21 | print(len(to_s)) 22 | for i in range(data.shape[0]): 23 | x = data[i:i+1] 24 | _, target = preprocess_fn([x]) 25 | bst_loss = np.inf 26 | bst_ind = -1 27 | for j, dd in enumerate(to_s): 28 | d = dd[0] 29 | cur = sampler.calc_loss(target.permute(0, 3, 1, 2).cuda(), d.cuda()).item() 30 | if cur < bst_loss: 31 | bst_loss = cur 32 | bst_ind = j 33 | nns_pairs.append((bst_loss, bst_ind)) 34 | nnss = torch.cat([to_s[x[1]][1] for x in nns_pairs], dim=0) 35 | torch.save(nnss.detach(), f'best-nns.npy') 36 | 37 | # used = [x[3] for x in nns_pairs] 38 | # others = [(np.inf, x, None) for i, x in enumerate(to_s) if i not in used] 39 | # print('others', len(others)) 40 | # for i in range(data.shape[0]): 41 | # x = data[i:i+1] 42 | # _, target = preprocess_fn([x]) 43 | # for j, x in enumerate(others): 44 | # d = x[1] 45 | # cur = sampler.calc_loss(target.permute(0, 3, 1, 2).cuda(), d.cuda()).item() 46 | # if cur < x[0]: 47 | # others[j] = (cur, d, target) 48 | # others = sorted(others)[::-1] 49 | 50 | # for i in range(len(others)//10): 51 | # nns = [] 52 | # for a in others[i*10:(i+1)*10]: 53 | # nn = sampler.sample_from_out(a[1].cpu()) 54 | # nns.append(nn) 55 | # for a in others[i*10:(i+1)*10]: 56 | # real = sampler.sample_from_out(a[2].permute(0, 3, 1, 2).cpu()) 57 | # nns.append(real) 58 | 59 | # print(len(nns)) 60 | # batches = nns 61 | # mb = 10 62 | # n_rows = 2 63 | # im = np.concatenate(batches, axis=0).reshape((n_rows, mb, *shape[1:])).transpose([0, 2, 1, 3, 4]).reshape( 64 | # [n_rows * shape[1], mb * shape[2], 3]) 65 | # logprint(f'printing samples to {fname}') 66 | # imageio.imwrite(f'{fname}/rnd-nn-rem-{i}.png', im) 67 | -------------------------------------------------------------------------------- /visual/spatial_visual.py: -------------------------------------------------------------------------------- 1 | from selectors import BaseSelector 2 | import torch 3 | import numpy as np 4 | import imageio 5 | 6 | def spatial_vissual(H, sampler, shape, imle, fname, logprint, lat1=None, lat2=None, sn1=None, sn2=None): 7 | num_lin = 1 8 | mb = 8 9 | 10 | batches = [] 11 | # step = (-f_latent + s_latent) / num_lin 12 | base_latent = torch.randn([1, H.latent_dim], dtype=torch.float32).cuda() 13 | base_snosie = [torch.randn([1, 1, s, s], dtype=torch.float32).cuda() for s in sampler.res] 14 | # for t in range(len(sampler.res)): 15 | # snoise = [base_snosie[i] for i in range(t)] + [torch.ones([1, 1, s, s], dtype=torch.float32).cuda() for s in sampler.res[t:]] 16 | # out = imle(base_latent, spatial_noise=snoise) 17 | # batches.append(sampler.sample_from_out(out)) 18 | 19 | # for t in range(10): 20 | # snoise = [base_snosie[i] for i in range(len(sampler.res) - 1)] + [torch.randn([1, 1, s, s], dtype=torch.float32).cuda() for s in sampler.res[len(sampler.res) - 1:]] 21 | # out = imle(base_latent, spatial_noise=snoise) 22 | # batches.append(sampler.sample_from_out(out)) 23 | 24 | # for t in range(10): 25 | # snoise = [base_snosie[i] for i in range(len(sampler.res) - 1)] + [torch.randn([1, 1, s, s], dtype=torch.float32).cuda() for s in sampler.res[len(sampler.res) - 1:]] 26 | # base_latent.normal_() 27 | # out = imle(base_latent, spatial_noise=base_snosie) 28 | # batches.append(sampler.sample_from_out(out)) 29 | 30 | base_latent = [torch.zeros([1, H.latent_dim], dtype=torch.float32).cuda() for i in range(10)] 31 | base_snosie = [torch.zeros([1, 1, s, s], dtype=torch.float32).cuda() for s in sampler.res] 32 | out = imle(base_latent[0], spatial_noise=base_snosie) 33 | batches.append(sampler.sample_from_out(out)) 34 | 35 | 36 | base_latent = [torch.randn([1, H.latent_dim], dtype=torch.float32).cuda() for i in range(10)] 37 | base_snosie = [torch.randn([1, 1, s, s], dtype=torch.float32).cuda() for s in sampler.res] 38 | for i in range(10): 39 | out = imle(base_latent[i], spatial_noise=base_snosie) 40 | batches.append(sampler.sample_from_out(out)) 41 | 42 | lat_dim = base_latent[0].shape[1] 43 | print(lat_dim) 44 | for i in range(10): 45 | out = imle(base_latent[i], spatial_noise=base_snosie) 46 | batches.append(sampler.sample_from_out(out)) 47 | lat = base_latent[i] 48 | for j in range(10): 49 | lat2 = base_latent[j] 50 | cur = torch.cat((lat[:, 0:lat_dim//2], lat2[:, lat_dim//2:]), dim=1) 51 | out = imle(cur, spatial_noise=base_snosie) 52 | batches.append(sampler.sample_from_out(out)) 53 | 54 | print(len(batches)) 55 | mb = 11 56 | n_rows = 11 57 | im = np.concatenate(batches, axis=0).reshape((n_rows, mb, *shape[1:])).transpose([0, 2, 1, 3, 4]).reshape( 58 | [n_rows * shape[1], mb * shape[2], 3]) 59 | 60 | logprint(f'printing samples to {fname}') 61 | imageio.imwrite(fname, im) -------------------------------------------------------------------------------- /visual/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import numpy as np 4 | import imageio 5 | 6 | def get_sample_for_visualization(data, preprocess_fn, num, dataset): 7 | for x in DataLoader(data, batch_size=num): 8 | break 9 | orig_image = (x[0] * 255.0).to(torch.uint8).permute(0, 2, 3, 1) if dataset == 'ffhq_1024' else x[0] 10 | preprocessed = preprocess_fn(x)[0] 11 | return orig_image, preprocessed 12 | 13 | 14 | 15 | def generate_for_NN(sampler, orig, initial, snoise, shape, ema_imle, fname, logprint): 16 | mb = shape[0] 17 | initial = initial[:mb].cuda() 18 | nns = sampler.sample(initial, ema_imle, snoise) 19 | batches = [orig[:mb], nns] 20 | n_rows = len(batches) 21 | im = np.concatenate(batches, axis=0).reshape((n_rows, mb, *shape[1:])).transpose([0, 2, 1, 3, 4]).reshape( 22 | [n_rows * shape[1], mb * shape[2], 3]) 23 | 24 | logprint(f'printing samples to {fname}') 25 | imageio.imwrite(fname, im) 26 | 27 | 28 | def generate_images_initial(H, sampler, orig, initial, snoise, shape, imle, ema_imle, fname, logprint): 29 | mb = shape[0] 30 | initial = initial[:mb] 31 | batches = [orig[:mb], sampler.sample(initial, imle, snoise)] 32 | 33 | temp_latent_rnds = torch.randn([mb, H.latent_dim], dtype=torch.float32).cuda() 34 | for t in range(H.num_rows_visualize): 35 | temp_latent_rnds.normal_() 36 | tmp_snoise = [s[:mb].normal_() for s in sampler.snoise_tmp] 37 | batches.append(sampler.sample(temp_latent_rnds, imle, tmp_snoise)) 38 | 39 | tmp_snoise = [s[:mb].normal_() for s in sampler.snoise_tmp] 40 | batches.append(sampler.sample(temp_latent_rnds, imle, tmp_snoise)) 41 | 42 | tmp_snoise = [s[:mb].normal_() for s in sampler.snoise_tmp] 43 | batches.append(sampler.sample(temp_latent_rnds, imle, tmp_snoise)) 44 | 45 | tmp_snoise = [s[:mb] for s in sampler.neutral_snoise] 46 | batches.append(sampler.sample(temp_latent_rnds, imle, tmp_snoise)) 47 | 48 | tmp_snoise = [s[:mb] for s in sampler.neutral_snoise] 49 | temp_latent_rnds.normal_() 50 | batches.append(sampler.sample(temp_latent_rnds, imle, tmp_snoise)) 51 | 52 | n_rows = len(batches) 53 | im = np.concatenate(batches, axis=0).reshape((n_rows, mb, *shape[1:])).transpose([0, 2, 1, 3, 4]).reshape( 54 | [n_rows * shape[1], mb * shape[2], 3]) 55 | 56 | logprint(f'printing samples to {fname}') 57 | imageio.imwrite(fname, im) 58 | 59 | def generate_and_save(H, imle, sampler, n_samp, subdir='fid'): 60 | with torch.no_grad(): 61 | temp_latent_rnds = torch.randn([H.imle_batch, H.latent_dim], dtype=torch.float32).cuda() 62 | for i in range(0, n_samp // H.imle_batch): 63 | temp_latent_rnds.normal_() 64 | tmp_snoise = [s[:H.imle_batch].normal_() for s in sampler.snoise_tmp] 65 | samp = sampler.sample(temp_latent_rnds, imle, tmp_snoise) 66 | for j in range(H.imle_batch): 67 | imageio.imwrite(f'{H.save_dir}/{subdir}/{i * H.imle_batch + j}.png', samp[j]) 68 | 69 | --------------------------------------------------------------------------------