├── EmbeddingsImagesDataset.py ├── GSN.py ├── ImageDataset.py ├── README.md ├── analyze_input_scat.py ├── compute_scattering.py ├── generator_architecture.py ├── main.py └── utils.py /EmbeddingsImagesDataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """ 4 | Author: angles 5 | Date and time: 27/04/18 - 17:58 6 | """ 7 | 8 | import os 9 | 10 | import numpy as np 11 | from PIL import Image 12 | from torch.utils.data import Dataset 13 | 14 | from utils import get_nb_files 15 | 16 | 17 | class EmbeddingsImagesDataset(Dataset): 18 | def __init__(self, dir_z, dir_x, nb_channels=3): 19 | assert get_nb_files(dir_z) == get_nb_files(dir_x) 20 | assert nb_channels in [1, 3] 21 | 22 | self.nb_files = get_nb_files(dir_z) 23 | 24 | self.nb_channels = nb_channels 25 | 26 | self.dir_z = dir_z 27 | self.dir_x = dir_x 28 | 29 | def __len__(self): 30 | return self.nb_files 31 | 32 | def __getitem__(self, idx): 33 | filename = os.path.join(self.dir_z, '{}.npy'.format(idx)) 34 | z = np.load(filename) 35 | 36 | filename = os.path.join(self.dir_x, '{}.png'.format(idx)) 37 | if self.nb_channels == 3: 38 | x = (np.ascontiguousarray(Image.open(filename), dtype=np.uint8).transpose((2, 0, 1)) / 127.5) - 1.0 39 | else: 40 | x = np.expand_dims(np.ascontiguousarray(Image.open(filename), dtype=np.uint8), axis=-1) 41 | x = (x.transpose((2, 0, 1)) / 127.5) - 1.0 42 | 43 | sample = {'z': z, 'x': x} 44 | return sample 45 | -------------------------------------------------------------------------------- /GSN.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """ 4 | Author: angles 5 | Date and time: 27/04/18 - 17:58 6 | """ 7 | 8 | import os 9 | import time 10 | from pathlib import Path 11 | 12 | import numpy as np 13 | import torch 14 | import torch.optim as optim 15 | from PIL import Image 16 | from tensorboardX import SummaryWriter 17 | from torch.autograd import Variable 18 | from torch.utils.data import DataLoader 19 | from torchvision.utils import make_grid 20 | from tqdm import tqdm 21 | 22 | from EmbeddingsImagesDataset import EmbeddingsImagesDataset 23 | from generator_architecture import Generator, weights_init 24 | from utils import create_folder, AverageMeter, now, get_hms 25 | 26 | 27 | class GSN: 28 | def __init__(self, parameters): 29 | dir_datasets = Path('~/datasets').expanduser() 30 | dir_experiments = Path('~/experiments').expanduser() 31 | 32 | dataset = parameters['dataset'] 33 | dataset_attribute = parameters['dataset_attribute'] 34 | embedding_attribute = parameters['embedding_attribute'] 35 | 36 | self.dim = parameters['dim'] 37 | self.nb_channels_first_layer = parameters['nb_channels_first_layer'] 38 | 39 | name_experiment = parameters['name_experiment'] 40 | 41 | self.dir_x_train = dir_datasets / dataset / dataset_attribute / 'train' 42 | self.dir_x_test = dir_datasets / dataset / dataset_attribute / 'test' 43 | self.dir_z_train = dir_datasets / dataset / '{0}_{1}'.format(dataset_attribute, embedding_attribute) / 'train' 44 | self.dir_z_test = dir_datasets / dataset / '{0}_{1}'.format(dataset_attribute, embedding_attribute) / 'test' 45 | 46 | self.dir_experiment = dir_experiments / 'gsn' / name_experiment 47 | self.dir_models = self.dir_experiment / 'models' 48 | self.dir_logs_train = self.dir_experiment / 'logs_train' 49 | self.dir_logs_test = self.dir_experiment / 'logs_test' 50 | 51 | self.batch_size = 128 52 | self.nb_epochs_to_save = 1 53 | 54 | def make_dirs(self): 55 | self.dir_experiment.mkdir() 56 | self.dir_models.mkdir() 57 | self.dir_logs_train.mkdir() 58 | self.dir_logs_test.mkdir() 59 | 60 | def train(self, epoch_to_restore=0): 61 | if epoch_to_restore == 0: 62 | self.make_dirs() 63 | 64 | g = Generator(self.nb_channels_first_layer, self.dim) 65 | 66 | if epoch_to_restore > 0: 67 | filename_model = self.dir_models / 'epoch_{}.pth'.format(epoch_to_restore) 68 | g.load_state_dict(torch.load(filename_model)) 69 | else: 70 | g.apply(weights_init) 71 | 72 | g.cuda() 73 | g.train() 74 | 75 | dataset_train = EmbeddingsImagesDataset(self.dir_z_train, self.dir_x_train) 76 | dataloader_train = DataLoader(dataset_train, self.batch_size, shuffle=True, num_workers=4, pin_memory=True) 77 | dataset_test = EmbeddingsImagesDataset(self.dir_z_test, self.dir_x_test) 78 | dataloader_test = DataLoader(dataset_test, self.batch_size, shuffle=True, num_workers=4, pin_memory=True) 79 | 80 | criterion = torch.nn.L1Loss() 81 | 82 | optimizer = optim.Adam(g.parameters()) 83 | writer_train = SummaryWriter(str(self.dir_logs_train)) 84 | writer_test = SummaryWriter(str(self.dir_logs_test)) 85 | 86 | try: 87 | epoch = epoch_to_restore 88 | while True: 89 | start_time = time.time() 90 | 91 | g.train() 92 | for _ in range(self.nb_epochs_to_save): 93 | epoch += 1 94 | 95 | for idx_batch, current_batch in enumerate(dataloader_train): 96 | g.zero_grad() 97 | x = Variable(current_batch['x']).float().cuda() 98 | z = Variable(current_batch['z']).float().cuda() 99 | g_z = g.forward(z) 100 | 101 | loss = criterion(g_z, x) 102 | loss.backward() 103 | optimizer.step() 104 | 105 | g.eval() 106 | with torch.no_grad(): 107 | train_l1_loss = AverageMeter() 108 | for idx_batch, current_batch in enumerate(dataloader_train): 109 | if idx_batch == 32: 110 | break 111 | x = current_batch['x'].float().cuda() 112 | z = current_batch['z'].float().cuda() 113 | g_z = g.forward(z) 114 | loss = criterion(g_z, x) 115 | train_l1_loss.update(loss) 116 | 117 | writer_train.add_scalar('l1_loss', train_l1_loss.avg, epoch) 118 | 119 | test_l1_loss = AverageMeter() 120 | for idx_batch, current_batch in enumerate(dataloader_test): 121 | if idx_batch == 32: 122 | break 123 | x = current_batch['x'].float().cuda() 124 | z = current_batch['z'].float().cuda() 125 | g_z = g.forward(z) 126 | loss = criterion(g_z, x) 127 | test_l1_loss.update(loss) 128 | 129 | writer_test.add_scalar('l1_loss', test_l1_loss.avg, epoch) 130 | images = make_grid(g_z.data[:16], nrow=4, normalize=True) 131 | writer_test.add_image('generations', images, epoch) 132 | 133 | if epoch % self.nb_epochs_to_save == 0: 134 | filename = os.path.join(self.dir_models, 'epoch_{}.pth'.format(epoch)) 135 | torch.save(g.state_dict(), filename) 136 | 137 | end_time = time.time() 138 | print("[*] Finished epoch {} in {}".format(epoch, get_hms(end_time - start_time))) 139 | 140 | finally: 141 | print('[*] Closing Writer.') 142 | writer_train.close() 143 | writer_test.close() 144 | 145 | def save_originals(self): 146 | def _save_originals(dir_z, dir_x, train_test): 147 | dataset = EmbeddingsImagesDataset(dir_z, dir_x) 148 | fixed_dataloader = DataLoader(dataset, 16) 149 | fixed_batch = next(iter(fixed_dataloader)) 150 | 151 | temp = make_grid(fixed_batch['x'], nrow=4).numpy().transpose((1, 2, 0)) 152 | 153 | filename_images = os.path.join(self.dir_experiment, 'originals_{}.png'.format(train_test)) 154 | Image.fromarray(np.uint8((temp + 1) * 127.5)).save(filename_images) 155 | 156 | _save_originals(self.dir_z_train, self.dir_x_train, 'train') 157 | _save_originals(self.dir_z_test, self.dir_x_test, 'test') 158 | 159 | def compute_errors(self, epoch): 160 | filename_model = self.dir_models / 'epoch_{}.pth'.format(epoch) 161 | g = Generator(self.nb_channels_first_layer, self.dim) 162 | g.cuda() 163 | g.load_state_dict(torch.load(filename_model)) 164 | g.eval() 165 | 166 | with torch.no_grad(): 167 | criterion = torch.nn.MSELoss() 168 | 169 | def _compute_error(dir_z, dir_x, train_test): 170 | dataset = EmbeddingsImagesDataset(dir_z, dir_x) 171 | dataloader = DataLoader(dataset, batch_size=512, num_workers=4, pin_memory=True) 172 | 173 | error = 0 174 | 175 | for idx_batch, current_batch in enumerate(tqdm(dataloader)): 176 | x = current_batch['x'].float().cuda() 177 | z = current_batch['z'].float().cuda() 178 | g_z = g.forward(z) 179 | 180 | error += criterion(g_z, x).data.cpu().numpy() 181 | 182 | error /= len(dataloader) 183 | 184 | print('Error for {}: {}'.format(train_test, error)) 185 | 186 | _compute_error(self.dir_z_train, self.dir_x_train, 'train') 187 | _compute_error(self.dir_z_test, self.dir_x_test, 'test') 188 | 189 | def get_generator(self, epoch_to_load): 190 | filename_model = self.dir_models / 'epoch_{}.pth'.format(epoch_to_load) 191 | g = Generator(self.nb_channels_first_layer, self.dim) 192 | g.load_state_dict(torch.load(filename_model)) 193 | g.cuda() 194 | g.eval() 195 | return g 196 | 197 | def conditional_generation(self, epoch_to_load, idx_image, z_initial_idx, z_end_idx): 198 | g = self.get_generator(epoch_to_load) 199 | 200 | dir_to_save = self.dir_experiment / 'conditional_generation_epoch{}_img{}_zi{}_ze{}_{}'.format(epoch_to_load, idx_image, 201 | z_initial_idx, z_end_idx, now()) 202 | dir_to_save.mkdir() 203 | 204 | with torch.no_grad(): 205 | def _generate_random(dir_z, dir_x): 206 | dataset = EmbeddingsImagesDataset(dir_z, dir_x) 207 | fixed_dataloader = DataLoader(dataset, idx_image + 1, shuffle=False) 208 | fixed_batch = next(iter(fixed_dataloader)) 209 | 210 | x = fixed_batch['x'][[idx_image]] 211 | filename_images = os.path.join(dir_to_save, 'original.png'.format(epoch_to_load)) 212 | temp = make_grid(x.data, nrow=1).cpu().numpy().transpose((1, 2, 0)) 213 | Image.fromarray(np.uint8((temp + 1) * 127.5)).save(filename_images) 214 | 215 | z0 = fixed_batch['z'][[idx_image]].numpy() 216 | nb_samples = 16 217 | batch_z = np.repeat(z0, nb_samples, axis=0) 218 | batch_z[:, z_initial_idx:z_end_idx] = np.random.randn(nb_samples, z_end_idx - z_initial_idx) 219 | z = torch.from_numpy(batch_z).float().cuda() 220 | 221 | g_z = g.forward(z) 222 | filename_images = os.path.join(dir_to_save, 'modified.png'.format(epoch_to_load)) 223 | temp = make_grid(g_z.data[:16], nrow=4).cpu().numpy().transpose((1, 2, 0)) 224 | Image.fromarray(np.uint8((temp + 1) * 127.5)).save(filename_images) 225 | 226 | _generate_random(self.dir_z_train, self.dir_x_train) 227 | 228 | def generate_from_model(self, epoch_to_load): 229 | g = self.get_generator(epoch_to_load) 230 | 231 | dir_to_save = self.dir_experiment / 'generations_epoch{}_{}'.format(epoch_to_load, now()) 232 | dir_to_save.mkdir() 233 | 234 | with torch.no_grad(): 235 | def _generate_from_model(dir_z, dir_x, train_test): 236 | dataset = EmbeddingsImagesDataset(dir_z, dir_x) 237 | fixed_dataloader = DataLoader(dataset, 16) 238 | fixed_batch = next(iter(fixed_dataloader)) 239 | 240 | z = fixed_batch['z'].float().cuda() 241 | g_z = g.forward(z) 242 | filename_images = dir_to_save / 'epoch_{}_{}.png'.format(epoch_to_load, train_test) 243 | temp = make_grid(g_z.data[:16], nrow=4).cpu().numpy().transpose((1, 2, 0)) 244 | Image.fromarray(np.uint8((temp + 1) * 127.5)).save(filename_images) 245 | 246 | _generate_from_model(self.dir_z_train, self.dir_x_train, 'train') 247 | _generate_from_model(self.dir_z_test, self.dir_x_test, 'test') 248 | 249 | def _generate_path(dir_z, dir_x, train_test): 250 | dataset = EmbeddingsImagesDataset(dir_z, dir_x) 251 | fixed_dataloader = DataLoader(dataset, 2, shuffle=True) 252 | fixed_batch = next(iter(fixed_dataloader)) 253 | 254 | z0 = fixed_batch['z'][[0]].numpy() 255 | z1 = fixed_batch['z'][[1]].numpy() 256 | 257 | batch_z = np.copy(z0) 258 | 259 | nb_samples = 100 260 | 261 | interval = np.linspace(0, 1, nb_samples) 262 | for t in interval: 263 | if t > 0: 264 | # zt = normalize((1 - t) * z0 + t * z1) 265 | zt = (1 - t) * z0 + t * z1 266 | batch_z = np.vstack((batch_z, zt)) 267 | 268 | z = torch.from_numpy(batch_z).float().cuda() 269 | g_z = g.forward(z) 270 | 271 | # filename_images = os.path.join(self.dir_experiment, 'path_epoch_{}_{}.png'.format(epoch, train_test)) 272 | # temp = make_grid(g_z.data, nrow=nb_samples).cpu().numpy().transpose((1, 2, 0)) 273 | # Image.fromarray(np.uint8((temp + 1) * 127.5)).save(filename_images) 274 | 275 | g_z = g_z.data.cpu().numpy().transpose((0, 2, 3, 1)) 276 | 277 | folder_to_save = dir_to_save / 'epoch_{}_{}_path'.format(epoch_to_load, train_test) 278 | create_folder(folder_to_save) 279 | 280 | for idx in range(nb_samples): 281 | filename_image = os.path.join(folder_to_save, '{}.png'.format(idx)) 282 | Image.fromarray(np.uint8((g_z[idx] + 1) * 127.5)).save(filename_image) 283 | 284 | _generate_path(self.dir_z_train, self.dir_x_train, 'train') 285 | _generate_path(self.dir_z_test, self.dir_x_test, 'test') 286 | 287 | def _generate_random(): 288 | nb_samples = 16 289 | z = np.random.randn(nb_samples, self.dim) 290 | # norms = np.sqrt(np.sum(z ** 2, axis=1)) 291 | # norms = np.expand_dims(norms, axis=1) 292 | # norms = np.repeat(norms, self.dim, axis=1) 293 | # z /= norms 294 | 295 | z = torch.from_numpy(z).float().cuda() 296 | g_z = g.forward(z) 297 | filename_images = os.path.join(dir_to_save, 'epoch_{}_random.png'.format(epoch_to_load)) 298 | temp = make_grid(g_z.data[:16], nrow=4).cpu().numpy().transpose((1, 2, 0)) 299 | Image.fromarray(np.uint8((temp + 1) * 127.5)).save(filename_images) 300 | 301 | _generate_random() 302 | 303 | def analyze_model(self, epoch): 304 | filename_model = os.path.join(self.dir_models, 'epoch_{}.pth'.format(epoch)) 305 | g = Generator(self.nb_channels_first_layer, self.dim) 306 | g.cuda() 307 | g.load_state_dict(torch.load(filename_model)) 308 | g.eval() 309 | 310 | nb_samples = 50 311 | batch_z = np.zeros((nb_samples, 32 * self.nb_channels_first_layer, 4, 4)) 312 | # batch_z = np.maximum(5*np.random.randn(nb_samples, 32 * self.nb_channels_first_layer, 4, 4), 0) 313 | # batch_z = 5 * np.random.randn(nb_samples, 32 * self.nb_channels_first_layer, 4, 4) 314 | 315 | for i in range(4): 316 | for j in range(4): 317 | batch_z[:, :, i, j] = create_path(nb_samples) 318 | # batch_z[:, :, 0, 0] = create_path(nb_samples) 319 | # batch_z[:, :, 0, 1] = create_path(nb_samples) 320 | # batch_z[:, :, 1, 0] = create_path(nb_samples) 321 | # batch_z[:, :, 1, 1] = create_path(nb_samples) 322 | batch_z = np.maximum(batch_z, 0) 323 | 324 | z = Variable(torch.from_numpy(batch_z)).type(torch.FloatTensor).cuda() 325 | temp = g.main._modules['4'].forward(z) 326 | for i in range(5, 10): 327 | temp = g.main._modules['{}'.format(i)].forward(temp) 328 | 329 | g_z = temp.data.cpu().numpy().transpose((0, 2, 3, 1)) 330 | 331 | folder_to_save = os.path.join(self.dir_experiment, 'epoch_{}_path_after_linear_only00_path'.format(epoch)) 332 | create_folder(folder_to_save) 333 | 334 | for idx in range(nb_samples): 335 | filename_image = os.path.join(folder_to_save, '{}.png'.format(idx)) 336 | Image.fromarray(np.uint8((g_z[idx] + 1) * 127.5)).save(filename_image) 337 | 338 | 339 | def create_path(nb_samples): 340 | z0 = 5 * np.random.randn(1, 32 * 32) 341 | z1 = 5 * np.random.randn(1, 32 * 32) 342 | 343 | # z0 = np.zeros((1, 32 * 32)) 344 | # z1 = np.zeros((1, 32 * 32)) 345 | 346 | # z0[0, 0] = -20 347 | # z1[0, 0] = 20 348 | 349 | batch_z = np.copy(z0) 350 | 351 | interval = np.linspace(0, 1, nb_samples) 352 | for t in interval: 353 | if t > 0: 354 | zt = (1 - t) * z0 + t * z1 355 | batch_z = np.vstack((batch_z, zt)) 356 | 357 | return batch_z 358 | -------------------------------------------------------------------------------- /ImageDataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: ta 3 | Date and time: 20/11/2018 - 23:56 4 | """ 5 | 6 | from torch.utils.data import Dataset 7 | 8 | from utils import get_nb_files, load_image 9 | 10 | 11 | class ImageDataset(Dataset): 12 | def __init__(self, dir_x, nb_channels=3): 13 | assert nb_channels in [1, 3] 14 | self.nb_channels = nb_channels 15 | self.nb_files = get_nb_files(dir_x) 16 | self.dir_x = dir_x 17 | 18 | def __len__(self): 19 | return self.nb_files 20 | 21 | def __getitem__(self, idx): 22 | filename = self.dir_x / '{}.png'.format(idx) 23 | if self.nb_channels == 3: 24 | x = load_image(filename).transpose((2, 0, 1)) / 255.0 25 | else: 26 | x = load_image(filename) / 255.0 27 | 28 | return x 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Regularized inverse Scattering 2 | 3 | This repository contains the code to reproduce the experiments of the paper: 4 | 5 | [Generative networks as inverse problems with Scattering transforms](https://openreview.net/pdf?id=r1NYjfbR-) 6 | 7 | Specifically, it contains the code necessary to invert the representations given by a fixed encoder (or embedding operator). 8 | 9 | It is implemented using [PyTorch](http://pytorch.org/). The file GSN.py contains the class GSN that implements the optimization of a network defined in generator_architecture.py; in the file main.py one can modify all the parameters taking into account the following: 10 | 11 | - There should be a 'datasets' folder in your home folder which contains two folders for the train and test images and two folders for its corresponding representations (embedding_attribute in main.py). 12 | 13 | - The models and the generations are saved in the folder 'experiments/gsn_hf' inside your home folder, the name of the folder is the name of the experiment indicated as a parameter in main.py. 14 | 15 | To compute the representations you can use [PyScatWave](https://github.com/edouardoyallon/pyscatwave) and to whiten them you can use PCA from [scikit-learn](http://scikit-learn.org). 16 | -------------------------------------------------------------------------------- /analyze_input_scat.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """ 4 | Author: angles 5 | Date and time: 27/04/18 - 17:58 6 | """ 7 | 8 | import os 9 | 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | from torch.utils.data import DataLoader 13 | 14 | from EmbeddingsImagesDataset import EmbeddingsImagesDataset 15 | 16 | dir_datasets = os.path.expanduser('~/datasets') 17 | dataset = 'diracs' 18 | dataset_attribute = '1024' 19 | embedding_attribute = 'ScatJ4' 20 | 21 | dir_x_train = os.path.join(dir_datasets, dataset, '{0}'.format(dataset_attribute)) 22 | dir_z_train = os.path.join(dir_datasets, dataset, '{0}_{1}'.format(dataset_attribute, embedding_attribute)) 23 | 24 | dataset = EmbeddingsImagesDataset(dir_z_train, dir_x_train, nb_channels=1) 25 | fixed_dataloader = DataLoader(dataset, batch_size=256) 26 | fixed_batch = next(iter(fixed_dataloader)) 27 | 28 | x = fixed_batch['x'].numpy() 29 | z = fixed_batch['z'].numpy() 30 | 31 | min_distance = np.inf 32 | i_tilde = 0 33 | j_tilde = 0 34 | 35 | distances = list() 36 | for i in range(256): 37 | for j in range(256): 38 | if i < j: 39 | temp = (z[i] - z[j]) ** 2 40 | temp = np.sum(temp) 41 | temp = np.sqrt(temp) 42 | 43 | if temp < min_distance: 44 | min_distance = temp 45 | i_tilde = i 46 | j_tilde = j 47 | 48 | distances.append(temp) 49 | 50 | distances = np.array(distances) 51 | print('Most similar indexes:', i_tilde, j_tilde) 52 | 53 | print('Min distances:', distances.min()) 54 | 55 | plt.hist(distances) 56 | plt.show() 57 | -------------------------------------------------------------------------------- /compute_scattering.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: ta 3 | Date and time: 20/11/2018 - 22:59 4 | """ 5 | 6 | import os 7 | 8 | os.environ["KYMATIO_BACKEND_2D"] = "skcuda" 9 | 10 | from kymatio import Scattering2D 11 | import kymatio.scattering2d.backend as backend 12 | 13 | print('[*] Backend for Scattering2D: {}'.format(backend.NAME)) 14 | 15 | from pathlib import Path 16 | 17 | import numpy as np 18 | from torch.utils.data import DataLoader 19 | from tqdm import tqdm 20 | 21 | from ImageDataset import ImageDataset 22 | 23 | 24 | def compute_scattering(dir_images, dir_to_save, scattering, batch_size): 25 | dir_to_save.mkdir() 26 | 27 | dataset = ImageDataset(dir_images) 28 | dataloader = DataLoader(dataset, batch_size, pin_memory=True, num_workers=1) 29 | 30 | for idx_batch, current_batch in enumerate(tqdm(dataloader)): 31 | images = current_batch.float().cuda() 32 | if scattering is not None: 33 | s_images = scattering(images).cpu().numpy() 34 | s_images = np.reshape(s_images, (batch_size, -1, s_images.shape[-1], s_images.shape[-1])) 35 | else: 36 | s_images = images.cpu().numpy() 37 | for idx_local in range(batch_size): 38 | idx_global = idx_local + idx_batch * batch_size 39 | filename = dir_to_save / '{}.npy'.format(idx_global) 40 | temp = s_images[idx_local] 41 | np.save(filename, temp) 42 | 43 | 44 | def main(): 45 | dir_datasets = Path('~/datasets/').expanduser() 46 | dataset_attribute = '128_rgb_512_512' 47 | dir_dataset = dir_datasets / 'celeba' / dataset_attribute 48 | 49 | batch_size = 512 50 | 51 | for J in range(4, 5): 52 | dir_to_save = dir_datasets / 'celeba' / '{}_SJ{}'.format(dataset_attribute, J) 53 | dir_to_save.mkdir() 54 | 55 | if J == 0: 56 | scattering = None 57 | else: 58 | scattering = Scattering2D(J, (128, 128)) 59 | scattering.cuda() 60 | 61 | compute_scattering(dir_dataset / 'train', dir_to_save / 'train', scattering, batch_size) 62 | compute_scattering(dir_dataset / 'test', dir_to_save / 'test', scattering, batch_size) 63 | 64 | 65 | if __name__ == '__main__': 66 | main() 67 | -------------------------------------------------------------------------------- /generator_architecture.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """ 4 | Author: angles 5 | Date and time: 27/04/18 - 17:58 6 | """ 7 | 8 | from pathlib import Path 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.utils.data import DataLoader 14 | from torchvision.utils import save_image 15 | 16 | from EmbeddingsImagesDataset import EmbeddingsImagesDataset 17 | 18 | 19 | class View(nn.Module): 20 | def __init__(self, *shape): 21 | super(View, self).__init__() 22 | self.shape = shape 23 | 24 | def forward(self, input_tensor): 25 | return input_tensor.view(*self.shape) 26 | 27 | 28 | class Generator(nn.Module): 29 | def __init__(self, nb_channels_first_layer, z_dim, size_first_layer=4): 30 | super(Generator, self).__init__() 31 | 32 | nb_channels_input = nb_channels_first_layer * 16 33 | 34 | self.main = nn.Sequential( 35 | nn.Linear(in_features=z_dim, 36 | out_features=size_first_layer * size_first_layer * nb_channels_input, 37 | bias=False), 38 | View(-1, nb_channels_input, size_first_layer, size_first_layer), 39 | nn.BatchNorm2d(nb_channels_input, eps=0.001, momentum=0.9), 40 | nn.ReLU(inplace=True), 41 | 42 | # ConvBlock(nb_channels_input, nb_channels_first_layer * 16, upsampling=True), 43 | ConvBlock(nb_channels_first_layer * 16, nb_channels_first_layer * 8, upsampling=True), 44 | ConvBlock(nb_channels_first_layer * 8, nb_channels_first_layer * 4, upsampling=True), 45 | ConvBlock(nb_channels_first_layer * 4, nb_channels_first_layer * 2, upsampling=True), 46 | ConvBlock(nb_channels_first_layer * 2, nb_channels_first_layer, upsampling=True), 47 | # ConvBlock(nb_channels_first_layer * 2, nb_channels_first_layer * 2, upsampling=True), 48 | # ConvBlock(nb_channels_first_layer * 2, nb_channels_first_layer * 2, upsampling=False), 49 | # ConvBlock(nb_channels_first_layer * 2, nb_channels_first_layer * 2, upsampling=False), 50 | # ConvBlock(nb_channels_first_layer * 2, nb_channels_first_layer, upsampling=False), 51 | # ConvBlock(nb_channels_first_layer, nb_channels_first_layer, upsampling=False), 52 | 53 | ConvBlock(nb_channels_first_layer, nb_channels_output=3, tanh=True) 54 | ) 55 | 56 | def forward(self, input_tensor): 57 | return self.main(input_tensor) 58 | 59 | 60 | class ConvBlock(nn.Module): 61 | def __init__(self, nb_channels_input, nb_channels_output, upsampling=False, tanh=False): 62 | super(ConvBlock, self).__init__() 63 | 64 | self.tanh = tanh 65 | self.upsampling = upsampling 66 | 67 | filter_size = 5 68 | padding = (filter_size - 1) // 2 69 | 70 | self.pad = nn.ReflectionPad2d(padding) 71 | self.conv = nn.Conv2d(nb_channels_input, nb_channels_output, filter_size, bias=False) 72 | self.bn_layer = nn.BatchNorm2d(nb_channels_output, eps=0.001, momentum=0.9) 73 | 74 | def forward(self, input_tensor): 75 | if self.upsampling: 76 | output = F.interpolate(input_tensor, scale_factor=2, mode='bilinear', align_corners=False) 77 | else: 78 | output = input_tensor 79 | 80 | output = self.pad(output) 81 | output = self.conv(output) 82 | output = self.bn_layer(output) 83 | 84 | if self.tanh: 85 | output = torch.tanh(output) 86 | else: 87 | output = F.relu(output) 88 | 89 | return output 90 | 91 | 92 | def weights_init(layer): 93 | if isinstance(layer, nn.Linear): 94 | layer.weight.data.normal_(0.0, 0.02) 95 | elif isinstance(layer, nn.Conv2d): 96 | layer.weight.data.normal_(0.0, 0.02) 97 | elif isinstance(layer, nn.BatchNorm2d): 98 | layer.weight.data.normal_(1.0, 0.02) 99 | layer.bias.data.fill_(0) 100 | 101 | 102 | if __name__ == '__main__': 103 | dir_datasets = Path('~/datasets').expanduser() 104 | dataset = 'celeba' 105 | dataset_attribute = '64_rgb_65536_8192' 106 | embedding_attribute = 'SJ4_pca_norm_1024' 107 | 108 | dir_x_train = dir_datasets / dataset / '{0}'.format(dataset_attribute) / 'train' 109 | dir_z_train = dir_datasets / dataset / '{0}_{1}'.format(dataset_attribute, embedding_attribute) / 'train' 110 | 111 | dataset = EmbeddingsImagesDataset(dir_z_train, dir_x_train) 112 | fixed_dataloader = DataLoader(dataset, batch_size=128) 113 | fixed_batch = next(iter(fixed_dataloader)) 114 | 115 | nb_channels_first_layer = 4 116 | z_dim = 1024 117 | 118 | input_tensor = fixed_batch['z'].float().cuda() 119 | g = Generator(nb_channels_first_layer, z_dim) 120 | g.apply(weights_init) 121 | g.cuda() 122 | g.train() 123 | 124 | output = g.forward(input_tensor) 125 | save_image(output[:16].data, 'temp.png', nrow=4) 126 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """ 4 | Author: angles 5 | Date and time: 27/04/18 - 17:58 6 | """ 7 | 8 | import os 9 | 10 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 11 | 12 | from utils import create_name_experiment 13 | from GSN import GSN 14 | 15 | parameters = dict() 16 | parameters['dataset'] = 'celeba' 17 | parameters['dataset_attribute'] = '64_rgb_65536_8192' 18 | parameters['dim'] = 812 19 | parameters['embedding_attribute'] = 'IPSJ5_d{}'.format(parameters['dim']) 20 | # parameters['embedding_attribute'] = 'PSKJ4_randproj_{0}_pca_norm_{0}'.format(parameters['dim']) 21 | # parameters['embedding_attribute'] = 'SJ4_pca_norm_{0}'.format(parameters['dim']) 22 | # parameters['embedding_attribute'] = 'SJ4_randproj_{0}_pca_norm_{0}'.format(parameters['dim']) 23 | parameters['nb_channels_first_layer'] = 32 24 | 25 | parameters['name_experiment'] = create_name_experiment(parameters, 'pilot_dsencoder') 26 | 27 | gsn = GSN(parameters) 28 | # gsn.train() 29 | # gsn.save_originals() 30 | # gsn.generate_from_model(222) 31 | # gsn.compute_errors(60) 32 | # gsn.analyze_model(404) 33 | gsn.conditional_generation(222, idx_image=76, z_initial_idx=0, z_end_idx=160) 34 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """ 4 | Author: angles 5 | Date and time: 27/04/18 - 17:58 6 | """ 7 | 8 | import os 9 | from datetime import datetime 10 | 11 | import numpy as np 12 | from PIL import Image 13 | 14 | 15 | def load_image(filename): 16 | return np.ascontiguousarray(Image.open(filename), dtype=np.uint8) 17 | 18 | 19 | def normalize(vector): 20 | norm = np.sqrt(np.sum(vector ** 2)) 21 | return vector / norm 22 | 23 | 24 | def get_nb_files(input_dir): 25 | list_files = [file for file in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, file))] 26 | return len(list_files) 27 | 28 | 29 | def create_folder(folder): 30 | if not os.path.exists(folder): 31 | os.makedirs(folder) 32 | 33 | 34 | def create_name_experiment(parameters, attribute_experiment): 35 | name_experiment = '{}_{}_{}_ncfl{}_{}'.format(parameters['dataset'], 36 | parameters['dataset_attribute'], 37 | parameters['embedding_attribute'], 38 | parameters['nb_channels_first_layer'], 39 | attribute_experiment) 40 | 41 | print('Name experiment: {}'.format(name_experiment)) 42 | 43 | return name_experiment 44 | 45 | 46 | class AverageMeter(object): 47 | """ 48 | Computes and stores the average and current value 49 | """ 50 | 51 | def __init__(self): 52 | self.val = 0 53 | self.avg = 0 54 | self.sum = 0 55 | self.count = 0 56 | 57 | def update(self, val, n=1): 58 | self.val = val 59 | self.sum += val * n 60 | self.count += n 61 | self.avg = self.sum / self.count 62 | 63 | 64 | def now(): 65 | return datetime.now().strftime("%d%m%Y%H%M%S") 66 | 67 | 68 | def get_hms(seconds): 69 | seconds = int(seconds) 70 | minutes = seconds // 60 71 | rseconds = seconds - 60 * minutes 72 | return '{}m{}s'.format(minutes, rseconds) 73 | --------------------------------------------------------------------------------