├── README.rst ├── colorization_cINN ├── README.md ├── compare_methods.py ├── config.py ├── data.py ├── eval.py ├── feature_net.py ├── joint_bilateral.c ├── joint_bilateral_filter.pyx ├── model.py ├── model_no_cond.py ├── output │ └── .gitkeep ├── pts_in_hull.npy ├── setup.py ├── subnet_coupling.py ├── train.py └── viz.py ├── colorization_minimal_example ├── .gitignore ├── data.py ├── eval.py ├── images │ └── .gitkeep ├── model.py ├── output │ └── .gitkeep ├── train.py └── train_data_128 ├── mnist_cINN ├── .run ├── README.rst ├── color_mnist_data │ └── color_mnist.py ├── cond_net.py ├── config.py ├── data.py ├── eval.py ├── extra_modules.py ├── losses.py ├── model.py ├── opts.py ├── output │ └── .gitkeep ├── train.py └── viz.py └── mnist_minimal_example ├── .run ├── data.py ├── eval.py ├── images └── .gitkeep ├── model.py ├── output └── .gitkeep └── train.py /README.rst: -------------------------------------------------------------------------------- 1 | "Guided Image Generation with Conditional Invertible Neural Networks" (2019) 2 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 3 | 4 | Paper: https://arxiv.org/abs/1907.02392 5 | 6 | Supplement: https://drive.google.com/file/d/1_OoiIGhLeVJGaZFeBt0OWOq8ZCtiI7li 7 | 8 | Contents 9 | ^^^^^^^^^^^^^^^^ 10 | 11 | Each subdirectory has its own README file containing the details. 12 | 13 | * ``experiments/mnist_minimal_example`` contains code to produce class-conditional MNIST samples in <150 lines total 14 | * ``experiments/colorization_minimal_example`` contains code to colorize LSUN bedrooms in <200 lines total 15 | * ``experiments/colorization_cINN`` contains the full research code used to produce all colorization figures in the paper 16 | * ``experiments/mnist_cINN`` contains the full research code used to produce all mnist figures in the paper 17 | 18 | Dependencies 19 | ^^^^^^^^^^^^^^^^ 20 | 21 | Except for pytorch, any fairly recent version will probably work, 22 | these are just the confirmed ones: 23 | 24 | +---------------------------+-------------------------------+ 25 | | **Package** | **Version** | 26 | +---------------------------+-------------------------------+ 27 | | Pytorch | >= 1.0.0 | 28 | +---------------------------+-------------------------------+ 29 | | Numpy | >= 1.15.0 | 30 | +---------------------------+-------------------------------+ 31 | | *Optionally for the experiments:* | 32 | +---------------------------+-------------------------------+ 33 | | Matplotlib | 2.2.3 | 34 | +---------------------------+-------------------------------+ 35 | | Visdom | 0.1.8.5 | 36 | +---------------------------+-------------------------------+ 37 | | Torchvision | 0.2.1 | 38 | +---------------------------+-------------------------------+ 39 | | scikit-learn | 0.20.3 | 40 | +---------------------------+-------------------------------+ 41 | | scikit-image | 0.14.2 | 42 | +---------------------------+-------------------------------+ 43 | | Pillow | 6.0.0 | 44 | +---------------------------+-------------------------------+ 45 | -------------------------------------------------------------------------------- /colorization_cINN/README.md: -------------------------------------------------------------------------------- 1 | Compiling the Joint Bilateral Filter 2 | -------------------------------------- 3 | 4 | Simply run 5 | ``` 6 | python setup.py build_ext --inplace 7 | ``` 8 | 9 | Checkpoints 10 | -------------------------------------- 11 | 12 | * Model checkpoint used for the paper: 13 | 14 | `https://drive.google.com/open?id=1gpHHtT7EcCoEqTzaUmDImp_tyB7vSKIN` 15 | 16 | * Pretrained cond. net (only necessary if you want to train yourself, but not do the pretraining): 17 | 18 | `https://drive.google.com/open?id=1YQkOf03kK7-ZNDGJmVFloF_hZvnmjc0_` 19 | -------------------------------------------------------------------------------- /colorization_cINN/compare_methods.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from multiprocessing import Pool 3 | 4 | from tqdm import tqdm 5 | import numpy as np 6 | #import skimage.io as io 7 | import matplotlib.pyplot as plt 8 | 9 | def bootstrap_mean(x): 10 | '''calculate mean with error estimate through statistical bootsrapping''' 11 | mean = np.mean(x) 12 | mean_sampled = [] 13 | for i in range(256): 14 | x_resamp = x[np.random.randint(0, len(x), size=x.shape)] 15 | mean_sampled.append(np.mean(x_resamp)) 16 | return mean, np.std(mean_sampled) 17 | 18 | # For index i, each function returns a list of filenames of one or more (diverse) colorization results 19 | def cinn_imgs(i): 20 | return ['./images/val_set/%.6i/%.2i.png' % (i, j) for j in range(2, 10)] 21 | 22 | def vae_imgs(i): 23 | return ['/home/diz/code/colorization_baselines/vae_diverse_colorization/data/output/testimgs/%.6i.png/divcolor_%.3i.png' % (i, j) for j in range(8)] 24 | 25 | def cgan_imgs(i): 26 | return ['/home/diz/single_images_cgan/%.6i.png' % (i)] 27 | 28 | def cnn_imgs(i): 29 | return ['/home/diz/code/colorization_baselines/siggraph2016_colorization/single_images_cnn/%.6i.png' % (i)] 30 | 31 | def gt_img(i): 32 | return '/home/diz/data/imagenet/val_cropped/%.6i.png' % (i) 33 | 34 | def abl_no_imgs(i): 35 | return ['./images/val_set_ablation_no_cond/%.6i/%.2i.png' % (i, j) for j in range(2, 10)] 36 | 37 | def abl_fixed_imgs(i): 38 | return ['./cond_fixed/imgs/%.6i_%.3i.png' % (i, j) for j in range(0, 8)] 39 | 40 | def variance(functs): 41 | '''For the methods given in *functs (returning image filenames), compute the variance of 42 | colorizations per image, averaged over the 5k val. set''' 43 | for f in functs: 44 | var = [] 45 | for i in tqdm(range(5120)): 46 | imgs = np.stack([plt.imread(fname)[:, :, :3] for fname in f(i)], axis=0) 47 | var.append(np.mean(np.var(imgs, axis=0))) 48 | 49 | print('variance', f.__name__, *bootstrap_mean(np.array(var))) 50 | 51 | def err_individual_image(args): 52 | '''wrapper for multiprocessing: args is a tuple (f,i) of filename-returning-function f, and 53 | val. index i.''' 54 | 55 | i, f = args 56 | gt_im = plt.imread(gt_img(i))[np.newaxis, :, :, :3] 57 | imgs = np.stack([plt.imread(fname)[:, :, :3] for fname in f(i)], axis=0) 58 | imgs_1 = np.stack([plt.imread(fname)[:, :, :3] for fname in f(i)[:1]], axis=0) 59 | 60 | err = np.sqrt(np.mean(np.min((imgs - gt_im)**2, axis=0))) 61 | err_1 = np.sqrt(np.mean((imgs_1 - gt_im)**2)) 62 | 63 | return [err, err_1] 64 | 65 | def err(functs): 66 | '''Compute the RMS best-of-8 and best-of-1 error on the 5k val set''' 67 | for f in functs: 68 | args = [(i,f) for i in range(5120)] 69 | with Pool(16) as p: 70 | errs = np.array(p.map(err_individual_image, args)) 71 | 72 | print('of 8', f.__name__, *bootstrap_mean(errs[:,0])) 73 | print('of 1', f.__name__, *bootstrap_mean(errs[:,1])) 74 | 75 | functs = [cgan_imgs, cinn_imgs, vae_imgs, cnn_imgs, abl_no_imgs, abl_fixed_imgs] 76 | functs_diverse = [cinn_imgs, vae_imgs, abl_no_imgs, abl_fixed_imgs] 77 | 78 | err(functs) 79 | # of 8: 80 | # cinn 3.53 0.04 81 | # vae 4.06 0.04 82 | # cnn 6.77 0.05 83 | # cgan 9.75 0.06 84 | # abl no 85 | # abl fixed 86 | # of 1: 87 | # cinn 9.52 0.06 88 | # vae 8.40 0.07 89 | # cnn 6.77 0.05 90 | # cgan 9.75 0.06 91 | # abl no 92 | # abl fixed 93 | 94 | variance(functs_diverse) 95 | # cinn 0.00352 2.71e-05 96 | # vae 0.00210 2.13e-05 97 | # abl no 98 | # abl fixed 99 | -------------------------------------------------------------------------------- /colorization_cINN/config.py: -------------------------------------------------------------------------------- 1 | ################# 2 | # Architecture: # 3 | ################# 4 | 5 | # Image size of L, and ab channels respectively: 6 | img_dims_orig = (256, 256) 7 | img_dims = (img_dims_orig[0] // 4, img_dims_orig[0] // 4) 8 | # Clamping parameter in the coupling blocks (higher = less stable but more expressive) 9 | clamping = 1.5 10 | 11 | ############################# 12 | # Training hyperparameters: # 13 | ############################# 14 | 15 | seed = 9287 16 | batch_size = 48 17 | device_ids = [0,1,2] # GPU ids. Set to [0] for single GPU 18 | 19 | log10_lr = -4.0 # Log learning rate 20 | lr = 10**log10_lr 21 | lr_feature_net = lr # lr of the cond. network 22 | 23 | n_epochs = 120 * 4 24 | n_its_per_epoch = 32 * 8 # In case the epochs should be cut short after n iterations 25 | 26 | weight_decay = 1e-5 27 | betas = (0.9, 0.999) # concerning adam optimizer 28 | 29 | init_scale = 0.030 # initialization std. dev. of weights (0.03 is approx xavier) 30 | pre_low_lr = 0 # for the first n epochs, lower the lr by a factor of 20 31 | 32 | ####################### 33 | # Dataset parameters: # 34 | ####################### 35 | 36 | dataset = 'imagenet' # also 'coco' is possible. Todo: places365 37 | validation_images = './imagenet/validation_images.txt' 38 | shuffle_val = False 39 | val_start = 0 # use a slice [start:stop] of the entire val set 40 | val_stop = 5120 41 | 42 | end_to_end = True # Whether to leave the cond. net fixed 43 | no_cond_net = False # Whether to use a cond. net at all 44 | pretrain_epochs = 0 # Train only the inn for n epochs before end-to-end 45 | 46 | ######################## 47 | # Display and logging: # 48 | ######################## 49 | 50 | sampling_temperature = 1.0 # latent std. dev. for preview images 51 | loss_display_cutoff = 10 # cut off the loss so the plot isn't ruined 52 | loss_names = ['L', 'lr'] 53 | preview_upscale = 256 // img_dims_orig[0] 54 | img_folder = './images' 55 | silent = False 56 | live_visualization = False 57 | progress_bar = False 58 | 59 | ####################### 60 | # Saving checkpoints: # 61 | ####################### 62 | 63 | load_inn_only = '' # only load the inn part of the architecture 64 | load_feature_net = '' # only load the cond. net part 65 | load_file = '' # load entire architecture (overwrites the prev. 2 options) 66 | filename = 'output/full_model.pt' # output filename 67 | 68 | checkpoint_save_interval = 60 69 | checkpoint_save_overwrite = False # Whether to overwrite the old checkpoint with the new one 70 | checkpoint_on_error = True # Wheter to make a checkpoint with suffix _ABORT if an error occurs 71 | -------------------------------------------------------------------------------- /colorization_cINN/data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import glob 3 | from os.path import join 4 | from multiprocessing import Pool 5 | 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from skimage import io, color 9 | from PIL import Image, ImageEnhance 10 | import torch 11 | from torch.utils.data import Dataset, DataLoader, TensorDataset 12 | import torch.nn.functional as F 13 | import torchvision.transforms as T 14 | from tqdm import tqdm 15 | 16 | import joint_bilateral_filter as jbf 17 | import config as c 18 | 19 | offsets = (47.5, 2.4, 7.4) 20 | scales = (25.6, 11.2, 16.8) 21 | 22 | def apply_filt(args): 23 | '''multiprocessing wrapper for applying the joint bilateral filter''' 24 | L_i, ab_i = args 25 | return jbf.upsample(L_i[0], ab_i, s_x=6, s_l=0.10) 26 | 27 | def norm_lab_to_rgb(L, ab, norm=True, filt=False, bw=False): 28 | '''given an Nx1xWxH Tensor L and an Nx2xwxh Tensor ab, normalized accoring to offsets and 29 | scales above, upsample the ab channels and combine with L, and form an RGB image. 30 | 31 | norm: If false, assume that L, ab are not normalized and already in the correct range 32 | filt: Use joint bilateral upsamling to do the upsampling. Slow, but improves image quality. 33 | bw: Simply produce a grayscale RGB, ignoring the ab channels''' 34 | 35 | if bw: 36 | filt=False 37 | 38 | if filt: 39 | with Pool(12) as p: 40 | ab_up_list = p.map(apply_filt, [(L[i], ab[i]) for i in range(len(L))]) 41 | 42 | ab = np.stack(ab_up_list, axis=0) 43 | ab = torch.Tensor(ab) 44 | else: 45 | ab = F.interpolate(ab, size=L.shape[2], mode='bilinear') 46 | 47 | lab = torch.cat([L, ab], dim=1) 48 | 49 | for i in range(1 + 2*norm): 50 | lab[:, i] = lab[:, i] * scales[i] + offsets[i] 51 | 52 | lab[:, 0].clamp_(0., 100.) 53 | lab[:, 1:].clamp_(-128, 128) 54 | if bw: 55 | lab[:, 1:].zero_() 56 | 57 | lab = lab.cpu().data.numpy() 58 | rgb = [color.lab2rgb(np.transpose(l, (1, 2, 0))).transpose(2, 0, 1) for l in lab] 59 | return np.array(rgb) 60 | 61 | class LabColorDataset(Dataset): 62 | def __init__(self, file_list, transform=None): 63 | 64 | self.files = file_list 65 | self.transform = transform 66 | self.to_tensor = T.ToTensor() 67 | 68 | def __len__(self): 69 | return len(self.files) 70 | 71 | def __getitem__(self, idx): 72 | 73 | im = Image.open(self.files[idx]) 74 | if self.transform: 75 | im = self.transform(im) 76 | im = self.to_tensor(im).numpy() 77 | 78 | try: 79 | if im.shape[0] == 1: 80 | im = np.concatenate([im]*3, axis=0) 81 | if im.shape[0] == 4: 82 | im = im[:3] 83 | 84 | im = np.transpose(im, (1,2,0)) 85 | im = color.rgb2lab(im).transpose((2, 0, 1)) 86 | for i in range(3): 87 | im[i] = (im[i] - offsets[i]) / scales[i] 88 | return torch.Tensor(im) 89 | 90 | except: 91 | return self.__getitem__(idx+1) 92 | 93 | 94 | # Data transforms for training and test/validation set 95 | transf = T.Compose([T.RandomHorizontalFlip(), 96 | T.RandomResizedCrop(c.img_dims_orig[0], scale=(0.2, 1.))]) 97 | transf_test = T.Compose([T.Resize(c.img_dims_orig[0]), 98 | T.CenterCrop(c.img_dims_orig[0])]) 99 | 100 | if c.dataset == 'imagenet': 101 | with open('./imagenet/training_images.txt') as f: 102 | train_list = [join('./imagenet', fname[2:]) for fname in f.read().splitlines()] 103 | with open(c.validation_images) as f: 104 | test_list = [ t for t in f.read().splitlines()if t[0] != '#'] 105 | test_list = [join('./imagenet', fname) for fname in test_list] 106 | if c.val_start is not None: 107 | test_list = test_list[c.val_start:c.val_stop] 108 | else: 109 | data_dir = '/home/diz/data/coco17' 110 | complete_list = sorted(glob.glob(join(data_dir, '*.jpg'))) 111 | train_list = complete_list[64:] 112 | test_list = complete_list[64:] 113 | 114 | 115 | train_data = LabColorDataset(train_list,transf) 116 | test_data = LabColorDataset(test_list, transf_test) 117 | 118 | train_loader = DataLoader(train_data, batch_size=c.batch_size, shuffle=True, num_workers=8, pin_memory=True, drop_last=True) 119 | test_loader = DataLoader(test_data, batch_size=min(64, len(test_list)), shuffle=c.shuffle_val, num_workers=4, pin_memory=True, drop_last=False) 120 | 121 | if __name__ == '__main__': 122 | # Determine mean and standard deviation of RGB channels 123 | # (i.e. set global variables scale and offsets to 1., then use the results as new scale and offset) 124 | 125 | for x in test_loader: 126 | x_l, x_ab, _, x_ab_pred = model.prepare_batch(x) 127 | #continue 128 | img_gt = norm_lab_to_rgb(x_l, x_ab) 129 | img_pred = norm_lab_to_rgb(x_l, x_ab_pred) 130 | for i in range(c.batch_size): 131 | plt.figure() 132 | plt.subplot(2,2,1) 133 | plt.imshow(img_gt[i].transpose(1,2,0)) 134 | plt.subplot(2,2,2) 135 | plt.scatter(x_ab[i, 0].cpu().numpy().flatten() * scales[1] + offsets[1], 136 | x_ab[i, 1].cpu().numpy().flatten() * scales[2] + offsets[2], label='gt') 137 | 138 | plt.scatter(x_ab_pred[i, 0].cpu().numpy().flatten() * scales[1] + offsets[1], 139 | x_ab_pred[i, 1].cpu().numpy().flatten() * scales[2] + offsets[2], label='pred') 140 | 141 | plt.legend() 142 | plt.subplot(2,2,3) 143 | plt.imshow(img_pred[i].transpose(1,2,0)) 144 | 145 | plt.show() 146 | sys.exit() 147 | 148 | means = [] 149 | stds = [] 150 | 151 | for i, x in enumerate(train_loader): 152 | print('\r', '%i / %i' % (i, len(train_loader)), end='') 153 | mean = [] 154 | std = [] 155 | for i in range(3): 156 | mean.append(x[:, i].mean().item()) 157 | std.append(x[:, i].std().item()) 158 | 159 | means.append(mean) 160 | stds.append(std) 161 | 162 | if i >= 1000: 163 | break 164 | 165 | means, stds = np.array(means), np.array(stds) 166 | 167 | print() 168 | print('Mean ', means.mean(axis=0)) 169 | print('Std dev', stds.mean(axis=0)) 170 | 171 | #[-0.04959071 0.03768991 0.11539354] 172 | #[0.51175581 0.17507738 0.26179135] 173 | -------------------------------------------------------------------------------- /colorization_cINN/eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | ''' 3 | Usage: ./eval.py model_checkpoint_file [val_start_index, val_stop_index] 4 | model_checkpoint_file: Path of the checkpoint 5 | optional val_start/stop_index: Only use validation images between these indexes 6 | (Useful for GNU-parallel etc.) 7 | ''' 8 | import glob 9 | import sys 10 | from os.path import join 11 | import os 12 | 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from skimage import color 17 | from PIL import Image 18 | from skimage import color 19 | from sklearn.decomposition import PCA 20 | import matplotlib 21 | matplotlib.use('Agg') 22 | import matplotlib.pyplot as plt 23 | import torchvision.transforms as T 24 | from tqdm import tqdm 25 | from scipy.ndimage.filters import uniform_filter, gaussian_filter 26 | 27 | import config as c 28 | if len(sys.argv) > 2: 29 | c.val_start = int(sys.argv[2]) 30 | c.val_stop = int(sys.argv[3]) 31 | 32 | if c.no_cond_net: 33 | import model_no_cond as model 34 | else: 35 | import model 36 | 37 | import data 38 | from data import test_loader 39 | 40 | # Some global definitions: 41 | # ========================= 42 | # Whether to use the joint bilateral filter for upsampling (slow but better quality) 43 | JBF_FILTER = True 44 | # Use only a selection of val images, e.g. 45 | # VAL_SELECTION = [0,1,5,15] 46 | # per default uses all: 47 | VAL_SELECTION = list(range(len(data.test_list))) 48 | 49 | if len(sys.argv) > 1: 50 | model_name = sys.argv[1] 51 | else: 52 | model_name = c.filename 53 | 54 | model.load(model_name) 55 | 56 | model.combined_model.eval() 57 | model.combined_model.module.inn.eval() 58 | 59 | if not c.no_cond_net: 60 | model.combined_model.module.feature_network.eval() 61 | model.combined_model.module.fc_cond_network.eval() 62 | 63 | def show_imgs(imgs, save_as): 64 | '''Save a set of images in a directory (e.g. a set of diverse colorizations 65 | for a single grayscale image) 66 | imgs: List of 3xWxH images (numpy or torch tensors), or Nx3xWxH torch tensor 67 | save_as: directory name to save the images in''' 68 | 69 | 70 | imgs_np = [] 71 | 72 | for im in imgs: 73 | try: 74 | im_np = im.data.cpu().numpy() 75 | imgs_np.append(im_np) 76 | except: 77 | imgs_np.append(im) 78 | 79 | try: 80 | os.mkdir(join(c.img_folder, save_as)) 81 | except OSError: 82 | pass 83 | 84 | for i, im in enumerate(imgs_np): 85 | im = np.transpose(im, (1,2,0)) 86 | 87 | if im.shape[2] == 1: 88 | im = np.concatenate([im]*3, axis=2) 89 | 90 | plt.imsave(join(c.img_folder, save_as, '%.2i' % (i)), im) 91 | 92 | # Run a single batch to infer the shapes etc.: 93 | 94 | for x in test_loader: 95 | test_set = x 96 | break 97 | 98 | with torch.no_grad(): 99 | x_l, x_ab, cond, ab_pred = model.prepare_batch(test_set) 100 | 101 | outputs = model.cinn(x_ab, cond) 102 | jac = model.cinn.jacobian(run_forward=False) 103 | tot_output_size = 2 * c.img_dims[0] * c.img_dims[1] 104 | 105 | def sample_z(N, T=1.0): 106 | ''' Sample N latent vectors, with a sampling temperature T''' 107 | sampled_z = [] 108 | for o in outputs: 109 | shape = list(o.shape) 110 | shape[0] = N 111 | sampled_z.append(torch.randn(shape).cuda()) 112 | 113 | return sampled_z 114 | 115 | def sample_resolution_levels(level, z_fixed, N=8, temp=1.): 116 | '''Generate images with latent code `z_fixed`, but replace the latent dimensions 117 | at resolution level `level` with random ones. 118 | N: number of random samples 119 | temp: sampling temperature 120 | naming of output files: __.png''' 121 | 122 | assert len(test_loader) == 1, "please use only one batch worth of images" 123 | 124 | for n in range(N): 125 | counter = 0 126 | for x in tqdm(test_loader): 127 | with torch.no_grad(): 128 | 129 | z = sample_z(x.shape[0], temp) 130 | z_fixed[3-level] = z[3-level] 131 | 132 | x_l, x_ab, cond, ab_pred = model.prepare_batch(x) 133 | 134 | ab_gen = model.combined_model.module.reverse_sample(z_fixed, cond) 135 | rgb_gen = data.norm_lab_to_rgb(x_l.cpu(), ab_gen.cpu(), filt=True) 136 | 137 | for im in rgb_gen: 138 | im = np.transpose(im, (1,2,0)) 139 | plt.imsave(join(c.img_folder, '%.6i_%i_%.3i.png' % (counter, level, n)), im) 140 | counter += 1 141 | 142 | def colorize_batches(temp=1., postfix=0, filt=True): 143 | '''Colorize the whole validation set once. 144 | temp: Sampling temperature 145 | postfix: Has to be int. Append to file name (e.g. make 10 diverse colorizations of val. set) 146 | filt: Whether to use JBF 147 | ''' 148 | counter = 0 149 | for x in tqdm(test_loader): 150 | with torch.no_grad(): 151 | z = sample_z(x.shape[0], temp) 152 | x_l, x_ab, cond, ab_pred = model.prepare_batch(x) 153 | 154 | ab_gen = model.combined_model.module.reverse_sample(z, cond) 155 | rgb_gen = data.norm_lab_to_rgb(x_l.cpu(), ab_gen.cpu(), filt=filt) 156 | 157 | for im in rgb_gen: 158 | im = np.transpose(im, (1,2,0)) 159 | plt.imsave(join(c.img_folder, '%.6i_%.3i.png' % (counter, postfix)), im) 160 | counter += 1 161 | 162 | def interpolation_grid(val_ind=0, grid_size=5, max_temp=0.9, interp_power=2): 163 | ''' 164 | Make a grid of a 2D latent space interpolation. 165 | val_ind: Which image to use (index in current val. set) 166 | grid_size: Grid size in each direction 167 | max_temp: Maximum temperature to scale to in each direction (note that the corners 168 | will have temperature sqrt(2)*max_temp 169 | interp_power: Interpolate with (linspace(-lim**p, +lim**p))**(1/p) instead of linear. 170 | Because little happens between t = 0.0...0.7, we don't want this to take up the 171 | whole grid. p>1 gives more space to the temperatures closer to 1. 172 | ''' 173 | steps = np.linspace(-(max_temp**interp_power), max_temp**interp_power, grid_size, endpoint=True) 174 | steps = np.sign(steps) * np.abs(steps)**(1./interp_power) 175 | 176 | test_im = [] 177 | for i,x in enumerate(test_loader): 178 | test_im.append(x) 179 | 180 | test_im = torch.cat(test_im, dim=0) 181 | test_im = torch.stack([test_im[i] for i in VAL_SELECTION], dim=0) 182 | test_im = torch.cat([test_im[val_ind:val_ind+1]]*grid_size**2, dim=0).cuda() 183 | 184 | 185 | def interp_z(z0, z1, a0, a1): 186 | z_out = [] 187 | for z0_i, z1_i in zip(z0, z1): 188 | z_out.append(a0 * z0_i + a1 * z1_i) 189 | return z_out 190 | 191 | torch.manual_seed(c.seed+val_ind) 192 | z0 = sample_z(1, 1.) 193 | z1 = sample_z(1, 1.) 194 | 195 | z_grid = [] 196 | for dk in steps: 197 | for dl in steps: 198 | z_grid.append(interp_z(z0, z1, dk, dl)) 199 | 200 | z_grid = [torch.cat(z_i, dim=0) for z_i in list(map(list, zip(*z_grid)))] 201 | 202 | with torch.no_grad(): 203 | x_l, x_ab, cond, ab_pred = model.prepare_batch(test_im) 204 | ab_gen = model.combined_model.module.reverse_sample(z_grid, cond) 205 | 206 | rgb_gen = data.norm_lab_to_rgb(x_l.cpu(), ab_gen.cpu(), filt=True) 207 | 208 | for i,im in enumerate(rgb_gen): 209 | im = np.transpose(im, (1,2,0)) 210 | plt.imsave(join(c.img_folder, '%.6i_%.3i.png' % (val_ind, i)), im) 211 | 212 | def flow_visualization(val_ind=0, n_samples=2): 213 | 214 | test_im = [] 215 | for i,x in enumerate(test_loader): 216 | test_im.append(x) 217 | 218 | test_im = torch.cat(test_im, dim=0) 219 | test_im = torch.stack([test_im[i] for i in VAL_SELECTION], dim=0) 220 | test_im = torch.cat([test_im[val_ind:val_ind+1]]*n_samples, dim=0).cuda() 221 | 222 | torch.manual_seed(c.seed) 223 | z = sample_z(n_samples, 1.) 224 | 225 | block_idxs = [(1,7), (11,13), (14,18), (19,24), (28,32), 226 | (34,44), (48,52), (54,64), (68,90)] 227 | block_steps = [12, 10, 10, 10, 12, 12, 10, 16, 12] 228 | 229 | #scales = [0.9, 0.9, 0.7, 0.5, 0.5, 0.2] 230 | z_levels = [3,5,7] 231 | min_max_final = None 232 | 233 | def rescale_min_max(ab, new_min, new_max, soft_factor=0.): 234 | min_ab = torch.min(torch.min(ab, 3, keepdim=True)[0], 2, keepdim=True)[0] 235 | max_ab = torch.max(torch.max(ab, 3, keepdim=True)[0], 2, keepdim=True)[0] 236 | 237 | new_min = (1. - soft_factor) * new_min - soft_factor * 6 238 | new_max = (1. - soft_factor) * new_max + soft_factor * 6 239 | 240 | ab = (ab - min_ab) / (max_ab - min_ab) 241 | return ab * (new_max - new_min) + new_min 242 | 243 | with torch.no_grad(): 244 | x_l, x_ab, cond, ab_pred = model.prepare_batch(test_im) 245 | x_l_flat = torch.zeros(x_l.shape) 246 | #x_l_flat *= x_l.mean().item() 247 | 248 | frame_counter = 0 249 | 250 | for level, (k_start, k_stop) in enumerate(block_idxs): 251 | print('level', level) 252 | interp_steps = block_steps[level] 253 | scales = np.linspace(1., 1e-3, interp_steps + 1) 254 | scales = scales[1:] / scales[:-1] 255 | 256 | for i_interp in tqdm(range(interp_steps)): 257 | 258 | ab_gen = model.combined_model.module.reverse_sample(z, cond).cpu() 259 | ab_gen = torch.Tensor([[gaussian_filter(x, sigma=2. * (frame_counter / sum(block_steps))) for x in ab] for ab in ab_gen]) 260 | 261 | if min_max_final is None: 262 | min_max_final = (torch.min(torch.min(ab_gen, 3, keepdim=True)[0], 2, keepdim=True)[0], 263 | torch.max(torch.max(ab_gen, 3, keepdim=True)[0], 2, keepdim=True)[0]) 264 | else: 265 | ab_gen = rescale_min_max(ab_gen, *min_max_final, 266 | soft_factor=(frame_counter/sum(block_steps))**2) 267 | 268 | if frame_counter == 0: 269 | rgb_gen = data.norm_lab_to_rgb(x_l.cpu(), ab_gen, filt=True) 270 | for j in range(rgb_gen.shape[0]): 271 | im = rgb_gen[j] 272 | im = np.transpose(im, (1,2,0)) 273 | plt.imsave(join(c.img_folder, 'flow/%.6i_%.3i_final_merged.png' % (val_ind, j+12)), im) 274 | 275 | colors_gen = data.norm_lab_to_rgb(x_l_flat, (1. + 0.2 * (frame_counter / sum(block_steps))) * ab_gen, filt=False) 276 | 277 | for j,im in enumerate(colors_gen): 278 | im = np.transpose(im, (1,2,0)) 279 | im_color = np.transpose(colors_gen[j], (1,2,0)) 280 | #plt.imsave(join(c.img_folder, 'flow/%.6i_%.3i_%.3i.png' % (val_ind, j, frame_counter)), im) 281 | plt.imsave(join(c.img_folder, 'flow/%.6i_%.3i_%.3i_c.png' % (val_ind, j+12, frame_counter)), im_color) 282 | frame_counter += 1 283 | 284 | #if level in z_levels: 285 | #z[z_levels.index(level)] *= scales[i_interp] 286 | #z[-1] *= 1.1 287 | 288 | for k_block in range(k_start,k_stop+1): 289 | for key,p in model.combined_model.module.inn.named_parameters(): 290 | split = key.split('.') 291 | if f'module_list.{k_block}.' in key and p.requires_grad: 292 | split = key.split('.') 293 | if len(split) > 3 and split[3][-1] == '3' and split[2] != 'subnet': 294 | p.data *= scales[i_interp] 295 | 296 | for k in range(k_start,k_stop+1): 297 | for k,p in model.combined_model.module.inn.named_parameters(): 298 | if f'module_list.{i}.' in k and p.requires_grad: 299 | p.data *= 0.0 300 | 301 | #if level in z_levels: 302 | #z[z_levels.index(level)] *= 0 303 | 304 | state_dict = torch.load(model_name)['net'] 305 | orig_state = model.combined_model.state_dict() 306 | for name, param in state_dict.items(): 307 | if 'tmp_var' in name: 308 | continue 309 | if isinstance(param, nn.Parameter): 310 | param = param.data 311 | try: 312 | orig_state[name].copy_(param) 313 | except RuntimeError: 314 | print() 315 | print(name) 316 | print() 317 | raise 318 | 319 | 320 | def colorize_test_set(): 321 | '''This function is deprecated, for the sake of `colorize_batches`. 322 | It loops over the image index at the outer level and diverse samples at inner level, 323 | so it may be useful if you want to adapt it.''' 324 | test_set = [] 325 | for i,x in enumerate(test_loader): 326 | test_set.append(x) 327 | 328 | test_set = torch.cat(test_set, dim=0) 329 | test_set = torch.stack([test_set[i] for i in VAL_SELECTION], dim=0) 330 | 331 | with torch.no_grad(): 332 | temperatures = [] 333 | 334 | rgb_bw = data.norm_lab_to_rgb(x_l.cpu(), 0.*x_ab.cpu(), filt=False) 335 | rgb_gt = data.norm_lab_to_rgb(x_l.cpu(), x_ab.cpu(), filt=JBF_FILTER) 336 | 337 | for i, o in enumerate(outputs): 338 | std = torch.std(o).item() 339 | temperatures.append(1.0) 340 | 341 | zz = sum(torch.sum(o**2, dim=1) for o in outputs) 342 | log_likeli = 0.5 * zz - jac 343 | log_likeli /= tot_output_size 344 | print() 345 | print(torch.mean(log_likeli).item()) 346 | print() 347 | 348 | def sample_z(N, temps=temperatures): 349 | sampled_z = [] 350 | for o, t in zip(outputs, temps): 351 | shape = list(o.shape) 352 | shape[0] = N 353 | sampled_z.append(t * torch.randn(shape).cuda()) 354 | 355 | return sampled_z 356 | 357 | N = 9 358 | sample_new = True 359 | 360 | for i,n in enumerate(VAL_SELECTION): 361 | print(i) 362 | x_i = torch.cat([test_set[i:i+1]]*N, dim=0) 363 | x_l_i, x_ab_i, cond_i, ab_pred_i = model.prepare_batch(x_i) 364 | if sample_new: 365 | z = sample_z(N) 366 | 367 | ab_gen = model.combined_model.module.reverse_sample(z, cond_i) 368 | rgb_gen = data.norm_lab_to_rgb(x_l_i.cpu(), ab_gen.cpu(), filt=JBF_FILTER) 369 | 370 | i_save = n 371 | if c.val_start: 372 | i_save += c.val_start 373 | show_imgs([rgb_gt[i], rgb_bw[i]] + list(rgb_gen), '%.6i_%.3i' % (i_save, i)) 374 | 375 | def color_transfer(): 376 | '''Transfers latent code from images to some new conditioning image (see paper Fig. 13) 377 | Uses images from the directory ./transfer. See code for changing which images are used.''' 378 | 379 | with torch.no_grad(): 380 | cond_images = [] 381 | ref_images = [] 382 | images = ['00', '01', '02'] 383 | for im in images: 384 | cond_images += [F'./transfer/{im}_c.jpg']*3 385 | ref_images += [F'./transfer/{im}_{j}.jpg' for j in range(3)] 386 | 387 | def load_image(fname): 388 | im = Image.open(fname) 389 | im = data.transf_test(im) 390 | im = data.test_data.to_tensor(im).numpy() 391 | im = np.transpose(im, (1,2,0)) 392 | im = color.rgb2lab(im).transpose((2, 0, 1)) 393 | 394 | for i in range(3): 395 | im[i] = (im[i] - data.offsets[i]) / data.scales[i] 396 | return torch.Tensor(im) 397 | 398 | cond_inputs = torch.stack([load_image(f) for f in cond_images], dim=0) 399 | ref_inputs = torch.stack([load_image(f) for f in ref_images], dim=0) 400 | 401 | L, x, cond, _ = model.prepare_batch(ref_inputs) 402 | L_new, _, cond_new, _ = model.prepare_batch(cond_inputs) 403 | 404 | z = model.combined_model.module.inn(x, cond) 405 | z_rand = sample_z(len(ref_images)) 406 | 407 | for zi in z: 408 | print(zi.shape) 409 | 410 | for i, (s,t) in enumerate([(1.0,1), (0.7,1), (0.0,1.0), (0,1.0)]): 411 | z_rand[i] = np.sqrt(s) * z_rand[i] + np.sqrt(1.-s) * z[i] 412 | 413 | x_new = model.combined_model.module.reverse_sample(z_rand, cond_new) 414 | 415 | im_ref = data.norm_lab_to_rgb(L.cpu(), x.cpu(), filt=True) 416 | im_cond = data.norm_lab_to_rgb(L_new.cpu(), 0*x_new.cpu(), bw=True) 417 | im_new = data.norm_lab_to_rgb(L_new.cpu(), x_new.cpu(), filt=True) 418 | 419 | for i, im in enumerate(ref_images): 420 | show_imgs([im_ref[i], im_cond[i], im_new[i]], im.split('/')[-1].split('.')[0]) 421 | 422 | def find_map(): 423 | '''For a given conditioning, try to find the maximum likelihood colorization. 424 | It doesn't work, but I left in the function to play around with''' 425 | 426 | import torch.nn as nn 427 | import torch.optim 428 | z_optim = [] 429 | parameters = [] 430 | 431 | z_random = sample_z(4*len(VAL_SELECTION)) 432 | for i, opt in enumerate([False]*2 + [True]*2): 433 | if opt: 434 | z_optim.append(nn.Parameter(z_random[i])) 435 | parameters.append(z_optim[-1]) 436 | else: 437 | z_optim.append(z_random[i]) 438 | 439 | optimizer = torch.optim.Adam(parameters, lr = 0.1)#, momentum=0.0, weight_decay=0) 440 | 441 | cond_4 = [torch.cat([c]*4, dim=0) for c in cond] 442 | for i in range(100): 443 | for k in range(10): 444 | optimizer.zero_grad() 445 | zz = sum(torch.sum(o**2, dim=1) for o in z_optim) 446 | x_new = model.combined_model.module.reverse_sample(z_optim, cond_4) 447 | jac = model.combined_model.module.inn.jacobian(run_forward=False, rev=True) 448 | 449 | log_likeli = 0.5 * zz + jac 450 | log_likeli /= tot_output_size 451 | 452 | log_likeli = (torch.mean(log_likeli) 453 | # Regularizer: variance within image 454 | + 0.1 * torch.mean(torch.log(torch.std(x_new[:, 0].view(4*len(VAL_SELECTION), -1), dim=1))**2 455 | + torch.log(torch.std(x_new[:, 1].view(4*len(VAL_SELECTION), -1), dim=1))**2) 456 | # Regularizer: variance across images 457 | + 0.1 * torch.mean(torch.log(torch.std(x_new, dim=0))**2)) 458 | 459 | log_likeli.backward() 460 | optimizer.step() 461 | 462 | if (i%10) == 0: 463 | show_imgs(list(data.norm_lab_to_rgb(torch.cat([x_l]*4, 0), x_new, filt=False)), '%.4i' % i) 464 | 465 | print(i, '\t', log_likeli.item(), '\t', 0.25 * sum(torch.std(z_optim[k]).item() for k in range(4))) 466 | 467 | def latent_space_pca(img_names = ['zebra']): 468 | '''This wasn't used in the paper or worked on in a while. 469 | Perform PCA on latent space to see where images lie in relation to each other. 470 | See code for details.''' 471 | 472 | image_characteristics = [] 473 | 474 | for img_name in img_names: 475 | img_base = './demo_images/' + img_name 476 | high_sat = sorted(glob.glob(img_base + '_???.png')) 477 | #low_sat = sorted(glob.glob(img_base + '_b_???.png')) 478 | low_sat = [] 479 | 480 | to_tensor = T.ToTensor() 481 | 482 | demo_imgs = [] 483 | repr_colors = [] 484 | 485 | for fname in high_sat + low_sat: 486 | print(fname) 487 | 488 | im = plt.imread(fname) 489 | if img_name == 'zebra': 490 | repr_colors.append(np.mean(im[0:50, -50:, :], axis=(0,1))) 491 | elif img_name == 'zebra_blurred': 492 | repr_colors.append(np.mean(im[0:50, -50:, :], axis=(0,1))) 493 | elif img_name == 'snowboards': 494 | repr_colors.append(np.mean(im[50:60, 130:140, :], axis=(0,1))) 495 | else: 496 | raise ValueError 497 | 498 | im = color.rgb2lab(im).transpose((2, 0, 1)) 499 | for i in range(3): 500 | im[i] = (im[i] - data.offsets[i]) / data.scales[i] 501 | 502 | demo_imgs.append(torch.Tensor(im).expand(1, -1, -1, -1)) 503 | 504 | demo_imgs = torch.cat(demo_imgs, dim=0) 505 | x_l, x_ab, cond, ab_pred = model.prepare_batch(demo_imgs) 506 | 507 | outputs = model.cinn(x_ab, cond) 508 | jac = model.cinn.jacobian(run_forward=False) 509 | 510 | if c.n_downsampling < 2: 511 | outputs = [outputs] 512 | 513 | outputs_cat = torch.cat(outputs, dim=1) 514 | outputs_cat = outputs_cat.cpu().numpy() 515 | jac = jac.cpu().numpy() 516 | 517 | zz = np.sum(outputs_cat**2, axis=1) 518 | log_likeli = - zz / 2. + np.abs(jac) 519 | log_likeli /= outputs_cat.shape[1] 520 | print(log_likeli) 521 | repr_colors = np.array(repr_colors) 522 | 523 | image_characteristics.append([log_likeli, outputs_cat, repr_colors]) 524 | 525 | 526 | log_likeli_combined = np.concatenate([C[0] for C in image_characteristics], axis=0) 527 | outputs_combined = np.concatenate([C[1] for C in image_characteristics], axis=0) 528 | 529 | pca = PCA(n_components=2) 530 | pca.fit(outputs_combined) 531 | 532 | for i, img_name in enumerate(img_names): 533 | log_likeli, outputs_cat, repr_colors = image_characteristics[i] 534 | 535 | 536 | size = 10 + (40 * (log_likeli - np.min(log_likeli_combined)) / (np.max(log_likeli_combined) - np.min(log_likeli_combined)))**2 537 | outputs_pca = pca.transform(outputs_cat) 538 | center = pca.transform(np.zeros((2, outputs_cat.shape[1]))) 539 | 540 | plt.figure(figsize=(9,9)) 541 | plt.scatter(outputs_pca[:len(high_sat), 0], outputs_pca[:len(high_sat), 1], s=size[:len(high_sat)], c=repr_colors[:len(high_sat)]) 542 | #plt.scatter(outputs_pca[len(high_sat):, 0], outputs_pca[len(high_sat):, 1], s=size[len(high_sat):], c=repr_colors[len(high_sat):]) 543 | #plt.colorbar() 544 | #plt.scatter(center[:, 0], center[:, 1], c='black', marker='+', s=150) 545 | plt.xlim(-100, 100) 546 | plt.ylim(-100, 100) 547 | plt.savefig(F'colorspace_{img_name}.png', dpi=200) 548 | 549 | if __name__ == '__main__': 550 | pass 551 | 552 | # Comment in which ever you want to run: 553 | # ======================================== 554 | 555 | #for i in tqdm(range(len(data.test_list))): 556 | for i in [110, 122]: 557 | print(i) 558 | flow_visualization(i, n_samples=10) 559 | 560 | #for i in tqdm(range(len(data.test_list))): 561 | #interpolation_grid(i) 562 | 563 | #latent_space_pca() 564 | 565 | #colorize_test_set() 566 | 567 | #for i in range(8): 568 | #torch.manual_seed(i+c.seed) 569 | #colorize_batches(postfix=i, temp=1.0, filt=False) 570 | 571 | #for i in range(6): 572 | #torch.manual_seed(c.seed) 573 | #z_fixed = sample_z(outputs[0].shape[0], 0.0000001) 574 | #sample_resolution_levels(i, z_fixed) 575 | 576 | #color_transfer() 577 | 578 | #find_map() 579 | -------------------------------------------------------------------------------- /colorization_cINN/feature_net.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | __weights_dict = dict() 7 | 8 | def load_weights(weight_file): 9 | if weight_file == None: 10 | return 11 | 12 | try: 13 | weights_dict = np.load(weight_file).item() 14 | except: 15 | weights_dict = np.load(weight_file, encoding='bytes').item() 16 | 17 | return weights_dict 18 | 19 | class KitModel(nn.Module): 20 | 21 | 22 | def __init__(self, weight_file): 23 | super(KitModel, self).__init__() 24 | global __weights_dict 25 | __weights_dict = load_weights(weight_file) 26 | 27 | self.bw_conv1_1 = self.__conv(2, name='bw_conv1_1', in_channels=1, out_channels=64, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True, dilation=1, padding=1) 28 | self.conv1_2 = self.__conv(2, name='conv1_2', in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(2, 2), groups=1, bias=True, dilation=1, padding=1) 29 | self.conv1_2norm = self.__batch_normalization(2, 'conv1_2norm', num_features=64, eps=9.999999747378752e-06, momentum=0.1) 30 | self.conv2_1 = self.__conv(2, name='conv2_1', in_channels=64, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True, dilation=1, padding=1) 31 | self.conv2_2 = self.__conv(2, name='conv2_2', in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(2, 2), groups=1, bias=True, dilation=1, padding=1) 32 | self.conv2_2norm = self.__batch_normalization(2, 'conv2_2norm', num_features=128, eps=9.999999747378752e-06, momentum=0.1) 33 | self.conv3_1 = self.__conv(2, name='conv3_1', in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True, dilation=1, padding=1) 34 | self.conv3_2 = self.__conv(2, name='conv3_2', in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True, dilation=1, padding=1) 35 | self.conv3_3 = self.__conv(2, name='conv3_3', in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(2, 2), groups=1, bias=True, dilation=1, padding=1) 36 | self.conv3_3norm = self.__batch_normalization(2, 'conv3_3norm', num_features=256, eps=9.999999747378752e-06, momentum=0.1) 37 | self.conv4_1 = self.__conv(2, name='conv4_1', in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True, dilation=1, padding=1) 38 | self.conv4_2 = self.__conv(2, name='conv4_2', in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True, dilation=1, padding=1) 39 | self.conv4_3 = self.__conv(2, name='conv4_3', in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True, dilation=1, padding=1) 40 | self.conv4_3norm = self.__batch_normalization(2, 'conv4_3norm', num_features=512, eps=9.999999747378752e-06, momentum=0.1) 41 | self.conv5_1 = self.__conv(2, name='conv5_1', in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True, dilation=2, padding=2) 42 | self.conv5_2 = self.__conv(2, name='conv5_2', in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True, dilation=2, padding=2) 43 | self.conv5_3 = self.__conv(2, name='conv5_3', in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True, dilation=2, padding=2) 44 | self.conv5_3norm = self.__batch_normalization(2, 'conv5_3norm', num_features=512, eps=9.999999747378752e-06, momentum=0.1) 45 | self.conv6_1 = self.__conv(2, name='conv6_1', in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True, dilation=2, padding=2) 46 | self.conv6_2 = self.__conv(2, name='conv6_2', in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True, dilation=2, padding=2) 47 | self.conv6_3 = self.__conv(2, name='conv6_3', in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True, dilation=2, padding=2) 48 | self.conv6_3norm = self.__batch_normalization(2, 'conv6_3norm', num_features=512, eps=9.999999747378752e-06, momentum=0.1) 49 | self.conv7_1 = self.__conv(2, name='conv7_1', in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True, dilation=1, padding=1) 50 | self.conv7_2 = self.__conv(2, name='conv7_2', in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True, dilation=1, padding=1) 51 | self.conv7_3 = self.__conv(2, name='conv7_3', in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True, dilation=1, padding=1) 52 | self.conv7_3norm = self.__batch_normalization(2, 'conv7_3norm', num_features=512, eps=9.999999747378752e-06, momentum=0.1) 53 | self.conv8_1 = self.__conv_transpose(2, name='conv8_1', in_channels=512, out_channels=256, kernel_size=(4, 4), stride=(2, 2), groups=1, bias=True) 54 | self.conv8_2 = self.__conv(2, name='conv8_2', in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True, dilation=1, padding=1) 55 | self.conv8_3 = self.__conv(2, name='conv8_3', in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True, dilation=1, padding=1) 56 | self.conv8_313 = self.__conv(2, name='conv8_313', in_channels=256, out_channels=313, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True, dilation=1, padding=0) 57 | self.class8_ab = self.__conv(2, name='class8_ab', in_channels=313, out_channels=2, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True, dilation=1, padding=0) 58 | 59 | def features(self, x): 60 | out = self.bw_conv1_1(x) 61 | out = F.relu(out) 62 | out = self.conv1_2(out) 63 | out = F.relu(out) 64 | out = self.conv1_2norm(out) 65 | out = self.conv2_1(out) 66 | out = F.relu(out) 67 | out = self.conv2_2(out) 68 | out = F.relu(out) 69 | out = self.conv2_2norm(out) 70 | out = self.conv3_1(out) 71 | out = F.relu(out) 72 | out = self.conv3_2(out) 73 | out = F.relu(out) 74 | out = self.conv3_3(out) 75 | out = F.relu(out) 76 | out = self.conv3_3norm(out) 77 | out = self.conv4_1(out) 78 | out = F.relu(out) 79 | out = self.conv4_2(out) 80 | out = F.relu(out) 81 | out = self.conv4_3(out) 82 | out = F.relu(out) 83 | out = self.conv4_3norm(out) 84 | out = self.conv5_1(out) 85 | out = F.relu(out) 86 | out = self.conv5_2(out) 87 | out = F.relu(out) 88 | out = self.conv5_3(out) 89 | out = F.relu(out) 90 | out = self.conv5_3norm(out) 91 | out = self.conv6_1(out) 92 | out = F.relu(out) 93 | out = self.conv6_2(out) 94 | out = F.relu(out) 95 | out = self.conv6_3(out) 96 | out = F.relu(out) 97 | out = self.conv6_3norm(out) 98 | out = self.conv7_1(out) 99 | out = F.relu(out) 100 | out = self.conv7_2(out) 101 | out = F.relu(out) 102 | out = self.conv7_3(out) 103 | out = F.relu(out) 104 | out = self.conv7_3norm(out) 105 | out = self.conv8_1(out) 106 | out = F.relu(out) 107 | out = self.conv8_2(out) 108 | out = F.relu(out) 109 | out = self.conv8_3(out) 110 | 111 | return out 112 | 113 | def forward(self, x): 114 | out = self.features(x) 115 | out = F.relu(out) 116 | out = self.conv8_313(out) 117 | out = 2.606 * out 118 | out = F.softmax(out, dim=1) 119 | out = self.class8_ab(out) 120 | 121 | return out 122 | 123 | def fwd_from_features(self, f): 124 | out = F.relu(f) 125 | out = self.conv8_313(out) 126 | out = 2.606 * out 127 | out = F.softmax(out, dim=1) 128 | out = self.class8_ab(out) 129 | 130 | return out 131 | 132 | @staticmethod 133 | def __batch_normalization(dim, name, **kwargs): 134 | if dim == 1: layer = nn.BatchNorm1d(**kwargs) 135 | elif dim == 2: layer = nn.BatchNorm2d(**kwargs) 136 | elif dim == 3: layer = nn.BatchNorm3d(**kwargs) 137 | else: raise NotImplementedError() 138 | 139 | try: 140 | if 'scale' in __weights_dict[name]: 141 | layer.state_dict()['weight'].copy_(torch.from_numpy(__weights_dict[name]['scale'])) 142 | else: 143 | layer.weight.data.fill_(1) 144 | 145 | if 'bias' in __weights_dict[name]: 146 | layer.state_dict()['bias'].copy_(torch.from_numpy(__weights_dict[name]['bias'])) 147 | else: 148 | layer.bias.data.fill_(0) 149 | 150 | layer.state_dict()['running_mean'].copy_(torch.from_numpy(__weights_dict[name]['mean'])) 151 | layer.state_dict()['running_var'].copy_(torch.from_numpy(__weights_dict[name]['var'])) 152 | except: 153 | pass 154 | return layer 155 | 156 | @staticmethod 157 | def __conv(dim, name, **kwargs): 158 | if dim == 1: layer = nn.Conv1d(**kwargs) 159 | elif dim == 2: layer = nn.Conv2d(**kwargs) 160 | elif dim == 3: layer = nn.Conv3d(**kwargs) 161 | else: raise NotImplementedError() 162 | 163 | try: 164 | layer.state_dict()['weight'].copy_(torch.from_numpy(__weights_dict[name]['weights'])) 165 | if 'bias' in __weights_dict[name]: 166 | layer.state_dict()['bias'].copy_(torch.from_numpy(__weights_dict[name]['bias'])) 167 | except: 168 | pass 169 | return layer 170 | 171 | @staticmethod 172 | def __conv_transpose(dim, name, **kwargs): 173 | if dim == 1: layer = nn.ConvTranspose1d(**kwargs) 174 | elif dim == 2: layer = nn.ConvTranspose2d(**kwargs) 175 | elif dim == 3: layer = nn.ConvTranspose3d(**kwargs) 176 | else: raise NotImplementedError() 177 | 178 | try: 179 | layer.state_dict()['weight'].copy_(torch.from_numpy(__weights_dict[name]['weights'])) 180 | if 'bias' in __weights_dict[name]: 181 | layer.state_dict()['bias'].copy_(torch.from_numpy(__weights_dict[name]['bias'])) 182 | except: 183 | pass 184 | return layer 185 | -------------------------------------------------------------------------------- /colorization_cINN/joint_bilateral.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void joint_filter(float* l_up, float* a_dw, float* a_up, 4 | int n_up, int m_up, int n_dw, int m_dw, 5 | double s_x, double s_l) { 6 | 7 | int range_x = ceil(s_x * 3); 8 | double c_x = 1./(s_x*s_x); 9 | double c_l = 1./(s_l*s_l); 10 | double scaling = (double) n_dw / n_up; 11 | 12 | double px_result, px_filt_norm, l, l0, w_x, w_l; 13 | int i, j, di, dj, di_dw, dj_dw; 14 | 15 | for (i = 0; i < n_up; i++){ 16 | for (j = 0; j < m_up; j++){ 17 | 18 | px_result = 0.; 19 | px_filt_norm = 0.; 20 | l0 = (double) l_up[i*m_up + j]; 21 | 22 | for(di = i-range_x; di < i+range_x; di++){ 23 | if (di < 0 || di >= n_up) continue; 24 | for(dj = j-range_x; dj < j+range_x; dj++){ 25 | if (dj < 0 || dj >= m_up) continue; 26 | 27 | l = (double) l_up[di*m_up + dj]; 28 | w_x = exp(-0.5 * ((di-i)*(di-i) + (dj-j)*(dj-j)) * c_x); 29 | w_l = exp(-0.5 * (l-l0)*(l-l0) * c_l); 30 | w_x *= w_l; 31 | 32 | di_dw = floor(di * scaling); 33 | dj_dw = floor(dj * scaling); 34 | px_result += w_x * a_dw[di_dw * m_dw + dj_dw]; 35 | px_filt_norm += w_x; 36 | 37 | } 38 | } 39 | a_up[i*m_up + j] = (float) (px_result / px_filt_norm); 40 | } 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /colorization_cINN/joint_bilateral_filter.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | cimport numpy as np 3 | 4 | cdef extern void joint_filter(float*, float*, float*, int, int, int, int, double, double) 5 | 6 | def upsample(x_l, x_ab, s_x, s_l): 7 | n_up, m_up = x_l.shape[0], x_l.shape[1] 8 | n_dw, m_dw = x_ab.shape[1], x_ab.shape[2] 9 | 10 | cdef np.ndarray[float, ndim=2, mode="c"] l_up = np.ascontiguousarray(x_l) 11 | cdef np.ndarray[float, ndim=2, mode="c"] a_dw = np.ascontiguousarray(x_ab[0]) 12 | cdef np.ndarray[float, ndim=2, mode="c"] b_dw = np.ascontiguousarray(x_ab[1]) 13 | 14 | cdef np.ndarray[float, ndim=2, mode="c"] a_up = np.empty((n_up, m_up), dtype=np.float32) 15 | cdef np.ndarray[float, ndim=2, mode="c"] b_up = np.empty((n_up, m_up), dtype=np.float32) 16 | 17 | joint_filter(&l_up[0,0], &a_dw[0,0], &a_up[0,0], n_up, m_up, n_dw, m_dw, s_x, s_l) 18 | joint_filter(&l_up[0,0], &b_dw[0,0], &b_up[0,0], n_up, m_up, n_dw, m_dw, s_x, s_l) 19 | 20 | return np.stack([a_up, b_up], axis=0) 21 | 22 | 23 | -------------------------------------------------------------------------------- /colorization_cINN/model.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch.optim 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | from FrEIA.framework import * 9 | from FrEIA.modules import * 10 | from subnet_coupling import * 11 | import data 12 | import config as c 13 | 14 | feature_channels = 256 15 | fc_cond_length = 512 16 | n_blocks_fc = 8 17 | outputs = [] 18 | 19 | conditions = [ConditionNode(feature_channels, c.img_dims[0], c.img_dims[1]), 20 | ConditionNode(fc_cond_length)] 21 | 22 | def random_orthog(n): 23 | w = np.random.randn(n, n) 24 | w = w + w.T 25 | w, S, V = np.linalg.svd(w) 26 | return torch.FloatTensor(w) 27 | 28 | def cond_subnet(level, c_out, extra_conv=False): 29 | c_intern = [feature_channels, 128, 128, 256] 30 | modules = [] 31 | 32 | for i in range(level): 33 | modules.extend([nn.Conv2d(c_intern[i], c_intern[i+1], 3, stride=2, padding=1), 34 | nn.LeakyReLU() ]) 35 | 36 | if extra_conv: 37 | modules.extend([ 38 | nn.Conv2d(c_intern[level], 128, 3, padding=1), 39 | nn.LeakyReLU(), 40 | nn.Conv2d(128, 2*c_out, 3, padding=1), 41 | ]) 42 | else: 43 | modules.append(nn.Conv2d(c_intern[level], 2*c_out, 3, padding=1)) 44 | 45 | modules.append(nn.BatchNorm2d(2*c_out)) 46 | 47 | return nn.Sequential(*modules) 48 | 49 | fc_cond_net = nn.Sequential(*[ 50 | nn.Conv2d(feature_channels, 128, 3, stride=2, padding=1), # 32 x 32 51 | nn.LeakyReLU(), 52 | nn.Conv2d(128, 256, 3, stride=2, padding=1), # 16 x 16 53 | nn.LeakyReLU(), 54 | nn.Conv2d(256, 256, 3, stride=2, padding=1), # 8 x 8 55 | nn.LeakyReLU(), 56 | nn.Conv2d(256, fc_cond_length, 3, stride=2, padding=1), # 4 x 4 57 | nn.AvgPool2d(4), 58 | nn.BatchNorm2d(fc_cond_length), 59 | ]) 60 | 61 | def _add_conditioned_section(nodes, depth, channels_in, channels, cond_level): 62 | 63 | for k in range(depth): 64 | nodes.append(Node([nodes[-1].out0], 65 | subnet_coupling_layer, 66 | {'clamp':c.clamping, 'F_class':F_conv, 67 | 'subnet':cond_subnet(cond_level, channels//2), 'sub_len':channels, 68 | 'F_args':{'leaky_slope': 5e-2, 'channels_hidden':channels}}, 69 | conditions=[conditions[0]], name=F'conv_{k}')) 70 | 71 | #else: 72 | #nodes.append(Node([nodes[-1].out0], 73 | #glow_coupling_layer, 74 | #{'clamp':c.clamping, 'F_class':F_conv, 75 | #'F_args':{'leaky_slope': 1e-2, 'channels_hidden':channels}}, 76 | #conditions=[], name=F'conv_{k}')) 77 | 78 | 79 | #nodes.append(Node([nodes[-1].out0], 80 | #cbn_direct, 81 | #{'clamp':c.clamping, 'subnet':cond_subnet(cond_level, channels_in, (cond_level==0))}, 82 | #conditions=[conditions[0]], name=F'cbn_{k}')) 83 | 84 | nodes.append(Node([nodes[-1].out0], conv_1x1, {'M':random_orthog(channels_in)})) 85 | 86 | 87 | def _add_split_downsample(nodes, split, downsample, channels_in, channels): 88 | if downsample=='haar': 89 | nodes.append(Node([nodes[-1].out0], haar_multiplex_layer, {'rebalance':0.5, 'order_by_wavelet':True}, name='haar')) 90 | if downsample=='reshape': 91 | nodes.append(Node([nodes[-1].out0], i_revnet_downsampling, {}, name='reshape')) 92 | 93 | for i in range(2): 94 | nodes.append(Node([nodes[-1].out0], conv_1x1, {'M':random_orthog(channels_in*4)})) 95 | nodes.append(Node([nodes[-1].out0], 96 | glow_coupling_layer, 97 | {'clamp':c.clamping, 'F_class':F_conv, 98 | 'F_args':{'kernel_size':1, 'leaky_slope': 1e-2, 'channels_hidden':channels}}, 99 | conditions=[])) 100 | 101 | if split: 102 | nodes.append(Node([nodes[-1].out0], split_layer, 103 | {'split_size_or_sections': split, 'dim':0}, name='split')) 104 | 105 | output = Node([nodes[-1].out1], flattening_layer, {}, name='flatten') 106 | nodes.insert(-2, output) 107 | nodes.insert(-2, OutputNode([output.out0], name='out')) 108 | 109 | def _add_fc_section(nodes): 110 | nodes.append(Node([nodes[-1].out0], flattening_layer, {}, name='flatten')) 111 | for k in range(n_blocks_fc): 112 | nodes.append(Node([nodes[-1].out0], permute_layer, {'seed':k}, name=F'permute_{k}')) 113 | nodes.append(Node([nodes[-1].out0], glow_coupling_layer, 114 | {'clamp':c.clamping, 'F_class':F_fully_connected, 'F_args':{'internal_size':512}}, 115 | conditions=[conditions[1]], name=F'fc_{k}')) 116 | 117 | nodes.append(OutputNode([nodes[-1].out0], name='out')) 118 | 119 | nodes = [InputNode(2, *c.img_dims, name='inp')] 120 | # 2x64x64 px 121 | _add_conditioned_section(nodes, depth=4, channels_in=2, channels=32, cond_level=0) 122 | _add_split_downsample(nodes, split=False, downsample='reshape', channels_in=2, channels=64) 123 | 124 | # 8x32x32 px 125 | _add_conditioned_section(nodes, depth=6, channels_in=8, channels=64, cond_level=1) 126 | _add_split_downsample(nodes, split=(16, 16), downsample='reshape', channels_in=8, channels=128) 127 | 128 | # 16x16x16 px 129 | _add_conditioned_section(nodes, depth=6, channels_in=16, channels=128, cond_level=2) 130 | _add_split_downsample(nodes, split=(32, 32), downsample='reshape', channels_in=16, channels=256) 131 | 132 | # 32x8x8 px 133 | _add_conditioned_section(nodes, depth=6, channels_in=32, channels=256, cond_level=3) 134 | _add_split_downsample(nodes, split=(32, 3*32), downsample='haar', channels_in=32, channels=256) 135 | 136 | # 32x4x4 = 512 px 137 | _add_fc_section(nodes) 138 | 139 | def init_model(mod): 140 | for key, param in mod.named_parameters(): 141 | split = key.split('.') 142 | if param.requires_grad: 143 | param.data = c.init_scale * torch.randn(param.data.shape).cuda() 144 | if len(split) > 3 and split[3][-1] == '3': # last convolution in the coeff func 145 | param.data.fill_(0.) 146 | 147 | 148 | cinn = ReversibleGraphNet(nodes + conditions, verbose=False) 149 | output_dimensions = [] 150 | for o in nodes: 151 | if type(o) is OutputNode: 152 | output_dimensions.append(o.input_dims[0][0]) 153 | 154 | cinn.cuda() 155 | init_model(cinn) 156 | #init_model(fc_cond_net) 157 | 158 | if c.load_inn_only: 159 | cinn.load_state_dict(torch.load(c.load_inn_only)['net']) 160 | 161 | import feature_net 162 | efros_net = feature_net.KitModel(None) 163 | 164 | try: 165 | pretrained_dict = torch.load('./features_pretrained.pt') 166 | pretrained_dict = {k:v for k,v in pretrained_dict.items() if 'num_batches_tracked' not in k} 167 | efros_net.load_state_dict(pretrained_dict) 168 | except FileNotFoundError: 169 | warnings.warn("No loading pretrained weights for conditioning network (./features_pretrained.pt)") 170 | 171 | efros_net.cuda() 172 | efros_net.class8_ab.state_dict()['weight'].copy_(torch.from_numpy(np.load('./pts_in_hull.npy').T).view(2, 313, 1, 1)) 173 | 174 | def prepare_batch(x): 175 | 176 | net_feat = combined_model.module.feature_network 177 | net_inn = combined_model.module.inn 178 | net_cond = combined_model.module.fc_cond_network 179 | 180 | with torch.no_grad(): 181 | x = x.cuda() 182 | x_l, x_ab = x[:, 0:1], x[:, 1:] 183 | 184 | x_ab = F.interpolate(x_ab, size=c.img_dims) 185 | x_ab += 5e-2 * torch.cuda.FloatTensor(x_ab.shape).normal_() 186 | 187 | if c.end_to_end: 188 | features = net_feat.features(x_l * data.scales[0] + data.offsets[0] - 50) 189 | features = features[:, :, 1:-1, 1:-1] 190 | else: 191 | with torch.no_grad(): 192 | features = net_feat.features(x_l * data.scales[0] + data.offsets[0] - 50) 193 | features = features[:, :, 1:-1, 1:-1] 194 | 195 | with torch.no_grad(): 196 | ab_pred = net_feat.fwd_from_features(features) 197 | for i in [0,1]: 198 | ab_pred[:, i] = (ab_pred[:, i] - data.offsets[i+1]) / data.scales[i+1] 199 | 200 | ab_pred += 5e-2 * torch.cuda.FloatTensor(ab_pred.shape).normal_() 201 | ab_pred += 0.10 * torch.randn(ab_pred.shape[0], 2, 1, 1).cuda().expand_as(ab_pred) 202 | ab_pred *= 0.95 + 0.18 * np.random.randn() 203 | 204 | cond = [features, net_cond(features).squeeze()] 205 | 206 | return x_l.detach(), x_ab.detach(), cond, ab_pred 207 | 208 | class WrappedModel(nn.Module): 209 | def __init__(self, feature_network, fc_cond_network, inn): 210 | super().__init__() 211 | 212 | self.feature_network = feature_network 213 | self.fc_cond_network = fc_cond_network 214 | self.inn = inn 215 | 216 | def forward(self, x): 217 | 218 | x_l, x_ab = x[:, 0:1], x[:, 1:] 219 | 220 | x_ab = F.interpolate(x_ab, size=c.img_dims) 221 | x_ab += 5e-2 * torch.cuda.FloatTensor(x_ab.shape).normal_() 222 | 223 | if c.end_to_end: 224 | features = self.feature_network.features(x_l * data.scales[0] + data.offsets[0] - 50) 225 | features = features[:, :, 1:-1, 1:-1] 226 | else: 227 | with torch.no_grad(): 228 | features = self.feature_network.features(x_l * data.scales[0] + data.offsets[0] - 50) 229 | features = features[:, :, 1:-1, 1:-1] 230 | 231 | cond = [features, self.fc_cond_network(features).squeeze()] 232 | 233 | z = self.inn(x_ab, cond) 234 | zz = sum(torch.sum(o**2, dim=1) for o in z) 235 | jac = self.inn.jacobian(run_forward=False) 236 | 237 | return zz, jac 238 | 239 | def reverse_sample(self, z, cond): 240 | return self.inn(z, cond, rev=True) 241 | 242 | combined_model = WrappedModel(efros_net, fc_cond_net, cinn) 243 | combined_model.cuda() 244 | combined_model = nn.DataParallel(combined_model, device_ids=c.device_ids) 245 | 246 | params_trainable = (list(filter(lambda p: p.requires_grad, combined_model.module.inn.parameters())) 247 | + list(combined_model.module.fc_cond_network.parameters())) 248 | 249 | optim = torch.optim.Adam(params_trainable, lr=c.lr, betas=c.betas, eps=1e-6, weight_decay=c.weight_decay) 250 | #optim = torch.optim.SGD(params_trainable, lr=c.lr, weight_decay=c.weight_decay) 251 | 252 | sched_factor = 0.2 253 | sched_patience = 8 254 | sched_trehsh = 0.001 255 | sched_cooldown = 2 256 | 257 | weight_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, 258 | factor=sched_factor, 259 | patience=sched_patience, 260 | threshold=sched_trehsh, 261 | min_lr=0, eps=1e-08, 262 | cooldown=sched_cooldown, 263 | verbose = True) 264 | 265 | weight_scheduler_fixed = torch.optim.lr_scheduler.StepLR(optim, 120, gamma=0.2) 266 | 267 | class DummyOptim: 268 | def __init__(self): 269 | self.param_groups = [] 270 | def state_dict(self): 271 | return {} 272 | def load_state_dict(self, *args, **kwargs): 273 | pass 274 | def step(self, *args, **kwargs): 275 | pass 276 | def zero_grad(self): 277 | pass 278 | 279 | efros_net.train() 280 | 281 | if c.end_to_end: 282 | feature_optim = torch.optim.Adam(combined_model.module.feature_network.parameters(), lr=c.lr_feature_net, betas=c.betas, eps=1e-4) 283 | feature_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(feature_optim, 284 | factor=sched_factor, 285 | patience=sched_patience, 286 | threshold=sched_trehsh, 287 | min_lr=0, eps=1e-08, 288 | cooldown=sched_cooldown, 289 | verbose = True) 290 | else: 291 | #efros_net.eval() 292 | feature_optim = DummyOptim() 293 | feature_scheduler = DummyOptim() 294 | 295 | def optim_step(): 296 | optim.step() 297 | optim.zero_grad() 298 | 299 | feature_optim.step() 300 | feature_optim.zero_grad() 301 | 302 | def save(name): 303 | torch.save({'opt':optim.state_dict(), 304 | 'opt_f':feature_optim.state_dict(), 305 | 'net':combined_model.state_dict()}, name) 306 | 307 | def load(name): 308 | state_dicts = torch.load(name) 309 | network_state_dict = {k:v for k,v in state_dicts['net'].items() if 'tmp_var' not in k} 310 | combined_model.load_state_dict(network_state_dict) 311 | try: 312 | optim.load_state_dict(state_dicts['opt']) 313 | feature_optim.load_state_dict(state_dicts['opt_f']) 314 | except: 315 | print('Cannot load optimizer for some reason or other') 316 | -------------------------------------------------------------------------------- /colorization_cINN/model_no_cond.py: -------------------------------------------------------------------------------- 1 | import torch.optim 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from FrEIA.framework import * 7 | from FrEIA.modules import * 8 | from cbn_layer import * 9 | from subnet_coupling import * 10 | import data 11 | import config as c 12 | 13 | n_blocks_fc = 8 14 | outputs = [] 15 | 16 | conditions = [ConditionNode(1, c.img_dims[0], c.img_dims[1])] 17 | 18 | def random_orthog(n): 19 | w = np.random.randn(n, n) 20 | w = w + w.T 21 | w, S, V = np.linalg.svd(w) 22 | return torch.FloatTensor(w) 23 | 24 | class HaarConv(nn.Module): 25 | 26 | def __init__(self, level): 27 | super().__init__() 28 | 29 | self.in_channels = 4**level 30 | self.fac_fwd = 0.25 31 | self.haar_weights = torch.ones(4,1,2,2) 32 | 33 | self.haar_weights[1, 0, 0, 1] = -1 34 | self.haar_weights[1, 0, 1, 1] = -1 35 | 36 | self.haar_weights[2, 0, 1, 0] = -1 37 | self.haar_weights[2, 0, 1, 1] = -1 38 | 39 | self.haar_weights[3, 0, 1, 0] = -1 40 | self.haar_weights[3, 0, 0, 1] = -1 41 | 42 | self.haar_weights = torch.cat([self.haar_weights]*self.in_channels, 0) 43 | self.haar_weights = nn.Parameter(self.haar_weights) 44 | self.haar_weights.requires_grad = False 45 | 46 | def forward(self, x): 47 | out = F.conv2d(x, self.haar_weights, 48 | bias=None, stride=2, groups=self.in_channels) 49 | return out * self.fac_fwd 50 | 51 | def cond_subnet(level): 52 | return nn.Sequential(*[HaarConv(i) for i in range(level+2)]) 53 | 54 | def _add_conditioned_section(nodes, depth, channels_in, channels, cond_level): 55 | 56 | for k in range(depth): 57 | nodes.append(Node([nodes[-1].out0], 58 | subnet_coupling_layer, 59 | {'clamp':c.clamping, 'F_class':F_conv, 60 | 'subnet':cond_subnet(cond_level), 'sub_len':4**(cond_level+2), 61 | 'F_args':{'leaky_slope': 5e-2, 'channels_hidden':channels}}, 62 | conditions=[conditions[0]], name=F'conv_{k}')) 63 | 64 | nodes.append(Node([nodes[-1].out0], conv_1x1, {'M':random_orthog(channels_in)})) 65 | 66 | 67 | def _add_split_downsample(nodes, split, downsample, channels_in, channels): 68 | if downsample=='haar': 69 | nodes.append(Node([nodes[-1].out0], haar_multiplex_layer, {'rebalance':0.5, 'order_by_wavelet':True}, name='haar')) 70 | if downsample=='reshape': 71 | nodes.append(Node([nodes[-1].out0], i_revnet_downsampling, {}, name='reshape')) 72 | 73 | for i in range(2): 74 | nodes.append(Node([nodes[-1].out0], conv_1x1, {'M':random_orthog(channels_in*4)})) 75 | nodes.append(Node([nodes[-1].out0], 76 | glow_coupling_layer, 77 | {'clamp':c.clamping, 'F_class':F_conv, 78 | 'F_args':{'kernel_size':1, 'leaky_slope': 1e-2, 'channels_hidden':channels}}, 79 | conditions=[])) 80 | 81 | if split: 82 | nodes.append(Node([nodes[-1].out0], split_layer, 83 | {'split_size_or_sections': split, 'dim':0}, name='split')) 84 | 85 | output = Node([nodes[-1].out1], flattening_layer, {}, name='flatten') 86 | nodes.insert(-2, output) 87 | nodes.insert(-2, OutputNode([output.out0], name='out')) 88 | 89 | def _add_fc_section(nodes): 90 | nodes.append(Node([nodes[-1].out0], flattening_layer, {}, name='flatten')) 91 | for k in range(n_blocks_fc): 92 | nodes.append(Node([nodes[-1].out0], permute_layer, {'seed':k}, name=F'permute_{k}')) 93 | nodes.append(Node([nodes[-1].out0], glow_coupling_layer, 94 | {'clamp':c.clamping, 'F_class':F_fully_connected, 'F_args':{'internal_size':512}}, 95 | conditions=[], name=F'fc_{k}')) 96 | 97 | nodes.append(OutputNode([nodes[-1].out0], name='out')) 98 | 99 | nodes = [InputNode(2, *c.img_dims, name='inp')] 100 | # 2x64x64 px 101 | _add_conditioned_section(nodes, depth=4, channels_in=2, channels=32, cond_level=0) 102 | _add_split_downsample(nodes, split=False, downsample='reshape', channels_in=2, channels=64) 103 | 104 | # 8x32x32 px 105 | _add_conditioned_section(nodes, depth=6, channels_in=8, channels=64, cond_level=1) 106 | _add_split_downsample(nodes, split=(16, 16), downsample='reshape', channels_in=8, channels=128) 107 | 108 | # 16x16x16 px 109 | _add_conditioned_section(nodes, depth=6, channels_in=16, channels=128, cond_level=2) 110 | _add_split_downsample(nodes, split=(32, 32), downsample='reshape', channels_in=16, channels=256) 111 | 112 | # 32x8x8 px 113 | _add_conditioned_section(nodes, depth=6, channels_in=32, channels=256, cond_level=3) 114 | _add_split_downsample(nodes, split=(32, 3*32), downsample='haar', channels_in=32, channels=256) 115 | 116 | # 32x4x4 = 512 px 117 | _add_fc_section(nodes) 118 | 119 | def init_model(mod): 120 | for key, param in mod.named_parameters(): 121 | split = key.split('.') 122 | if param.requires_grad: 123 | param.data = c.init_scale * torch.randn(param.data.shape).cuda() 124 | if len(split) > 3 and split[3][-1] == '3': # last convolution in the coeff func 125 | param.data.fill_(0.) 126 | 127 | 128 | cinn = ReversibleGraphNet(nodes + conditions, verbose=False) 129 | output_dimensions = [] 130 | for o in nodes: 131 | if type(o) is OutputNode: 132 | output_dimensions.append(o.input_dims[0][0]) 133 | 134 | cinn.cuda() 135 | init_model(cinn) 136 | 137 | if c.load_inn_only: 138 | cinn.load_state_dict(torch.load(c.load_inn_only)['net']) 139 | 140 | class DummyFeatureNet(nn.Module): 141 | def __init__(self, *args, **kwargs): 142 | super().__init__() 143 | self.dumm_param = nn.Parameter(torch.zeros(10)) 144 | 145 | def forward(self, x): 146 | return x 147 | def features(self, x): 148 | return x 149 | 150 | efros_net = DummyFeatureNet() 151 | 152 | def prepare_batch(x): 153 | 154 | net_feat = combined_model.module.feature_network 155 | net_inn = combined_model.module.inn 156 | net_cond = combined_model.module.fc_cond_network 157 | 158 | with torch.no_grad(): 159 | x = x.cuda() 160 | x_l, x_ab = x[:, 0:1], x[:, 1:] 161 | 162 | x_ab = F.interpolate(x_ab, size=c.img_dims) 163 | x_ab += 5e-2 * torch.cuda.FloatTensor(x_ab.shape).normal_() 164 | 165 | cond = [x_l] 166 | ab_pred = None 167 | 168 | return x_l.detach(), x_ab.detach(), cond, ab_pred 169 | 170 | class WrappedModel(nn.Module): 171 | def __init__(self, feature_network, fc_cond_network, inn): 172 | super().__init__() 173 | 174 | self.feature_network = feature_network 175 | self.fc_cond_network = fc_cond_network 176 | self.inn = inn 177 | 178 | def forward(self, x): 179 | 180 | x_l, x_ab = x[:, 0:1], x[:, 1:] 181 | 182 | x_ab = F.interpolate(x_ab, size=c.img_dims) 183 | x_ab += 5e-2 * torch.cuda.FloatTensor(x_ab.shape).normal_() 184 | 185 | cond = [x_l] 186 | 187 | z = self.inn(x_ab, cond) 188 | zz = sum(torch.sum(o**2, dim=1) for o in z) 189 | jac = self.inn.jacobian(run_forward=False) 190 | 191 | return zz, jac 192 | 193 | def reverse_sample(self, z, cond): 194 | return self.inn(z, cond, rev=True) 195 | 196 | combined_model = WrappedModel(efros_net, None, cinn) 197 | combined_model.cuda() 198 | combined_model = nn.DataParallel(combined_model, device_ids=c.device_ids) 199 | 200 | params_trainable = list(filter(lambda p: p.requires_grad, combined_model.module.inn.parameters())) 201 | 202 | optim = torch.optim.Adam(params_trainable, lr=c.lr, betas=c.betas, eps=1e-6, weight_decay=c.weight_decay) 203 | 204 | sched_factor = 0.2 205 | sched_patience = 8 206 | sched_trehsh = 0.001 207 | sched_cooldown = 2 208 | 209 | weight_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, 210 | factor=sched_factor, 211 | patience=sched_patience, 212 | threshold=sched_trehsh, 213 | min_lr=0, eps=1e-08, 214 | cooldown=sched_cooldown, 215 | verbose = True) 216 | 217 | weight_scheduler_fixed = torch.optim.lr_scheduler.torch.optim.lr_scheduler.StepLR(optim, 120, gamma=0.2) 218 | 219 | class DummyOptim: 220 | def __init__(self): 221 | self.param_groups = [] 222 | def state_dict(self): 223 | return {} 224 | def load_state_dict(self, *args, **kwargs): 225 | pass 226 | def step(self, *args, **kwargs): 227 | pass 228 | def zero_grad(self): 229 | pass 230 | 231 | efros_net.train() 232 | 233 | if c.end_to_end: 234 | feature_optim = torch.optim.Adam(combined_model.module.feature_network.parameters(), lr=c.lr_feature_net, betas=c.betas, eps=1e-4) 235 | feature_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(feature_optim, 236 | factor=sched_factor, 237 | patience=sched_patience, 238 | threshold=sched_trehsh, 239 | min_lr=0, eps=1e-08, 240 | cooldown=sched_cooldown, 241 | verbose = True) 242 | else: 243 | feature_optim = DummyOptim() 244 | feature_scheduler = DummyOptim() 245 | 246 | def optim_step(): 247 | optim.step() 248 | optim.zero_grad() 249 | 250 | feature_optim.step() 251 | feature_optim.zero_grad() 252 | 253 | def save(name): 254 | torch.save({'opt':optim.state_dict(), 255 | 'opt_f':feature_optim.state_dict(), 256 | 'net':combined_model.state_dict()}, name) 257 | 258 | def load(name): 259 | state_dicts = torch.load(name) 260 | network_state_dict = {k:v for k,v in state_dicts['net'].items() if 'tmp_var' not in k} 261 | combined_model.load_state_dict(network_state_dict) 262 | try: 263 | optim.load_state_dict(state_dicts['opt']) 264 | feature_optim.load_state_dict(state_dicts['opt_f']) 265 | except: 266 | print('Cannot load optimizer for some reason or other') 267 | -------------------------------------------------------------------------------- /colorization_cINN/output/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vislearn/conditional_INNs/19e316c606cae24815efa51305ce8d3a6476f819/colorization_cINN/output/.gitkeep -------------------------------------------------------------------------------- /colorization_cINN/pts_in_hull.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vislearn/conditional_INNs/19e316c606cae24815efa51305ce8d3a6476f819/colorization_cINN/pts_in_hull.npy -------------------------------------------------------------------------------- /colorization_cINN/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from distutils.core import setup 4 | from distutils.extension import Extension 5 | from Cython.Distutils import build_ext 6 | 7 | import numpy as np 8 | 9 | setup( 10 | cmdclass = {'build_ext': build_ext}, 11 | ext_modules = [Extension("joint_bilateral_filter", 12 | sources=["joint_bilateral_filter.pyx", "joint_bilateral.c"], 13 | include_dirs=[np.get_include()])], 14 | ) 15 | 16 | def demo(): 17 | import joint_bilateral_filter as jbf 18 | import matplotlib.pyplot as plt 19 | from skimage.color import lab2rgb 20 | from skimage.transform import resize 21 | 22 | L = plt.imread('./demo_images/L.png')[:, :, 0] 23 | 24 | ab = np.stack([plt.imread('./demo_images/a.png')[::2, ::2, 0], 25 | plt.imread('./demo_images/b.png')[::2, ::2, 0]], axis=0) 26 | 27 | s_x = 4. 28 | s_l = 0.05 29 | 30 | print('===='*20 + '\n\n' + '===='*20) 31 | ab_up = jbf.upsample(L, ab, s_x, s_l) 32 | ab_naive0 = resize(ab.transpose((1,2,0)), (ab_up.shape[1], ab_up.shape[2]), order=0).transpose((2,0,1)) 33 | ab_naive1 = resize(ab.transpose((1,2,0)), (ab_up.shape[1], ab_up.shape[2]), order=1).transpose((2,0,1)) 34 | 35 | Lab_up = np.concatenate([L[np.newaxis, :, :], ab_up], axis=0).transpose((1,2,0)) 36 | Lab_naive0 = np.concatenate([L[np.newaxis, :, :], ab_naive0], axis=0).transpose((1,2,0)) 37 | Lab_naive1 = np.concatenate([L[np.newaxis, :, :], ab_naive1], axis=0).transpose((1,2,0)) 38 | 39 | for i, s, o in [(0, 100, 0), 40 | (1, 255, -128), 41 | (2, 255, -128)]: 42 | 43 | Lab_up[:, :, i] = Lab_up[:, :, i] * s + o 44 | Lab_naive0[:, :, i] = Lab_naive0[:, :, i] * s + o 45 | Lab_naive1[:, :, i] = Lab_naive1[:, :, i] * s + o 46 | 47 | rgb_up = lab2rgb(Lab_up) 48 | rgb_naive0 = lab2rgb(Lab_naive0) 49 | rgb_naive1 = lab2rgb(Lab_naive1) 50 | 51 | plt.figure() 52 | plt.title('nearest') 53 | plt.imshow(rgb_naive0) 54 | plt.figure() 55 | plt.title('bilinear') 56 | plt.imshow(rgb_naive1) 57 | plt.figure() 58 | plt.title('joint bilateral') 59 | plt.imshow(rgb_up) 60 | 61 | plt.show() 62 | -------------------------------------------------------------------------------- /colorization_cINN/subnet_coupling.py: -------------------------------------------------------------------------------- 1 | from math import exp 2 | import torch 3 | import torch.nn as nn 4 | 5 | class subnet_coupling_layer(nn.Module): 6 | def __init__(self, dims_in, dims_c, F_class, subnet, sub_len, F_args={}, clamp=5.): 7 | super().__init__() 8 | 9 | channels = dims_in[0][0] 10 | self.ndims = len(dims_in[0]) 11 | self.split_len1 = channels // 2 12 | self.split_len2 = channels - channels // 2 13 | 14 | self.clamp = clamp 15 | self.max_s = exp(clamp) 16 | self.min_s = exp(-clamp) 17 | 18 | self.conditional = True 19 | condition_length = sub_len 20 | self.subnet = subnet 21 | 22 | self.s1 = F_class(self.split_len1 + condition_length, self.split_len2*2, **F_args) 23 | self.s2 = F_class(self.split_len2 + condition_length, self.split_len1*2, **F_args) 24 | 25 | def e(self, s): 26 | return torch.exp(self.clamp * 0.636 * torch.atan(s / self.clamp)) 27 | 28 | def log_e(self, s): 29 | return self.clamp * 0.636 * torch.atan(s / self.clamp) 30 | 31 | def forward(self, x, c=[], rev=False): 32 | x1, x2 = (x[0].narrow(1, 0, self.split_len1), 33 | x[0].narrow(1, self.split_len1, self.split_len2)) 34 | c_star = self.subnet(torch.cat(c, 1)) 35 | 36 | if not rev: 37 | r2 = self.s2(torch.cat([x2, c_star], 1) if self.conditional else x2) 38 | s2, t2 = r2[:, :self.split_len1], r2[:, self.split_len1:] 39 | y1 = self.e(s2) * x1 + t2 40 | 41 | r1 = self.s1(torch.cat([y1, c_star], 1) if self.conditional else y1) 42 | s1, t1 = r1[:, :self.split_len2], r1[:, self.split_len2:] 43 | y2 = self.e(s1) * x2 + t1 44 | self.last_jac = self.log_e(s1) + self.log_e(s2) 45 | 46 | else: # names of x and y are swapped! 47 | r1 = self.s1(torch.cat([x1, c_star], 1) if self.conditional else x1) 48 | s1, t1 = r1[:, :self.split_len2], r1[:, self.split_len2:] 49 | y2 = (x2 - t1) / self.e(s1) 50 | 51 | r2 = self.s2(torch.cat([y2, c_star], 1) if self.conditional else y2) 52 | s2, t2 = r2[:, :self.split_len1], r2[:, self.split_len1:] 53 | y1 = (x1 - t2) / self.e(s2) 54 | self.last_jac = - self.log_e(s1) - self.log_e(s2) 55 | 56 | return [torch.cat((y1, y2), 1)] 57 | 58 | def jacobian(self, x, c=[], rev=False): 59 | return torch.sum(self.last_jac, dim=tuple(range(1, self.ndims+1))) 60 | 61 | def output_dims(self, input_dims): 62 | return input_dims 63 | -------------------------------------------------------------------------------- /colorization_cINN/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | 4 | import torch 5 | import torch.nn 6 | import torch.optim 7 | from torch.nn.functional import avg_pool2d#, interpolate 8 | from torch.autograd import Variable 9 | import numpy as np 10 | import tqdm 11 | 12 | import config as c 13 | 14 | if c.no_cond_net: 15 | import model_no_cond as model 16 | else: 17 | import model 18 | 19 | import data 20 | import viz 21 | 22 | if c.load_file: 23 | model.load(c.load_file) 24 | 25 | class dummy_loss(object): 26 | def item(self): 27 | return 1. 28 | 29 | def sample_outputs(sigma, out_shape): 30 | return [sigma * torch.cuda.FloatTensor(torch.Size((4, o))).normal_() for o in out_shape] 31 | 32 | tot_output_size = 2 * c.img_dims[0] * c.img_dims[1] 33 | 34 | try: 35 | for i_epoch in range(-c.pre_low_lr, c.n_epochs): 36 | 37 | loss_history = [] 38 | data_iter = iter(data.train_loader) 39 | 40 | if i_epoch < 0: 41 | for param_group in model.optim.param_groups: 42 | param_group['lr'] = c.lr * 2e-2 43 | if i_epoch == 0: 44 | for param_group in model.optim.param_groups: 45 | param_group['lr'] = c.lr 46 | 47 | if c.end_to_end and i_epoch <= c.pretrain_epochs: 48 | for param_group in model.feature_optim.param_groups: 49 | param_group['lr'] = 0 50 | if i_epoch == c.pretrain_epochs: 51 | for param_group in model.feature_optim.param_groups: 52 | param_group['lr'] = 1e-4 53 | 54 | iterator = tqdm.tqdm(enumerate(data_iter), 55 | total=min(len(data.train_loader), c.n_its_per_epoch), 56 | leave=False, 57 | mininterval=1., 58 | disable=(not c.progress_bar), 59 | ncols=83) 60 | 61 | for i_batch , x in iterator: 62 | 63 | zz, jac = model.combined_model(x) 64 | 65 | neg_log_likeli = 0.5 * zz - jac 66 | 67 | l = torch.mean(neg_log_likeli) / tot_output_size 68 | l.backward() 69 | 70 | model.optim_step() 71 | loss_history.append([l.item(), 0.]) 72 | 73 | if i_batch+1 >= c.n_its_per_epoch: 74 | # somehow the data loader workers don't shut down automatically 75 | try: 76 | data_iter._shutdown_workers() 77 | except: 78 | pass 79 | 80 | iterator.close() 81 | break 82 | 83 | epoch_losses = np.mean(np.array(loss_history), axis=0) 84 | epoch_losses[1] = np.log10(model.optim.param_groups[0]['lr']) 85 | for i in range(len(epoch_losses)): 86 | epoch_losses[i] = min(epoch_losses[i], c.loss_display_cutoff) 87 | 88 | with torch.no_grad(): 89 | ims = [] 90 | for x in data.test_loader: 91 | x_l, x_ab, cond, ab_pred = model.prepare_batch(x[:4]) 92 | 93 | for i in range(3): 94 | z = sample_outputs(c.sampling_temperature, model.output_dimensions) 95 | x_ab_sampled = model.combined_model.module.reverse_sample(z, cond) 96 | ims.extend(list(data.norm_lab_to_rgb(x_l, x_ab_sampled))) 97 | 98 | break 99 | 100 | if i_epoch >= c.pretrain_epochs * 2: 101 | model.weight_scheduler.step(epoch_losses[0]) 102 | model.feature_scheduler.step(epoch_losses[0]) 103 | 104 | viz.show_imgs(*ims) 105 | viz.show_loss(epoch_losses) 106 | 107 | if i_epoch > 0 and (i_epoch % c.checkpoint_save_interval) == 0: 108 | model.save(c.filename + '_checkpoint_%.4i' % (i_epoch * (1-c.checkpoint_save_overwrite))) 109 | 110 | model.save(c.filename) 111 | 112 | except: 113 | if c.checkpoint_on_error: 114 | model.save(c.filename + '_ABORT') 115 | 116 | raise 117 | finally: 118 | viz.signal_stop() 119 | -------------------------------------------------------------------------------- /colorization_cINN/viz.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | 3 | from scipy.ndimage import zoom 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | import config as c 8 | import data 9 | 10 | n_imgs = 4 11 | n_plots = 2 12 | figsize = (4,4) 13 | im_width = c.img_dims[1] 14 | 15 | class Visualizer: 16 | def __init__(self, loss_labels): 17 | self.n_losses = len(loss_labels) 18 | self.loss_labels = loss_labels 19 | self.counter = 0 20 | 21 | header = 'Epoch' 22 | for l in loss_labels: 23 | header += '\t\t%s' % (l) 24 | 25 | self.config_str = "" 26 | self.config_str += "==="*30 + "\n" 27 | self.config_str += "Config options:\n\n" 28 | 29 | for v in dir(c): 30 | if v[0]=='_': continue 31 | s=eval('c.%s'%(v)) 32 | self.config_str += " {:25}\t{}\n".format(v,s) 33 | 34 | self.config_str += "==="*30 + "\n" 35 | 36 | print(self.config_str) 37 | print(header) 38 | 39 | def update_losses(self, losses, *args): 40 | print('\r', ' '*20, end='') 41 | line = '\r%.3i' % (self.counter) 42 | for l in losses: 43 | line += '\t\t%.4f' % (l) 44 | 45 | print(line) 46 | self.counter += 1 47 | 48 | def update_images(self, *img_list): 49 | w = img_list[0].shape[2] 50 | k = 0 51 | k_img = 0 52 | 53 | show_img = np.zeros((3, w*n_imgs, w*n_imgs), dtype=np.uint8) 54 | img_list_np = [] 55 | for im in img_list: 56 | im_np = im 57 | img_list_np.append(np.clip((255. * im_np), 0, 255).astype(np.uint8)) 58 | 59 | for i in range(n_imgs): 60 | for j in range(n_imgs): 61 | show_img[:, w*i:w*i+w, w*j:w*j+w] = img_list_np[k] 62 | 63 | k += 1 64 | if k >= len(img_list_np): 65 | k = 0 66 | k_img += 1 67 | 68 | plt.imsave(join(c.img_folder, '%.4d.jpg'%(self.counter)), show_img.transpose(1,2,0)) 69 | return zoom(show_img, (1., c.preview_upscale, c.preview_upscale), order=0) 70 | 71 | def update_hist(self, *args): 72 | pass 73 | 74 | def update_running(self, *args): 75 | pass 76 | 77 | 78 | if c.live_visualization: 79 | import visdom 80 | 81 | class LiveVisualizer(Visualizer): 82 | def __init__(self, loss_labels): 83 | super().__init__(loss_labels) 84 | self.viz = visdom.Visdom()#env='mnist') 85 | self.viz.close() 86 | self.config_box = self.viz.text('
' + self.config_str + '
') 87 | self.running_box = self.viz.text('

Running

') 88 | 89 | self.l_plots = self.viz.line(X = np.zeros((1,self.n_losses)), 90 | Y = np.zeros((1,self.n_losses)), 91 | opts = {'legend':self.loss_labels}) 92 | 93 | self.imgs = self.viz.image(np.random.random((3, im_width*n_imgs*c.preview_upscale, 94 | im_width*n_imgs*c.preview_upscale))) 95 | 96 | self.fig, self.axes = plt.subplots(n_plots, n_plots, figsize=figsize) 97 | self.hist = self.viz.matplot(self.fig) 98 | 99 | 100 | def update_losses(self, losses, logscale=False): 101 | super().update_losses(losses) 102 | its = min(len(data.train_loader), c.n_its_per_epoch) 103 | y = np.array([losses]) 104 | if logscale: 105 | y = np.log10(y) 106 | 107 | self.viz.line(X = (self.counter-1) * np.ones((1,self.n_losses)), 108 | Y = y, 109 | opts = {'legend':self.loss_labels}, 110 | win = self.l_plots, 111 | update = 'append') 112 | 113 | def update_images(self, *img_list): 114 | 115 | show_imgs = super().update_images(*img_list) 116 | self.viz.image(show_img, win = self.imgs) 117 | 118 | w = img_list[0].shape[2] 119 | k = 0 120 | k_img = 0 121 | 122 | show_img = np.zeros((3, w*n_imgs, w*n_imgs), dtype=np.uint8) 123 | img_list_np = [] 124 | for im in img_list: 125 | im_np = im 126 | img_list_np.append(np.clip((255. * im_np), 0, 255).astype(np.uint8)) 127 | 128 | for i in range(n_imgs): 129 | for j in range(n_imgs): 130 | show_img[:, w*i:w*i+w, w*j:w*j+w] = img_list_np[k] 131 | 132 | k += 1 133 | if k >= len(img_list_np): 134 | k = 0 135 | k_img += 1 136 | show_img = zoom(show_img, (1., c.preview_upscale, c.preview_upscale), order=0) 137 | self.viz.image(show_img, win = self.imgs) 138 | 139 | def update_hist(self, data): 140 | for i in range(n_plots): 141 | for j in range(n_plots): 142 | try: 143 | self.axes[i,j].clear() 144 | self.axes[i,j].hist(data[:, i*n_plots + j], bins=20, histtype='step') 145 | except ValueError: 146 | pass 147 | 148 | self.fig.tight_layout() 149 | self.viz.matplot(self.fig, win=self.hist) 150 | 151 | def update_running(self, running=True): 152 | if running: 153 | self.viz.text('

Running

', win=self.running_box) 154 | else: 155 | self.viz.text('

Done

', win=self.running_box) 156 | 157 | def close(self): 158 | self.viz.close(win=self.hist) 159 | self.viz.close(win=self.imgs) 160 | self.viz.close(win=self.l_plots) 161 | self.viz.close(win=self.running_box) 162 | self.viz.close(win=self.config_box) 163 | 164 | 165 | visualizer = LiveVisualizer(c.loss_names) 166 | else: 167 | visualizer = Visualizer(c.loss_names) 168 | 169 | def show_loss(losses, logscale=False): 170 | visualizer.update_losses(losses) 171 | 172 | def show_imgs(*imgs): 173 | visualizer.update_images(*imgs) 174 | 175 | def show_hist(data): 176 | visualizer.update_hist(data.data) 177 | 178 | def signal_start(): 179 | visualizer.update_running(True) 180 | 181 | def signal_stop(): 182 | visualizer.update_running(False) 183 | 184 | def close(): 185 | visualizer.close() 186 | 187 | -------------------------------------------------------------------------------- /colorization_minimal_example/.gitignore: -------------------------------------------------------------------------------- 1 | images/**/*.png 2 | -------------------------------------------------------------------------------- /colorization_minimal_example/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage import io, color 3 | import torch 4 | from PIL import Image 5 | from torch.utils.data import Dataset, DataLoader, TensorDataset 6 | import torchvision.transforms as T 7 | 8 | batch_size = 128 9 | offsets = (47.5, 2.4, 7.4) 10 | scales = (25.6, 11.2, 16.8) 11 | 12 | def norm_lab_to_rgb(L, ab, norm=True): 13 | '''given an Nx1xWxH Tensor L and an Nx2xwxh Tensor ab, normalized accoring to offsets and 14 | scales above, upsample the ab channels and combine with L, and form an RGB image. 15 | 16 | norm: If false, assume that L, ab are not normalized and already in the correct range''' 17 | 18 | lab = torch.cat([L, ab], dim=1) 19 | for i in range(1 + 2*norm): 20 | lab[:, i] = lab[:, i] * scales[i] + offsets[i] 21 | 22 | lab[:, 0].clamp_(0., 100.) 23 | lab[:, 1:].clamp_(-128, 128) 24 | 25 | lab = lab.cpu().data.numpy() 26 | rgb = [color.lab2rgb(np.transpose(l, (1, 2, 0))).transpose(2, 0, 1) for l in lab] 27 | return np.array(rgb) 28 | 29 | class LabColorDataset(Dataset): 30 | def __init__(self, file_list, transform=None, noise=False): 31 | 32 | self.files = file_list 33 | self.transform = transform 34 | self.to_tensor = T.ToTensor() 35 | self.noise = noise 36 | 37 | def __len__(self): 38 | return len(self.files) 39 | 40 | def __getitem__(self, idx): 41 | 42 | im = Image.open(self.files[idx]) 43 | if self.transform: 44 | im = self.transform(im) 45 | im = self.to_tensor(im).numpy() 46 | 47 | im = np.transpose(im, (1,2,0)) 48 | if im.shape[2] != 3: 49 | im = np.stack([im[:,:,0]]*3, axis=2) 50 | im = color.rgb2lab(im).transpose((2, 0, 1)) 51 | 52 | for i in range(3): 53 | im[i] = (im[i] - offsets[i]) / scales[i] 54 | im = torch.Tensor(im) 55 | if self.noise: 56 | im += 0.005 * torch.rand_like(im) 57 | return im 58 | 59 | transf = T.Resize(64) 60 | 61 | test_list = [f'./train_data_128/{i}.jpg' for i in range(1, 1025)] 62 | val_list = [f'./train_data_128/{i}.jpg' for i in range(1025, 2049)] 63 | train_list = [f'./train_data_128/{i}.jpg' for i in range(2049, 3033042)] 64 | 65 | train_data = LabColorDataset(train_list, transf, noise=True) 66 | test_data = LabColorDataset(test_list, transf) 67 | val_data = LabColorDataset(val_list, transf) 68 | test_all = torch.stack(list(test_data), 0).cuda() 69 | val_all = torch.stack(list(test_data), 0).cuda() 70 | 71 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, drop_last=True) 72 | test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False) 73 | -------------------------------------------------------------------------------- /colorization_minimal_example/eval.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | 3 | import torch 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from tqdm import tqdm 7 | from scipy.spatial import distance_matrix 8 | 9 | import model 10 | import data 11 | 12 | cinn = model.ColorizationCINN(0) 13 | cinn.cuda() 14 | cinn.eval() 15 | state_dict = {k:v for k,v in torch.load('output/lsun_cinn.pt').items() if 'tmp_var' not in k} 16 | cinn.load_state_dict(state_dict) 17 | 18 | def colorize_test_set(temp=1., postfix=0, img_folder='images'): 19 | '''Colorize the whole test set once. 20 | temp: Sampling temperature 21 | postfix: Has to be integer. Append to file name (e.g. to make 10 diverse colorizations of test set) 22 | ''' 23 | counter = 0 24 | with torch.no_grad(): 25 | for Lab in tqdm(data.test_loader): 26 | Lab = Lab.cuda() 27 | z = temp * torch.randn(Lab.shape[0], model.ndim_total).cuda() 28 | L, ab = Lab[:, :1], Lab[:, 1:] 29 | 30 | ab_gen = cinn.reverse_sample(z, L) 31 | rgb_gen = data.norm_lab_to_rgb(L.cpu(), ab_gen.cpu()) 32 | 33 | for im in rgb_gen: 34 | im = np.transpose(im, (1,2,0)) 35 | plt.imsave(join(img_folder, '%.6i_%.3i.png' % (counter, postfix)), im) 36 | counter += 1 37 | 38 | 39 | def best_of_n(n): 40 | '''computes the best-of-n MSE metric''' 41 | with torch.no_grad(): 42 | errs_batches = [] 43 | for Lab in tqdm(data.test_loader, disable=True): 44 | L = Lab[:, :1].cuda() 45 | ab = Lab[:, 1:].cuda() 46 | B = L.shape[0] 47 | 48 | rgb_gt = data.norm_lab_to_rgb(L.cpu(), ab.cpu()) 49 | rgb_gt = rgb_gt.reshape(B, -1) 50 | 51 | errs = np.inf * np.ones(B) 52 | 53 | for k in range(n): 54 | z = torch.randn(B, model.ndim_total).cuda() 55 | ab_k = cinn.reverse_sample(z, L) 56 | rgb_k = data.norm_lab_to_rgb(L.cpu(), ab_k.cpu()).reshape(B, -1) 57 | 58 | errs_k = np.mean((rgb_k - rgb_gt)**2, axis=1) 59 | errs = np.minimum(errs, errs_k) 60 | 61 | errs_batches.append(np.mean(errs)) 62 | 63 | print(F'MSE best of {n}') 64 | print(np.sqrt(np.mean(errs_batches))) 65 | return np.sqrt(np.mean(errs_batches)) 66 | 67 | def rgb_var(n): 68 | '''computes the pixel-wise variance of samples''' 69 | with torch.no_grad(): 70 | var = [] 71 | for Lab in tqdm(data.test_all, disable=True): 72 | L = Lab[:1].view(1,1,64,64).expand(n, -1, -1, -1).cuda() 73 | z = torch.randn(n, model.ndim_total).cuda() 74 | 75 | ab = cinn.reverse_sample(z, L) 76 | rgb = data.norm_lab_to_rgb(L.cpu(), ab.cpu()).reshape(n, -1) 77 | 78 | var.append(np.mean(np.var(rgb, axis=0))) 79 | 80 | print(F'Var (of {n} samples)') 81 | print(np.mean(var)) 82 | print(F'sqrt(Var) (of {n} samples)') 83 | print(np.sqrt(np.mean(var))) 84 | 85 | for i in range(8): 86 | torch.manual_seed(i+111) 87 | colorize_test_set(postfix=i) 88 | -------------------------------------------------------------------------------- /colorization_minimal_example/images/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vislearn/conditional_INNs/19e316c606cae24815efa51305ce8d3a6476f819/colorization_minimal_example/images/.gitkeep -------------------------------------------------------------------------------- /colorization_minimal_example/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim 4 | 5 | import FrEIA.framework as Ff 6 | import FrEIA.modules as Fm 7 | 8 | ndim_total = 2 * 64 * 64 9 | 10 | class CondNet(nn.Module): 11 | '''conditioning network''' 12 | def __init__(self): 13 | super().__init__() 14 | 15 | class Flatten(nn.Module): 16 | def __init__(self, *args): 17 | super().__init__() 18 | def forward(self, x): 19 | return x.view(x.shape[0], -1) 20 | 21 | self.resolution_levels = nn.ModuleList([ 22 | nn.Sequential(nn.Conv2d(1, 64, 3, padding=1), 23 | nn.LeakyReLU(), 24 | nn.Conv2d(64, 64, 3, padding=1)), 25 | 26 | nn.Sequential(nn.LeakyReLU(), 27 | nn.Conv2d(64, 128, 3, padding=1), 28 | nn.LeakyReLU(), 29 | nn.Conv2d(128, 128, 3, padding=1, stride=2)), 30 | 31 | nn.Sequential(nn.LeakyReLU(), 32 | nn.Conv2d(128, 128, 3, padding=1, stride=2)), 33 | 34 | nn.Sequential(nn.LeakyReLU(), 35 | nn.AvgPool2d(4), 36 | Flatten(), 37 | nn.Linear(2048, 512))]) 38 | 39 | def forward(self, c): 40 | outputs = [c] 41 | for m in self.resolution_levels: 42 | outputs.append(m(outputs[-1])) 43 | return outputs[1:] 44 | 45 | class ColorizationCINN(nn.Module): 46 | '''cINN, including the ocnditioning network''' 47 | def __init__(self, lr): 48 | super().__init__() 49 | 50 | self.cinn = self.build_inn() 51 | self.cond_net = CondNet() 52 | 53 | self.trainable_parameters = [p for p in self.cinn.parameters() if p.requires_grad] 54 | for p in self.trainable_parameters: 55 | p.data = 0.02 * torch.randn_like(p) 56 | 57 | self.trainable_parameters += list(self.cond_net.parameters()) 58 | self.optimizer = torch.optim.Adam(self.trainable_parameters, lr=lr) 59 | 60 | def build_inn(self): 61 | 62 | def sub_conv(ch_hidden, kernel): 63 | pad = kernel // 2 64 | return lambda ch_in, ch_out: nn.Sequential( 65 | nn.Conv2d(ch_in, ch_hidden, kernel, padding=pad), 66 | nn.ReLU(), 67 | nn.Conv2d(ch_hidden, ch_out, kernel, padding=pad)) 68 | 69 | def sub_fc(ch_hidden): 70 | return lambda ch_in, ch_out: nn.Sequential( 71 | nn.Linear(ch_in, ch_hidden), 72 | nn.ReLU(), 73 | nn.Linear(ch_hidden, ch_out)) 74 | 75 | nodes = [Ff.InputNode(2, 64, 64)] 76 | # outputs of the cond. net at different resolution levels 77 | conditions = [Ff.ConditionNode(64, 64, 64), 78 | Ff.ConditionNode(128, 32, 32), 79 | Ff.ConditionNode(128, 16, 16), 80 | Ff.ConditionNode(512)] 81 | 82 | split_nodes = [] 83 | 84 | subnet = sub_conv(32, 3) 85 | for k in range(2): 86 | nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock, 87 | {'subnet_constructor':subnet, 'clamp':1.0}, 88 | conditions=conditions[0])) 89 | 90 | nodes.append(Ff.Node(nodes[-1], Fm.HaarDownsampling, {'rebalance':0.5})) 91 | 92 | for k in range(4): 93 | subnet = sub_conv(64, 3 if k%2 else 1) 94 | 95 | nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock, 96 | {'subnet_constructor':subnet, 'clamp':1.0}, 97 | conditions=conditions[1])) 98 | nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':k})) 99 | 100 | #split off 6/8 ch 101 | nodes.append(Ff.Node(nodes[-1], Fm.Split1D, 102 | {'split_size_or_sections':[2,6], 'dim':0})) 103 | split_nodes.append(Ff.Node(nodes[-1].out1, Fm.Flatten, {})) 104 | 105 | nodes.append(Ff.Node(nodes[-1], Fm.HaarDownsampling, {'rebalance':0.5})) 106 | 107 | for k in range(4): 108 | subnet = sub_conv(128, 3 if k%2 else 1) 109 | 110 | nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock, 111 | {'subnet_constructor':subnet, 'clamp':0.6}, 112 | conditions=conditions[2])) 113 | nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':k})) 114 | 115 | #split off 4/8 ch 116 | nodes.append(Ff.Node(nodes[-1], Fm.Split1D, 117 | {'split_size_or_sections':[4,4], 'dim':0})) 118 | split_nodes.append(Ff.Node(nodes[-1].out1, Fm.Flatten, {})) 119 | nodes.append(Ff.Node(nodes[-1], Fm.Flatten, {}, name='flatten')) 120 | 121 | # fully_connected part 122 | subnet = sub_fc(512) 123 | for k in range(4): 124 | nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock, 125 | {'subnet_constructor':subnet, 'clamp':0.6}, 126 | conditions=conditions[3])) 127 | nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':k})) 128 | 129 | # concat everything 130 | nodes.append(Ff.Node([s.out0 for s in split_nodes] + [nodes[-1].out0], 131 | Fm.Concat1d, {'dim':0})) 132 | nodes.append(Ff.OutputNode(nodes[-1])) 133 | 134 | return Ff.ReversibleGraphNet(nodes + split_nodes + conditions, verbose=False) 135 | 136 | def forward(self, Lab): 137 | z = self.cinn(Lab[:,1:], c=self.cond_net(Lab[:,:1])) 138 | jac = self.cinn.log_jacobian(run_forward=False) 139 | return z, jac 140 | 141 | def reverse_sample(self, z, L): 142 | return self.cinn(z, c=self.cond_net(L), rev=True) 143 | -------------------------------------------------------------------------------- /colorization_minimal_example/output/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vislearn/conditional_INNs/19e316c606cae24815efa51305ce8d3a6476f819/colorization_minimal_example/output/.gitkeep -------------------------------------------------------------------------------- /colorization_minimal_example/train.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | 3 | from tqdm import tqdm 4 | import torch 5 | import torch.optim 6 | import numpy as np 7 | 8 | import model 9 | import data 10 | 11 | cinn = model.ColorizationCINN(1e-3) 12 | cinn.cuda() 13 | scheduler = torch.optim.lr_scheduler.StepLR(cinn.optimizer, 1, gamma=0.1) 14 | 15 | N_epochs = 3 16 | t_start = time() 17 | nll_mean = [] 18 | 19 | print('Epoch\tBatch/Total \tTime \tNLL train\tNLL val\tLR') 20 | for epoch in range(N_epochs): 21 | for i, Lab in enumerate(data.train_loader): 22 | Lab = Lab.cuda() 23 | z, log_j = cinn(Lab) 24 | 25 | nll = torch.mean(z**2) / 2 - torch.mean(log_j) / model.ndim_total 26 | nll.backward() 27 | nll_mean.append(nll.item()) 28 | cinn.optimizer.step() 29 | cinn.optimizer.zero_grad() 30 | 31 | if not i % 20: 32 | with torch.no_grad(): 33 | z, log_j = cinn(data.val_all[:512]) 34 | nll_val = torch.mean(z**2) / 2 - torch.mean(log_j) / model.ndim_total 35 | 36 | print('%.3i \t%.5i/%.5i \t%.2f \t%.6f\t%.6f\t%.2e' % (epoch, 37 | i, len(data.train_loader), 38 | (time() - t_start)/60., 39 | np.mean(nll_mean), 40 | nll_val.item(), 41 | cinn.optimizer.param_groups[0]['lr'], 42 | ), flush=True) 43 | nll_mean = [] 44 | 45 | scheduler.step() 46 | torch.save(cinn.state_dict(), f'output/lsun_cinn.pt') 47 | -------------------------------------------------------------------------------- /colorization_minimal_example/train_data_128: -------------------------------------------------------------------------------- 1 | /home/diz/data/lsun/train_data_128 -------------------------------------------------------------------------------- /mnist_cINN/.run: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | 3 | case $1 in 4 | 1) 5 | exit 1 6 | ;; 7 | 2) 8 | exit 2 9 | ;; 10 | 3) 11 | exit 3 12 | ;; 13 | 4) 14 | exit 4 15 | ;; 16 | 5) 17 | exit 5 18 | ;; 19 | 6) 20 | exit 6 21 | ;; 22 | 7) 23 | exit 7 24 | ;; 25 | 8) 26 | exit 8 27 | ;; 28 | *) 29 | python eval.py 30 | ;; 31 | esac 32 | -------------------------------------------------------------------------------- /mnist_cINN/README.rst: -------------------------------------------------------------------------------- 1 | Class-conditional Generation 2 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 3 | 4 | Checkpoint for the MNIST model: 5 | 6 | https://drive.google.com/file/d/1Vf8RFX-n-HvBwgUTFPFcMBvf1kfBSEHg 7 | 8 | (Simply download, and set ``load_file`` to the file location in ``config.py``) 9 | 10 | 11 | Colorization on 'Color-MNIST' 12 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 13 | 14 | This folder also contains an additional experiment: 'color-MNIST' 15 | (set ``colorize = True`` in ``config.py``) 16 | 17 | The color-MNIST dataset contains color images, where each image is an MNIST digit, 18 | colorized in a certain color scheme corresponding to the digit, with some random fluctuations. 19 | The task is to generate color images, the condition is black and white. 20 | The dataset can be generated by ``cd color_mnist_data; python color_mnist.py``. 21 | 22 | The model uses a small conditioning network to extract semantic information (which digit) from the condition. 23 | The conditioning network can also be pretrained by running ``python -m cond_net.py``. 24 | -------------------------------------------------------------------------------- /mnist_cINN/color_mnist_data/color_mnist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from matplotlib.colors import hsv_to_rgb, rgb_to_hsv 4 | from scipy.ndimage import zoom 5 | from skimage.filters import gaussian 6 | import torch 7 | 8 | from torch.utils.data import Dataset, DataLoader 9 | import torchvision.transforms as T 10 | import torchvision.datasets 11 | 12 | data_dir = '../mnist_data' 13 | 14 | hues_sigmas = { 15 | 0: (0, 20), 16 | 1: (50, 10), 17 | 2: (110, 30), 18 | 3: (180, 20), 19 | 4: (235, 25), 20 | 5: (305, 25), 21 | } 22 | 23 | pairings = [ 24 | (0,2), 25 | (1,3), 26 | (2,4), 27 | (3,5), 28 | (4,0), 29 | (5,1), 30 | (0,4), 31 | (2,0), 32 | (4,2), 33 | (5,3), 34 | ] 35 | 36 | imsize = 28 37 | 38 | def colorize(img, fg, bg): 39 | base_fg = hues_sigmas[fg][0] + hues_sigmas[fg][1] * np.random.randn() 40 | base_bg = hues_sigmas[bg][0] + hues_sigmas[bg][1] * np.random.randn() 41 | img_out = 0.8 * np.ones((imsize, imsize, 3)) 42 | img_out[:, :, 0] = img * base_fg / 360. 43 | img_out[:, :, 0] += (1.-img) * base_bg / 360. 44 | 45 | noise = np.random.randn(3, imsize, imsize) 46 | noise[0] = 0.25 * gaussian(noise[0], 4) 47 | noise[1] = 0.3 * gaussian(noise[1], 2) 48 | noise[2] = 0.05 * noise[2] 49 | 50 | img_out += noise.transpose((1,2,0)) 51 | img_out[:, :, 0] = img_out[:, :, 0] % 1. 52 | img_out[:, :, 1:] = np.clip(img_out[:, :, 1:], 0, 1) 53 | 54 | return np.clip(hsv_to_rgb(img_out), 0, 1) 55 | 56 | train_data = torchvision.datasets.MNIST(data_dir, train=True, transform=T.ToTensor(), download=True) 57 | test_data = torchvision.datasets.MNIST(data_dir, train=False, transform=T.ToTensor(), download=True) 58 | 59 | train_loader = DataLoader(train_data, batch_size=512, shuffle=False) 60 | 61 | images = [im.numpy() for im, labels in train_loader] 62 | labels = [labels.numpy() for im, labels in train_loader] 63 | images = np.concatenate(images, axis=0) 64 | labels = np.concatenate(labels, axis=0) 65 | 66 | def export(): 67 | from tqdm import tqdm 68 | 69 | imgs_color = [] 70 | for i in tqdm(range(len(labels))): 71 | im_color = colorize(images[i, 0], *(pairings[labels[i]])) 72 | imgs_color.append(im_color.transpose((2,0,1))) 73 | 74 | imgs_torch = torch.Tensor(images) 75 | imgs_color_torch = torch.Tensor(np.stack(imgs_color, axis=0)) 76 | labels_torch = torch.Tensor(labels) 77 | 78 | torch.save(imgs_color_torch, 'color_mnist_images.pt') 79 | torch.save(imgs_torch, 'color_mnist_masks.pt') 80 | torch.save(labels_torch, 'color_mnist_labels.pt') 81 | 82 | def plot(): 83 | n_rows = 10 84 | n_cols = 16 85 | 86 | for i in range(n_rows): 87 | matching_ims = images[labels == i] 88 | colors = pairings[i] 89 | 90 | for j in range(n_cols): 91 | im = colorize(matching_ims[j, 0], *colors) 92 | 93 | plt.subplot(n_rows, n_cols, n_cols*i+j+1) 94 | plt.imshow(im) 95 | plt.xticks([]) 96 | plt.yticks([]) 97 | 98 | plt.show() 99 | 100 | export() 101 | -------------------------------------------------------------------------------- /mnist_cINN/cond_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torchvision import datasets, transforms 6 | 7 | import config as c 8 | import data as color_data 9 | 10 | class Net(nn.Module): 11 | def __init__(self): 12 | super(Net, self).__init__() 13 | self.conv = nn.Sequential( 14 | nn.Conv2d(1, 32, kernel_size=3), 15 | nn.Conv2d(32, 64, kernel_size=3), 16 | nn.MaxPool2d(2), 17 | nn.Conv2d(64, 64, kernel_size=3), 18 | nn.Conv2d(64, 64, kernel_size=3), 19 | nn.MaxPool2d(2), 20 | ) 21 | 22 | self.linear = nn.Sequential( 23 | nn.Dropout(), 24 | nn.Linear(1024, 512), 25 | nn.Dropout(), 26 | nn.Linear(512, 512), 27 | nn.Dropout(), 28 | nn.Linear(512, c.cond_width), 29 | ) 30 | 31 | self.fc_final = nn.Linear(c.cond_width, 10) 32 | 33 | def forward(self, x): 34 | x = self.conv(x) 35 | x = x.view(c.batch_size, -1) 36 | x = self.linear(x) 37 | x = self.fc_final(x) 38 | return F.log_softmax(x, dim=1) 39 | 40 | def features(self, x): 41 | x = self.conv(x) 42 | x = x.view(c.batch_size, -1) 43 | return self.linear(x) 44 | 45 | model = Net().cuda() 46 | log_interval = 25 47 | 48 | def train(): 49 | model.train() 50 | for batch_idx, (color, target, data) in enumerate(color_data.train_loader): 51 | data, target = data.cuda(), target.long().cuda() 52 | optimizer.zero_grad() 53 | output = model(data) 54 | loss = F.nll_loss(output, target) 55 | loss.backward() 56 | optimizer.step() 57 | if batch_idx % log_interval == 0: 58 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 59 | epoch, batch_idx * len(data), len(color_data.train_loader.dataset), 60 | 100. * batch_idx / len(color_data.train_loader), loss.item())) 61 | 62 | test_loader = torch.utils.data.DataLoader( 63 | datasets.MNIST('./mnist_data', train=False, transform=transforms.ToTensor()), 64 | batch_size=c.batch_size, shuffle=True, drop_last=True) 65 | 66 | def test(): 67 | model.train() 68 | test_loss = 0 69 | correct = 0 70 | with torch.no_grad(): 71 | for data, target in test_loader: 72 | data, target = data.cuda(), target.cuda() 73 | output = model(data) 74 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 75 | pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 76 | correct += pred.eq(target.view_as(pred)).sum().item() 77 | 78 | test_loss /= len(test_loader.dataset) 79 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 80 | test_loss, correct, len(test_loader.dataset), 81 | 100. * correct / len(test_loader.dataset))) 82 | 83 | 84 | if __name__ == '__main__': 85 | 86 | optimizer = optim.SGD(model.parameters(), lr=0.03, momentum=0.5) 87 | 88 | for epoch in range(6): 89 | train() 90 | test() 91 | 92 | torch.save(model.state_dict(), c.cond_net_file) 93 | 94 | else: 95 | model.train() 96 | if c.cond_net_file: 97 | model.load_state_dict(torch.load(c.cond_net_file)) 98 | -------------------------------------------------------------------------------- /mnist_cINN/config.py: -------------------------------------------------------------------------------- 1 | ##################### 2 | # Which experiment: # 3 | ##################### 4 | 5 | # Train to colorize the 'colorized mnist' images, 6 | # instead of conditional generation 7 | colorize = False 8 | 9 | ######### 10 | # Data: # 11 | ######### 12 | 13 | data_mean = 0.0 14 | data_std = 1.0 15 | img_dims = (28, 28) 16 | output_dim = img_dims[0] * img_dims[1] 17 | if colorize: 18 | output_dim *= 3 19 | 20 | add_image_noise = 0.15 21 | 22 | ############## 23 | # Training: # 24 | ############## 25 | 26 | lr = 1e-4 27 | batch_size = 512 28 | decay_by = 0.01 29 | weight_decay = 1e-5 30 | betas = (0.9, 0.999) 31 | 32 | do_rev = False 33 | do_fwd = True 34 | 35 | n_epochs = 120 * 12 36 | n_its_per_epoch = 2**16 37 | 38 | init_scale = 0.03 39 | pre_low_lr = 1 40 | 41 | ################# 42 | # Architecture: # 43 | ################# 44 | 45 | # For cond. generation: 46 | n_blocks = 24 47 | internal_width = 512 48 | clamping = 1.5 49 | 50 | # For colorization: 51 | #n_blocks = 7 52 | #n_blocks_conv = 3 53 | #internal_width = 256 54 | #internal_width_conv = 64 55 | #clamping = 1.9 56 | cond_width = 64 # Output size of conditioning network 57 | 58 | fc_dropout = 0.0 59 | 60 | #################### 61 | # Logging/preview: # 62 | #################### 63 | 64 | loss_names = ['L', 'L_rev'] 65 | preview_upscale = 3 # Scale up the images for preview 66 | sampling_temperature = 0.8 # Sample at a reduced temperature for the preview 67 | live_visualization = False # Show samples and loss curves during training, using visdom 68 | progress_bar = True # Show a progress bar of each epoch 69 | 70 | ################### 71 | # Loading/saving: # 72 | ################### 73 | 74 | load_file = 'output/checkpoint.pt' # Load pre-trained network 75 | filename = 'output/mnist_cinn.pt' # Save parameters under this name 76 | cond_net_file = '' # Filename of the feature extraction network (only colorization) 77 | 78 | checkpoint_save_interval = 120 * 3 79 | checkpoint_save_overwrite = True # Overwrite each checkpoint with the next one 80 | checkpoint_on_error = True # Write out a checkpoint if the training crashes 81 | -------------------------------------------------------------------------------- /mnist_cINN/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join, isfile, basename 3 | from time import time 4 | from multiprocessing import Process 5 | 6 | from tqdm import tqdm 7 | import numpy as np 8 | from PIL import Image 9 | import torch 10 | from torch.utils.data import Dataset, DataLoader, TensorDataset 11 | import torchvision.transforms as T 12 | 13 | import config as c 14 | import torchvision.datasets 15 | 16 | def unnormalize(x): 17 | return x * c.data_std + c.data_mean 18 | 19 | if c.colorize: 20 | data_dir = 'color_mnist_data' 21 | 22 | ims = (torch.load(join(data_dir, 'color_mnist_images.pt')) - c.data_mean) / c.data_std 23 | labels = torch.load(join(data_dir, 'color_mnist_labels.pt')) 24 | masks = torch.load(join(data_dir, 'color_mnist_masks.pt')) 25 | 26 | dataset = torch.utils.data.TensorDataset(ims, labels, masks) 27 | 28 | train_loader = DataLoader(dataset, batch_size=c.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) 29 | test_loader = train_loader 30 | 31 | else: 32 | data_dir = 'mnist_data' 33 | 34 | train_data = torchvision.datasets.MNIST(data_dir, train=True, transform=T.ToTensor(), download=True) 35 | test_data = torchvision.datasets.MNIST(data_dir, train=False, transform=T.ToTensor(), download=True) 36 | 37 | train_loader = DataLoader(train_data, batch_size=c.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) 38 | test_loader = DataLoader(test_data, batch_size=c.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=True) 39 | -------------------------------------------------------------------------------- /mnist_cINN/eval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import numpy as np 4 | 5 | from sklearn.decomposition import PCA 6 | import matplotlib.pyplot as plt 7 | import torch 8 | from tqdm import tqdm 9 | from torch.nn.functional import avg_pool2d, interpolate 10 | 11 | import config as c 12 | import opts 13 | import data 14 | opts.parse(sys.argv) 15 | print('==='*30) 16 | print('Config options:\n') 17 | for v in dir(c): 18 | 19 | if v[0]=='_': continue 20 | s=eval('c.%s'%(v)) 21 | print(' {:25}\t{}'.format(v,s)) 22 | 23 | print('==='*30) 24 | 25 | import model 26 | 27 | model.load(c.load_file) 28 | model.model.eval() 29 | 30 | test_labels = torch.LongTensor((list(range(10))*(c.batch_size//10 + 1))[:c.batch_size]).cuda() 31 | test_cond = torch.zeros(c.batch_size, model.cond_size).cuda() 32 | test_cond.scatter_(1, test_labels.view(-1,1), 1.) 33 | 34 | def img_tile(imgs, row_col = None, transpose = False, channel_first=True, channels=3): 35 | '''tile a list of images to a large grid. 36 | imgs: iterable of images to use 37 | row_col: None (automatic), or tuple of (#rows, #columns) 38 | transpose: Wheter to stitch the list of images row-first or column-first 39 | channel_first: if true, assume images with CxWxH, else WxHxC 40 | channels: 3 or 1, number of color channels ''' 41 | if row_col == None: 42 | sqrt = np.sqrt(len(imgs)) 43 | rows = np.floor(sqrt) 44 | delt = sqrt - rows 45 | cols = np.ceil(rows + 2*delt + delt**2 / rows) 46 | rows, cols = int(rows), int(cols) 47 | else: 48 | rows, cols = row_col 49 | 50 | if channel_first: 51 | h, w = imgs[0].shape[1], imgs[0].shape[2] 52 | else: 53 | h, w = imgs[0].shape[0], imgs[0].shape[1] 54 | 55 | show_im = np.zeros((rows*h, cols*w, channels)) 56 | 57 | if transpose: 58 | def iterator(): 59 | for i in range(rows): 60 | for j in range(cols): 61 | yield i, j 62 | 63 | else: 64 | def iterator(): 65 | for j in range(cols): 66 | for i in range(rows): 67 | yield i, j 68 | 69 | k = 0 70 | for i, j in iterator(): 71 | 72 | im = imgs[k] 73 | if channel_first: 74 | im = np.transpose(im, (1, 2, 0)) 75 | 76 | show_im[h*i:h*i+h, w*j:w*j+w] = im 77 | 78 | k += 1 79 | if k == len(imgs): 80 | break 81 | 82 | return np.squeeze(show_im) 83 | 84 | def sample_outputs(sigma): 85 | '''Produce a random latent vector with sampling temperature sigma''' 86 | return sigma * torch.randn(c.batch_size, c.output_dim).cuda() 87 | 88 | def show_training_data(digit, n_imgs, save_as=None): 89 | '''Show some validation images (if you want to look for interesting examples etc.) 90 | digit: int 0-9, show images of this digit 91 | n_imgs: show this many images 92 | save_as: None, or filename, to save the image file''' 93 | imgs = [] 94 | while len(imgs) < n_imgs ** 2: 95 | color, label, img = next(iter(data.train_loader)) 96 | imgs += list(color[label==digit]) 97 | 98 | img_show = img_tile(imgs, (n_imgs, n_imgs)) 99 | plt.figure() 100 | plt.imshow(img_show) 101 | if save_as: 102 | plt.imsave(save_as, img_show) 103 | 104 | ####################################################################### 105 | # cINN with conditioning network for colorization of 'color MNIST' # 106 | ####################################################################### 107 | 108 | if c.colorize: 109 | def color_single(bw_img, n_show=11, save_as=None, subplot_args=None): 110 | '''colorize a sinlge black-and-white image. 111 | bw_img: 1x28x28 bw image 112 | n_show: how many samples to generate 113 | save_as: if not None: save image filename 114 | subplot_args: If not None: use plt.sublot(*subplot_args) instead of plt.figure()''' 115 | 116 | with torch.no_grad(): 117 | cond_features = cond_net.model.features(bw_img.expand(c.batch_size, -1, -1, -1)) 118 | cond = torch.cat([bw_img.expand(c.batch_size, 1, *c.img_dims), 119 | cond_features.view(c.batch_size, c.cond_width, 1, 1).expand(-1, -1, *c.img_dims)], dim=1) 120 | 121 | z = sample_outputs(1.0) 122 | 123 | with torch.no_grad(): 124 | colored = data.unnormalize(model.model(z, cond, rev=True)[:n_show].data.cpu().numpy()) 125 | 126 | imgs = [torch.cat([bw_img]*3, 0).cpu().numpy()] 127 | imgs += list(colored) 128 | 129 | img_show = img_tile(imgs, (1, n_show+1)) 130 | img_show = np.clip(img_show, 0, 1) 131 | 132 | if subplot_args: 133 | plt.subplot(*subplot_args) 134 | else: 135 | plt.figure() 136 | 137 | plt.imshow(img_show) 138 | plt.xticks([]) 139 | plt.yticks([]) 140 | if save_as: 141 | plt.imsave(save_as, img_show) 142 | 143 | 144 | torch.manual_seed(0) 145 | import data 146 | import cond_net 147 | 148 | color, label, img = next(iter(data.train_loader)) 149 | color, label, img = color.cuda(), label.cuda(), img.cuda() 150 | 151 | # make a large figure comparing the colorizaton of more and less ambiguous bw images: 152 | plt.figure() 153 | indx_examples = list([16, 17, 18, 19, 20]) 154 | 155 | for i, n in enumerate(indx_examples): 156 | color_single(img[n], subplot_args=(len(indx_examples), 1, i+1)) 157 | 158 | plt.savefig('colors.png', dpi=300) 159 | 160 | ######################################################################## 161 | # Standard cINN for class-conditional generation of MNIST images # 162 | ######################################################################## 163 | else: 164 | def interpolation(temp=1., n_steps=12, seeds=None, save_as=None): 165 | '''Interpolate between to random latent vectors. 166 | temp: Sampling temperature 167 | n_steps: Interpolation steps 168 | seeds: Optional 2-tuple of seeds for the two random latent vectors 169 | save_as: Optional filename to save the image.''' 170 | 171 | if seeds is not None: 172 | torch.manual_seed(seeds[0]) 173 | 174 | z_sample_0 = sample_outputs(temp) 175 | z_0 = z_sample_0[0].expand_as(z_sample_0) 176 | 177 | if seeds is not None: 178 | torch.manual_seed(seeds[1]) 179 | 180 | z_sample_1 = sample_outputs(temp) 181 | z_1 = z_sample_1[1].expand_as(z_sample_1) 182 | 183 | interpolation_steps = np.linspace(0., 1., n_steps, endpoint=True) 184 | interp_imgs = [] 185 | 186 | for t in interpolation_steps: 187 | with torch.no_grad(): 188 | im = model.model((1.-t) * z_0 + t * z_1, test_cond, rev=True).cpu().data.numpy() 189 | interp_imgs.extend([im[i:i+1] for i in range(10)]) 190 | 191 | img_show = img_tile(interp_imgs, (10, len(interpolation_steps)), transpose=False, channels=1) 192 | plt.figure() 193 | plt.imshow(img_show, cmap='gray', vmin=0, vmax=1) 194 | 195 | if save_as: 196 | plt.imsave(save_as, img_show, cmap='gray', vmin=0, vmax=1) 197 | 198 | def style_transfer(index_in, save_as=None): 199 | '''Perform style transfer as described in the cINN paper. 200 | index_in: Index of the validation image to use for the transfer. 201 | save_as: Optional filename to save the image.''' 202 | 203 | if c_test[index_in] != 9: 204 | return 205 | cond = torch.zeros(1, 10).cuda() 206 | cond[0, c_test[index_in]] = 1. 207 | 208 | with torch.no_grad(): 209 | z_reference = model.model(x_test[index_in:index_in+1], cond) 210 | z_reference = torch.cat([z_reference]*10, dim=0) 211 | 212 | imgs_generated = model.model(z_reference, test_cond[:10], rev=True).view(-1, 1, *c.img_dims) 213 | 214 | ref_img = x_test[index_in, 0].cpu() 215 | 216 | img_show = img_tile(imgs_generated.cpu(), (1,10), transpose=False, channel_first=True, channels=1) 217 | 218 | plt.figure() 219 | plt.subplot(1,2,1) 220 | plt.xlabel(str(index_in)) 221 | plt.imshow(ref_img, cmap='gray', vmin=0, vmax=1) 222 | plt.subplot(1,2,2) 223 | plt.imshow(img_show, cmap='gray', vmin=0, vmax=1) 224 | 225 | if save_as: 226 | plt.imsave(save_as, img_show, cmap='gray', vmin=0, vmax=1) 227 | 228 | def val_set_pca(I=0,C=9, save_as=None): 229 | '''Perform PCA uing the latent codes of the validation set, to identify disentagled 230 | and semantic latent dimensions. 231 | I: Index of the validation image to use for the transfer. 232 | C: Which digit to use (0-9). 233 | save_as: Optional filename to save the image.''' 234 | cond = torch.zeros(len(c_test), model.cond_size).cuda() 235 | cond.scatter_(1, c_test.view(-1,1), 1.) 236 | 237 | with torch.no_grad(): 238 | z_all = model.model(x_test, cond).data.cpu().numpy() 239 | 240 | pca = PCA(whiten=True) 241 | pca.fit(z_all) 242 | u = pca.transform(z_all) 243 | 244 | gridsize = 10 245 | extent = 8. 246 | u_grid = np.zeros((gridsize, u.shape[1])) 247 | 248 | U = np.linspace(-extent, extent, gridsize) 249 | weights = [[(0,0.55)], 250 | [(1,0.1), (3, 0.4), (4, 0.5)], 251 | [(2,0.33), (3, 0.33), (1, -0.33)]] 252 | 253 | for i, u_i in enumerate(U): 254 | for j, w in weights[I]: 255 | u_grid[i, j] = u_i * w 256 | 257 | z_grid = pca.inverse_transform(u_grid) 258 | grid_cond = torch.zeros(gridsize, 10).cuda() 259 | grid_cond[:, C] = 1. 260 | 261 | with torch.no_grad(): 262 | imgs = model.model(torch.Tensor(z_grid).cuda(), grid_cond, rev=True).view(-1, 1, 28, 28) 263 | img_show = img_tile(imgs.cpu(), (1,gridsize), transpose=False, channel_first=True, channels=1) 264 | 265 | plt.imsave(F'./images/pca/digit_{C}_component_{I}.png', img_show, cmap='gray', vmin=0, vmax=1) 266 | plt.imshow(img_show, cmap='gray', vmin=0, vmax=1) 267 | 268 | def temperature(temp=None, rows=10, columns=24, save_as=None): 269 | '''Show the effect of changing sampling temperature. 270 | temp: If None, interpolate between 0 and 1.5 in `columns` steps. 271 | If float, use it as the sampling temperature. 272 | rows: Number of rows (10=1 for each digit) to show 273 | columns: Number of columns (interpolation steps for temperature) 274 | save_as: Optional filename to save the image.''' 275 | 276 | temperature_imgs = [] 277 | temp_steps = np.linspace(0., 1.5, columns, endpoint=True) 278 | 279 | ticks = [ (i+0.5) * c.img_dims[1] for i in range(len(temp_steps))] 280 | labels = [ '%.2f' % (s) for s in temp_steps ] 281 | 282 | for s in temp_steps: 283 | 284 | if temp is None: 285 | z_sample = sample_outputs(s) 286 | else: 287 | z_sample = sample_outputs(temp) 288 | 289 | z_sample[:] = z_sample[0] 290 | 291 | with torch.no_grad(): 292 | temperature_imgs.append(model.model(z_sample, test_cond, rev=True).cpu().data.numpy()) 293 | 294 | imgs = [temperature_imgs[i][j:j+1] for j in range(rows) for i in range(len(temp_steps))] 295 | img_show = img_tile(imgs, (columns, rows), transpose=False, channel_first=True, channels=1) 296 | 297 | if save_as: 298 | plt.imsave(save_as, img_show, cmap='gray', vmin=0, vmax=1) 299 | 300 | for s in tqdm(range(0, 256)): 301 | torch.manual_seed(s) 302 | temperature(0.88, columns=1, save_as='./images/samples/T_%.4i.png' % (s)) 303 | plt.title(str(s)) 304 | 305 | x_test = [] 306 | c_test = [] 307 | for x,cc in data.test_loader: 308 | x_test.append(x) 309 | c_test.append(cc) 310 | x_test, c_test = torch.cat(x_test, dim=0).cuda(), torch.cat(c_test, dim=0).cuda() 311 | 312 | for i in [284, 394, 422, 759, 639, 599, 471, 449, 448, 426]: 313 | style_transfer(i, save_as=F'images/style_transf_{i}.png') 314 | plt.title(str(i)) 315 | 316 | interpolation(1.0, seeds=(51,89), n_steps=12) 317 | 318 | for j in range(3): 319 | plt.figure() 320 | for i in range(10): 321 | plt.subplot(10, 1, i+1) 322 | val_set_pca(I=j, C=i) 323 | plt.title(str(j)) 324 | 325 | plt.show() 326 | -------------------------------------------------------------------------------- /mnist_cINN/extra_modules.py: -------------------------------------------------------------------------------- 1 | from math import exp 2 | import warnings 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from torch.autograd import Variable 10 | 11 | from FrEIA.modules import * 12 | 13 | class F_fully_conv(nn.Module): 14 | 15 | def __init__(self, in_channels, out_channels, channels_hidden=64, kernel_size=3, leaky_slope=0.01): 16 | super().__init__() 17 | 18 | pad = kernel_size // 2 19 | 20 | self.leaky_slope = leaky_slope 21 | self.conv1 = nn.Conv2d(in_channels, channels_hidden, kernel_size=kernel_size, padding=pad) 22 | self.conv2 = nn.Conv2d(in_channels + channels_hidden, channels_hidden, kernel_size=kernel_size, padding=pad) 23 | self.conv3 = nn.Conv2d(in_channels + 2*channels_hidden, out_channels, kernel_size=1, padding=0) 24 | 25 | def forward(self, x): 26 | x1 = F.leaky_relu(self.conv1(x), self.leaky_slope) 27 | x2 = F.leaky_relu(self.conv2(torch.cat([x, x1], 1))) 28 | x3 = self.conv3(torch.cat([x, x1, x2], 1)) 29 | return x3 30 | 31 | class F_fully_shallow(nn.Module): 32 | 33 | def __init__(self, size_in, size, internal_size = None, dropout=0.0): 34 | super().__init__() 35 | if not internal_size: 36 | internal_size = 2*size 37 | 38 | self.d1 = nn.Dropout(p=dropout) 39 | self.d2 = nn.Dropout(p=dropout) 40 | 41 | self.fc1 = nn.Linear(size_in, internal_size) 42 | self.fc2 = nn.Linear(internal_size, internal_size) 43 | self.fc3 = nn.Linear(internal_size, size) 44 | 45 | self.nl1 = nn.LeakyReLU() 46 | self.nl2 = nn.LeakyReLU() 47 | 48 | def forward(self, x): 49 | out = self.nl1(self.d1(self.fc1(x))) 50 | out = self.nl2(self.d2(self.fc2(out))) 51 | out = self.fc3(out) 52 | return out 53 | -------------------------------------------------------------------------------- /mnist_cINN/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.autograd import Variable 4 | 5 | import config as c 6 | 7 | def MMD(x, y): 8 | xx, yy, xy = torch.mm(x,x.t()), torch.mm(y,y.t()), torch.mm(x,y.t()) 9 | 10 | rx = (xx.diag().unsqueeze(0).expand_as(xx)) 11 | ry = (yy.diag().unsqueeze(0).expand_as(yy)) 12 | 13 | dxx = rx.t() + rx - 2.*xx 14 | dyy = ry.t() + ry - 2.*yy 15 | dxy = rx.t() + ry - 2.*xy 16 | 17 | dxx = torch.clamp(dxx, 0., np.inf) 18 | dyy = torch.clamp(dyy, 0., np.inf) 19 | dxy = torch.clamp(dxy, 0., np.inf) 20 | 21 | XX, YY, XY = (Variable(torch.zeros(xx.shape).cuda()), 22 | Variable(torch.zeros(xx.shape).cuda()), 23 | Variable(torch.zeros(xx.shape).cuda())) 24 | 25 | for cw in c.kernel_widths: 26 | for a in c.kernel_powers: 27 | XX += cw**a * (cw + 0.5 * dxx / a)**-a 28 | YY += cw**a * (cw + 0.5 * dyy / a)**-a 29 | XY += cw**a * (cw + 0.5 * dxy / a)**-a 30 | 31 | return torch.mean(XX + YY - 2.*XY) 32 | 33 | def moment_match(x, y): 34 | return (torch.mean(x) - torch.mean(y))**2 + (torch.var(x) - torch.var(y))**2 35 | -------------------------------------------------------------------------------- /mnist_cINN/model.py: -------------------------------------------------------------------------------- 1 | import torch.optim 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | from FrEIA.framework import * 6 | from FrEIA.modules import * 7 | from extra_modules import * 8 | import data 9 | import config as c 10 | import cond_net 11 | 12 | if c.colorize: 13 | nodes = [InputNode(3, *c.img_dims, name='inp')] 14 | else: 15 | nodes = [InputNode(*c.img_dims, name='inp')] 16 | 17 | if c.colorize: 18 | cond_size = 1 + c.cond_width 19 | cond_node = ConditionNode(cond_size, *c.img_dims) 20 | else: 21 | cond_size = 10 22 | cond_node = ConditionNode(cond_size) 23 | 24 | if c.colorize: 25 | for i in range(c.n_blocks_conv): 26 | nodes.append(Node([nodes[-1].out0], permute_layer, {'seed':i}, name=F'permute_{i}')) 27 | nodes.append(Node([nodes[-1].out0], glow_coupling_layer, {'clamp':c.clamping, 'F_class':F_fully_conv, 28 | 'F_args':{'kernel_size':1, 'channels_hidden':c.internal_width_conv}}, 29 | conditions=cond_node, name=F'conv_{i}')) 30 | 31 | 32 | nodes.append(Node([nodes[-1].out0], flattening_layer, {}, name='flatten')) 33 | 34 | for i in range(c.n_blocks): 35 | nodes.append(Node([nodes[-1].out0], permute_layer, {'seed':i}, name=F'permute_{i}')) 36 | nodes.append(Node([nodes[-1].out0], glow_coupling_layer, {'clamp':c.clamping, 37 | 'F_class':F_fully_shallow, 38 | 'F_args':{'dropout':c.fc_dropout, 'internal_size':c.internal_width}}, 39 | name=F'fc_{i}')) 40 | 41 | else: 42 | nodes.append(Node([nodes[-1].out0], flattening_layer, {}, name='flatten')) 43 | for i in range(c.n_blocks): 44 | nodes.append(Node([nodes[-1].out0], permute_layer, {'seed':i}, name=F'permute_{i}')) 45 | nodes.append(Node([nodes[-1].out0], glow_coupling_layer, {'clamp':c.clamping,'F_class':F_fully_connected, 46 | 'F_args':{'dropout':c.fc_dropout, 'internal_size':c.internal_width}}, 47 | conditions=cond_node, 48 | name=F'fc_{i}')) 49 | 50 | 51 | nodes.append(OutputNode([nodes[-1].out0], name='out')) 52 | nodes.append(cond_node) 53 | 54 | def init_model(mod): 55 | for key, param in mod.named_parameters(): 56 | split = key.split('.') 57 | if param.requires_grad: 58 | param.data = c.init_scale * torch.randn(param.data.shape).cuda() 59 | if split[3][-1] == '3': # last convolution in the coeff func 60 | param.data.fill_(0.) 61 | 62 | 63 | def optim_step(): 64 | optim.step() 65 | optim.zero_grad() 66 | 67 | def save(name): 68 | save_dict = {'opt':optim.state_dict(), 69 | 'net':model.state_dict()} 70 | if c.colorize: 71 | save_dict['cond'] = cond_net.model.state_dict() 72 | 73 | torch.save(save_dict, name) 74 | 75 | def load(name): 76 | state_dicts = torch.load(name) 77 | model.load_state_dict(state_dicts['net']) 78 | if c.colorize: 79 | cond_net.model.load_state_dict(state_dicts['cond']) 80 | try: 81 | optim.load_state_dict(state_dicts['opt']) 82 | except ValueError: 83 | print('Cannot load optimizer for some reason or other') 84 | 85 | model = ReversibleGraphNet(nodes, verbose=False) 86 | model.cuda() 87 | init_model(model) 88 | 89 | params_trainable = list(filter(lambda p: p.requires_grad, model.parameters())) 90 | 91 | gamma = (c.decay_by)**(1./c.n_epochs) 92 | optim = torch.optim.Adam(params_trainable, lr=c.lr, betas=c.betas, eps=1e-6, weight_decay=c.weight_decay) 93 | weight_scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=1, gamma=gamma) 94 | 95 | -------------------------------------------------------------------------------- /mnist_cINN/opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import config as c 3 | 4 | def parse(args): 5 | 6 | parser = argparse.ArgumentParser(prog=args[0]) 7 | 8 | parser.add_argument('-l', '--lr', default=c.lr, dest='lr', type=float) 9 | parser.add_argument('-d', '--decay', default=c.decay_by, dest='decay_by', type=float) 10 | 11 | parser.add_argument('-b', '--batchsize', default=c.batch_size, dest='batch_size', type=int) 12 | parser.add_argument('-n', '--batches', default=c.n_its_per_epoch, dest='n_its_per_epoch', type=int) 13 | parser.add_argument('-N', '--epochs', default=c.n_epochs, dest='n_epochs', type=int) 14 | 15 | parser.add_argument('-i', '--in', default=c.load_file, dest='load_file', type=str) 16 | parser.add_argument('-o', '--out', default=c.filename, dest='filename', type=str) 17 | 18 | parser.add_argument('--init-scale', default=c.init_scale, dest='init_scale', type=float) 19 | 20 | opts = parser.parse_args(args[1:]) 21 | 22 | c.lr = opts.lr 23 | c.batch_size = opts.batch_size 24 | c.decay_by = opts.decay_by 25 | c.n_its_per_epoch = opts.n_its_per_epoch 26 | c.n_epochs = opts.n_epochs 27 | c.load_file = opts.load_file 28 | c.filename = opts.filename 29 | c.init_scale = opts.init_scale 30 | 31 | -------------------------------------------------------------------------------- /mnist_cINN/output/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vislearn/conditional_INNs/19e316c606cae24815efa51305ce8d3a6476f819/mnist_cINN/output/.gitkeep -------------------------------------------------------------------------------- /mnist_cINN/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | 4 | import torch 5 | import torch.nn 6 | import torch.optim 7 | from torch.nn.functional import avg_pool2d, interpolate 8 | from torch.autograd import Variable 9 | import numpy as np 10 | import tqdm 11 | 12 | import config as c 13 | import opts 14 | opts.parse(sys.argv) 15 | config_str = "" 16 | config_str += "==="*30 + "\n" 17 | config_str += "Config options:\n\n" 18 | 19 | for v in dir(c): 20 | if v[0]=='_': continue 21 | s=eval('c.%s'%(v)) 22 | config_str += " {:25}\t{}\n".format(v,s) 23 | 24 | config_str += "==="*30 + "\n" 25 | 26 | print(config_str) 27 | 28 | import model 29 | import data 30 | import viz 31 | import losses 32 | 33 | if c.colorize: 34 | import cond_net 35 | 36 | class dummy_loss(object): 37 | def item(self): 38 | return 1. 39 | 40 | if c.load_file: 41 | model.load(c.load_file) 42 | 43 | def sample_outputs(sigma): 44 | return sigma * torch.cuda.FloatTensor(c.batch_size, c.output_dim).normal_() 45 | 46 | if c.colorize: 47 | cond_tensor = torch.zeros(c.batch_size, model.cond_size, *c.img_dims).cuda() 48 | 49 | def make_cond(mask, cond_features): 50 | cond_tensor[:, 0] = mask[:, 0] 51 | cond_tensor[:, 1:] = cond_features.view(c.batch_size, -1, 1, 1).expand(-1, -1, *c.img_dims) 52 | return cond_tensor 53 | 54 | else: 55 | cond_tensor = torch.zeros(c.batch_size, model.cond_size).cuda() 56 | def make_cond(labels): 57 | cond_tensor.zero_() 58 | cond_tensor.scatter_(1, labels.view(-1,1), 1.) 59 | return cond_tensor 60 | 61 | test_labels = torch.LongTensor((list(range(10))*(c.batch_size//10 + 1))[:c.batch_size]).cuda() 62 | test_cond = make_cond(test_labels).clone() 63 | 64 | try: 65 | for i_epoch in range(-c.pre_low_lr, c.n_epochs): 66 | 67 | loss_history = [] 68 | data_iter = iter(data.train_loader) 69 | 70 | if i_epoch < 0: 71 | for param_group in model.optim.param_groups: 72 | param_group['lr'] = c.lr * 2e-2 73 | 74 | for i_batch, data_tuple in tqdm.tqdm(enumerate(data_iter), 75 | total=min(len(data.train_loader), c.n_its_per_epoch), 76 | leave=False, 77 | mininterval=1., 78 | disable=(not c.progress_bar), 79 | ncols=83): 80 | 81 | if c.colorize: 82 | x, labels, masks = data_tuple 83 | #print() 84 | #print(x.shape, labels.shape, masks.shape, cond_tensor.shape) 85 | #torch.Size([512, 3, 28, 28]) torch.Size([512]) torch.Size([512, 1, 28, 28]) torch.Size([512, 65]) 86 | x, labels, masks = x.cuda(), labels.cuda(), masks.cuda() 87 | x += c.add_image_noise * torch.cuda.FloatTensor(x.shape).normal_() 88 | with torch.no_grad(): 89 | cond_features = cond_net.model.features(masks) 90 | cond = make_cond(masks, cond_features) 91 | 92 | else: 93 | x, labels = data_tuple 94 | x, labels = x.cuda(), labels.cuda() 95 | x += c.add_image_noise * torch.cuda.FloatTensor(x.shape).normal_() 96 | 97 | cond = make_cond(labels.cuda()) 98 | 99 | output = model.model(x, cond) 100 | 101 | if c.do_fwd: 102 | zz = torch.sum(output**2, dim=1) 103 | jac = model.model.log_jacobian(run_forward=False) 104 | 105 | neg_log_likeli = 0.5 * zz - jac 106 | 107 | l = torch.mean(neg_log_likeli) 108 | l.backward(retain_graph=c.do_rev) 109 | else: 110 | l = dummy_loss() 111 | 112 | if c.do_rev: 113 | samples_noisy = sample_outputs(c.latent_noise) + output.data 114 | 115 | x_rec = model.model(samples_noisy, rev=True) 116 | l_rev = torch.mean( (x-x_rec)**2 ) 117 | l_rev.backward() 118 | else: 119 | l_rev = dummy_loss() 120 | 121 | model.optim_step() 122 | loss_history.append([l.item(), l_rev.item()]) 123 | 124 | if i_batch+1 >= c.n_its_per_epoch: 125 | # somehow the data loader workers don't shut down automatically 126 | try: 127 | data_iter._shutdown_workers() 128 | except: 129 | pass 130 | 131 | break 132 | 133 | model.weight_scheduler.step() 134 | 135 | epoch_losses = np.mean(np.array(loss_history), axis=0) 136 | epoch_losses[0] = min(epoch_losses[0], 0) 137 | 138 | if i_epoch > 1 - c.pre_low_lr: 139 | viz.show_loss(epoch_losses, logscale=False) 140 | output_orig = output.cpu() 141 | viz.show_hist(output_orig) 142 | 143 | with torch.no_grad(): 144 | samples = sample_outputs(c.sampling_temperature) 145 | 146 | if not c.colorize: 147 | cond = test_cond 148 | 149 | rev_imgs = model.model(samples, cond, rev=True) 150 | ims = [rev_imgs] 151 | 152 | viz.show_imgs(*list(data.unnormalize(i) for i in ims)) 153 | 154 | model.model.zero_grad() 155 | 156 | if (i_epoch % c.checkpoint_save_interval) == 0: 157 | model.save(c.filename + '_checkpoint_%.4i' % (i_epoch * (1-c.checkpoint_save_overwrite))) 158 | 159 | model.save(c.filename) 160 | 161 | except: 162 | if c.checkpoint_on_error: 163 | model.save(c.filename + '_ABORT') 164 | 165 | raise 166 | -------------------------------------------------------------------------------- /mnist_cINN/viz.py: -------------------------------------------------------------------------------- 1 | from scipy.ndimage import zoom 2 | import numpy as np 3 | 4 | import config as c 5 | import data 6 | 7 | 8 | class Visualizer: 9 | def __init__(self, loss_labels): 10 | self.n_losses = len(loss_labels) 11 | self.loss_labels = loss_labels 12 | self.counter = 0 13 | 14 | header = 'Epoch' 15 | for l in loss_labels: 16 | header += '\t\t%s' % (l) 17 | 18 | print(header) 19 | 20 | def update_losses(self, losses, *args): 21 | print('\r', ' '*20, end='') 22 | line = '\r%.3i' % (self.counter) 23 | for l in losses: 24 | line += '\t\t%.4f' % (l) 25 | 26 | print(line) 27 | self.counter += 1 28 | 29 | def update_images(self, *args): 30 | pass 31 | 32 | def update_hist(self, *args): 33 | pass 34 | 35 | if c.live_visualization: 36 | import visdom 37 | import matplotlib 38 | matplotlib.use('Agg') 39 | import matplotlib.pyplot as plt 40 | 41 | n_imgs = 10 42 | n_plots = 2 43 | figsize = (4,4) 44 | im_width = c.img_dims[1] 45 | 46 | class LiveVisualizer(Visualizer): 47 | def __init__(self, loss_labels): 48 | super().__init__(loss_labels) 49 | self.viz = visdom.Visdom()#env='mnist') 50 | self.viz.close() 51 | 52 | self.l_plots = self.viz.line(X = np.zeros((1,self.n_losses)), 53 | Y = np.zeros((1,self.n_losses)), 54 | opts = {'legend':self.loss_labels}) 55 | 56 | self.imgs = self.viz.image(np.random.random((3, im_width*n_imgs*c.preview_upscale, 57 | im_width*n_imgs*c.preview_upscale))) 58 | 59 | self.fig, self.axes = plt.subplots(n_plots, n_plots, figsize=figsize) 60 | self.hist = self.viz.matplot(self.fig) 61 | 62 | 63 | def update_losses(self, losses, logscale=True): 64 | super().update_losses(losses) 65 | its = min(len(data.train_loader), c.n_its_per_epoch) 66 | y = np.array([losses]) 67 | if logscale: 68 | y = np.log10(y) 69 | 70 | self.viz.line(X = (self.counter-1) * its * np.ones((1,self.n_losses)), 71 | Y = y, 72 | opts = {'legend':self.loss_labels}, 73 | win = self.l_plots, 74 | update = 'append') 75 | 76 | def update_images(self, *img_list): 77 | 78 | w = img_list[0].shape[2] 79 | k = 0 80 | k_img = 0 81 | 82 | show_img = np.zeros((3, w*n_imgs, w*n_imgs), dtype=np.uint8) 83 | img_list_np = [] 84 | for im in img_list: 85 | im_np = im.cpu().data.numpy() 86 | img_list_np.append(np.clip((255. * im_np), 0, 255).astype(np.uint8)) 87 | 88 | for i in range(n_imgs): 89 | for j in range(n_imgs): 90 | show_img[:, w*i:w*i+w, w*j:w*j+w] = img_list_np[k][k_img] 91 | 92 | k += 1 93 | if k >= len(img_list_np): 94 | k = 0 95 | k_img += 1 96 | 97 | show_img = zoom(show_img, (1., c.preview_upscale, c.preview_upscale), order=0) 98 | 99 | self.viz.image(show_img, win = self.imgs) 100 | 101 | def update_hist(self, data): 102 | for i in range(n_plots): 103 | for j in range(n_plots): 104 | try: 105 | self.axes[i,j].clear() 106 | self.axes[i,j].hist(data[:, i*n_plots + j], bins=20, histtype='step') 107 | except ValueError: 108 | pass 109 | 110 | self.fig.tight_layout() 111 | self.viz.matplot(self.fig, win=self.hist) 112 | 113 | def close(self): 114 | self.viz.close(win=self.hist) 115 | self.viz.close(win=self.imgs) 116 | self.viz.close(win=self.l_plots) 117 | 118 | 119 | 120 | visualizer = LiveVisualizer(c.loss_names) 121 | else: 122 | visualizer = Visualizer(c.loss_names) 123 | 124 | def show_loss(losses, logscale=True): 125 | visualizer.update_losses(losses, logscale) 126 | 127 | def show_imgs(*imgs): 128 | visualizer.update_images(*imgs) 129 | 130 | def show_hist(data): 131 | visualizer.update_hist(data.data) 132 | 133 | def close(): 134 | visualizer.close() 135 | 136 | -------------------------------------------------------------------------------- /mnist_minimal_example/.run: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | 3 | case $1 in 4 | 1) 5 | python eval.py; (sleep 0.3; xdotool key "Control_L+d") & 6 | ;; 7 | 2) 8 | exit 2 9 | ;; 10 | 3) 11 | exit 3 12 | ;; 13 | 4) 14 | exit 4 15 | ;; 16 | 5) 17 | exit 5 18 | ;; 19 | 6) 20 | exit 6 21 | ;; 22 | 7) 23 | exit 7 24 | ;; 25 | 8) 26 | exit 8 27 | ;; 28 | *) 29 | python train.py; (sleep 0.3; xdotool key "Control_L+d") & 30 | ;; 31 | esac 32 | -------------------------------------------------------------------------------- /mnist_minimal_example/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader, TensorDataset 3 | import torchvision.transforms as T 4 | import torchvision.datasets 5 | 6 | batch_size = 256 7 | data_mean = 0.128 8 | data_std = 0.305 9 | 10 | # amplitude for the noise augmentation 11 | augm_sigma = 0.08 12 | data_dir = 'mnist_data' 13 | 14 | def unnormalize(x): 15 | '''go from normaized data x back to the original range''' 16 | return x * data_std + data_mean 17 | 18 | 19 | train_data = torchvision.datasets.MNIST(data_dir, train=True, download=True, 20 | transform=T.Compose([T.ToTensor(), lambda x: (x - data_mean) / data_std])) 21 | test_data = torchvision.datasets.MNIST(data_dir, train=False, download=True, 22 | transform=T.Compose([T.ToTensor(), lambda x: (x - data_mean) / data_std])) 23 | 24 | # Sample a fixed batch of 1024 validation examples 25 | val_x, val_l = zip(*list(train_data[i] for i in range(1024))) 26 | val_x = torch.stack(val_x, 0).cuda() 27 | val_l = torch.LongTensor(val_l).cuda() 28 | 29 | # Exclude the validation batch from the training data 30 | train_data.data = train_data.data[1024:] 31 | train_data.targets = train_data.targets[1024:] 32 | # Add the noise-augmentation to the (non-validation) training data: 33 | train_data.transform = T.Compose([train_data.transform, lambda x: x + augm_sigma * torch.randn_like(x)]) 34 | 35 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) 36 | test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=True) 37 | -------------------------------------------------------------------------------- /mnist_minimal_example/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | import model 6 | import data 7 | 8 | cinn = model.MNIST_cINN(0) 9 | cinn.cuda() 10 | state_dict = {k:v for k,v in torch.load('output/mnist_cinn.pt').items() if 'tmp_var' not in k} 11 | cinn.load_state_dict(state_dict) 12 | 13 | cinn.eval() 14 | 15 | def show_samples(label): 16 | '''produces and shows cINN samples for a given label (0-9)''' 17 | 18 | N_samples = 100 19 | l = torch.cuda.LongTensor(N_samples) 20 | l[:] = label 21 | 22 | z = 1.0 * torch.randn(N_samples, model.ndim_total).cuda() 23 | 24 | with torch.no_grad(): 25 | samples = cinn.reverse_sample(z, l).cpu().numpy() 26 | samples = data.unnormalize(samples) 27 | 28 | full_image = np.zeros((28*10, 28*10)) 29 | 30 | for k in range(N_samples): 31 | i, j = k // 10, k % 10 32 | full_image[28 * i : 28 * (i + 1), 33 | 28 * j : 28 * (j + 1)] = samples[k, 0] 34 | 35 | full_image = np.clip(full_image, 0, 1) 36 | plt.figure() 37 | plt.title(F'Generated digits for c={label}') 38 | plt.imshow(full_image, vmin=0, vmax=1, cmap='gray') 39 | 40 | def val_loss(): 41 | '''prints the final validiation loss of the model''' 42 | 43 | with torch.no_grad(): 44 | z, log_j = cinn(data.val_x, data.val_l) 45 | nll_val = torch.mean(z**2) / 2 - torch.mean(log_j) / model.ndim_total 46 | 47 | print('Validation loss:') 48 | print(nll_val.item()) 49 | 50 | val_loss() 51 | 52 | for i in range(10): 53 | show_samples(i) 54 | 55 | plt.show() 56 | -------------------------------------------------------------------------------- /mnist_minimal_example/images/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vislearn/conditional_INNs/19e316c606cae24815efa51305ce8d3a6476f819/mnist_minimal_example/images/.gitkeep -------------------------------------------------------------------------------- /mnist_minimal_example/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim 4 | 5 | import FrEIA.framework as Ff 6 | import FrEIA.modules as Fm 7 | 8 | ndim_total = 28 * 28 9 | 10 | def one_hot(labels, out=None): 11 | ''' 12 | Convert LongTensor labels (contains labels 0-9), to a one hot vector. 13 | Can be done in-place using the out-argument (faster, re-use of GPU memory) 14 | ''' 15 | if out is None: 16 | out = torch.zeros(labels.shape[0], 10).to(labels.device) 17 | else: 18 | out.zeros_() 19 | 20 | out.scatter_(dim=1, index=labels.view(-1,1), value=1.) 21 | return out 22 | 23 | class MNIST_cINN(nn.Module): 24 | '''cINN for class-conditional MNISt generation''' 25 | def __init__(self, lr): 26 | super().__init__() 27 | 28 | self.cinn = self.build_inn() 29 | 30 | self.trainable_parameters = [p for p in self.cinn.parameters() if p.requires_grad] 31 | for p in self.trainable_parameters: 32 | p.data = 0.01 * torch.randn_like(p) 33 | 34 | self.optimizer = torch.optim.Adam(self.trainable_parameters, lr=lr, weight_decay=1e-5) 35 | 36 | def build_inn(self): 37 | 38 | def subnet(ch_in, ch_out): 39 | return nn.Sequential(nn.Linear(ch_in, 512), 40 | nn.ReLU(), 41 | nn.Linear(512, ch_out)) 42 | 43 | cond = Ff.ConditionNode(10) 44 | nodes = [Ff.InputNode(1, 28, 28)] 45 | 46 | nodes.append(Ff.Node(nodes[-1], Fm.Flatten, {})) 47 | 48 | for k in range(20): 49 | nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom , {'seed':k})) 50 | nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock, 51 | {'subnet_constructor':subnet, 'clamp':1.0}, 52 | conditions=cond)) 53 | 54 | return Ff.ReversibleGraphNet(nodes + [cond, Ff.OutputNode(nodes[-1])], verbose=False) 55 | 56 | def forward(self, x, l): 57 | z = self.cinn(x, c=one_hot(l)) 58 | jac = self.cinn.log_jacobian(run_forward=False) 59 | return z, jac 60 | 61 | def reverse_sample(self, z, l): 62 | return self.cinn(z, c=one_hot(l), rev=True) 63 | -------------------------------------------------------------------------------- /mnist_minimal_example/output/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vislearn/conditional_INNs/19e316c606cae24815efa51305ce8d3a6476f819/mnist_minimal_example/output/.gitkeep -------------------------------------------------------------------------------- /mnist_minimal_example/train.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | 3 | from tqdm import tqdm 4 | import torch 5 | import torch.nn 6 | import torch.optim 7 | import numpy as np 8 | 9 | import model 10 | import data 11 | 12 | cinn = model.MNIST_cINN(5e-4) 13 | cinn.cuda() 14 | scheduler = torch.optim.lr_scheduler.MultiStepLR(cinn.optimizer, milestones=[20, 40], gamma=0.1) 15 | 16 | N_epochs = 60 17 | t_start = time() 18 | nll_mean = [] 19 | 20 | print('Epoch\tBatch/Total \tTime \tNLL train\tNLL val\tLR') 21 | for epoch in range(N_epochs): 22 | for i, (x, l) in enumerate(data.train_loader): 23 | x, l = x.cuda(), l.cuda() 24 | z, log_j = cinn(x, l) 25 | 26 | nll = torch.mean(z**2) / 2 - torch.mean(log_j) / model.ndim_total 27 | nll.backward() 28 | torch.nn.utils.clip_grad_norm_(cinn.trainable_parameters, 10.) 29 | nll_mean.append(nll.item()) 30 | cinn.optimizer.step() 31 | cinn.optimizer.zero_grad() 32 | 33 | if not i % 50: 34 | with torch.no_grad(): 35 | z, log_j = cinn(data.val_x, data.val_l) 36 | nll_val = torch.mean(z**2) / 2 - torch.mean(log_j) / model.ndim_total 37 | 38 | print('%.3i \t%.5i/%.5i \t%.2f \t%.6f\t%.6f\t%.2e' % (epoch, 39 | i, len(data.train_loader), 40 | (time() - t_start)/60., 41 | np.mean(nll_mean), 42 | nll_val.item(), 43 | cinn.optimizer.param_groups[0]['lr'], 44 | ), flush=True) 45 | nll_mean = [] 46 | scheduler.step() 47 | 48 | torch.save(cinn.state_dict(), 'output/mnist_cinn.pt') 49 | --------------------------------------------------------------------------------