├── .gitignore ├── LICENSE ├── README.org ├── main.py ├── model.py ├── mount_remote_runs.sh ├── push_to_workstation.sh ├── requirements.txt └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | runs 3 | 4 | ############################################################### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # Environments 89 | .env 90 | .venv 91 | env/ 92 | venv/ 93 | ENV/ 94 | env.bak/ 95 | venv.bak/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 nathbo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.org: -------------------------------------------------------------------------------- 1 | 2 | * PyTorch Implementation: Optimizing the Latent Space of Generative Networks 3 | My PyTorch implementation of the paper [[https://arxiv.org/abs/1707.05776]["Optimizing the Latent Space of 4 | Generative Networks"]] by Piotr Bojanowski, Armand Joulin, David Lopez-Paz, Arthur 5 | Szlam. It is a very interesting read and good to understand! 6 | 7 | My personal goal with this project was to practice reimplementing a paper, in 8 | order to gain more experience. This paper was not completely trivial, but at the 9 | same time also very /non-standard/. 10 | 11 | ** Setup and Installation 12 | Install the dependencies, ideally in a [[https://docs.python.org/3/library/venv.html][virtual environment]]: 13 | #+BEGIN_SRC sh 14 | pip install -r requirements.txt 15 | #+END_SRC 16 | 17 | Install PyTorch as described on the [[https://pytorch.org/][website]], depending on your python version, 18 | CUDA, etc. 19 | 20 | ** Train the model 21 | #+BEGIN_SRC sh 22 | python main.py 23 | #+END_SRC 24 | 25 | You can also see all available options with, to e.g. decide on the dataset, the 26 | path where the dataset is stored, or the training parameters. 27 | #+BEGIN_SRC sh 28 | python main.py -h 29 | #+END_SRC 30 | 31 | ** TensorboardX 32 | 33 | ** To-Do 34 | - [ ] explanation for the tensorboard usage 35 | - [X] logging of the model parameters for nice histograms 36 | - [X] for visual testing: evaluate always the same images to send to tensorboard 37 | - [X] rename model 38 | - [ ] Describe the `plac` parameters well 39 | - [X] Cleanup the laploss and pca parts 40 | - [X] Store the PCA part locally to save some time 41 | - [X] Use ignite 42 | ** Related: 43 | Another implementation I found on the topic: 44 | https://github.com/tneumann/minimal_glo 45 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import os 3 | import logging 4 | 5 | import torch 6 | import torch.optim as optim 7 | import torchvision 8 | from torchvision import datasets, transforms 9 | from tensorboardX import SummaryWriter 10 | import tqdm 11 | import numpy as np 12 | from ignite.engine import Events, create_supervised_trainer 13 | from ignite.metrics import RunningAverage, Loss 14 | 15 | from model import CombinedModel 16 | from utils import IndexToImageDataset, LapLoss 17 | 18 | 19 | def setup_logger(level='DEBUG'): 20 | """Setup personal logger and its handler for this module 21 | 22 | All of this is necessary in order to only get the debug messages from this 23 | module, otherwise I get tons of messages from all possible third party 24 | imports. 25 | """ 26 | logger = logging.getLogger(__name__) 27 | hdlr = logging.StreamHandler() 28 | formatter = logging.Formatter( 29 | fmt='%(levelname)s|%(name)s|%(message)s', datefmt='%Y-%m-%d %H:%M:%S') 30 | hdlr.setFormatter(formatter) 31 | logger.addHandler(hdlr) 32 | logger.setLevel(level) 33 | return logger 34 | 35 | 36 | logger = setup_logger() 37 | 38 | 39 | def get_dataloader(dataset, batch_size, data_path='./data'): 40 | if dataset.lower() == 'mnist': 41 | data_transforms = [ 42 | transforms.ToTensor(), 43 | transforms.Normalize((0.1307, ), (0.3081, )) 44 | ] 45 | 46 | train_data = datasets.MNIST( 47 | data_path, 48 | train=True, 49 | download=True, 50 | transform=transforms.Compose(data_transforms)) 51 | elif dataset.lower() == 'cifar10': 52 | data_transforms = [ 53 | transforms.ToTensor(), 54 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 55 | ] 56 | 57 | train_data = datasets.CIFAR10( 58 | data_path, 59 | train=True, 60 | download=True, 61 | transform=transforms.Compose(data_transforms)) 62 | 63 | train_loader = torch.utils.data.DataLoader( 64 | IndexToImageDataset(train_data), 65 | batch_size=batch_size, 66 | shuffle=True, 67 | num_workers=2) 68 | return train_loader 69 | 70 | 71 | def get_tensorboard_writer(description, path): 72 | if description != '': 73 | description = '_' + description 74 | writer = SummaryWriter(log_dir=path, comment=description) 75 | return writer 76 | 77 | 78 | def log_images_to_tensorboard_writer(writer, images, epoch, tag='image'): 79 | writer.add_image( 80 | tag, torchvision.utils.make_grid(images, nrow=5, normalize=True), 81 | epoch) 82 | 83 | 84 | def log_graph_to_tensorboard(writer, model, device): 85 | # Log graph to tensorboard 86 | dummy_index = torch.ones([1, 1], dtype=torch.int64, device=device) 87 | writer.add_graph(model, dummy_index) 88 | 89 | 90 | def main( 91 | # 'use_cuda': True and torch.cuda.is_available(), 92 | data_path: ('Path where the dataset is stored', 'option', '', 93 | str) = './data/', 94 | seed=1, 95 | model_learning_rate: ('', 'option', '', float) = 1, 96 | model_momentum: ('', 'option', '', float) = 0, 97 | latent_learning_rate: ('', 'option', '', float) = 10, 98 | latent_momentum: ('', 'option', '', float) = 0, 99 | batch_size: ('', 'option', '', int) = 128, 100 | # test_batch_size: ('', 'option')=1000, 101 | latent_size: ('', 'option', '', int) = 100, 102 | epochs: ('', 'option', '', int) = 250, 103 | dataset: ('', 'option', '', str) = 'cifar10', 104 | tensorboard_description: ('', 'option', 'tensorboard_description', 105 | str) = '', 106 | no_cuda: ('Do not use CUDA, even if available.', 'flag', 107 | 'no-cuda') = False, 108 | # no_tensorboard: ('', 'flag', 'no-tensorboard') = False, 109 | tensorboard_log_dir: ( 110 | 'Directory to use for the tensorboard logs. Default is `./runs/`.', 111 | 'option', 'tensorboard_log_dir', str) = None, 112 | log_interval: ('', 'option', 'log_interval', int) = 10, 113 | ): 114 | # Setup 115 | use_cuda = torch.cuda.is_available() and not no_cuda 116 | device = torch.device('cuda' if use_cuda else 'cpu') 117 | torch.manual_seed(seed) 118 | 119 | # Define data, model, optimizer, loss, etc. 120 | train_loader = get_dataloader(dataset, batch_size, data_path) 121 | model = CombinedModel(train_loader, latent_size).to(device) 122 | optimizer = optim.Adam(model.parameters()) 123 | loss_fn = LapLoss() 124 | trainer = create_supervised_trainer( 125 | model, optimizer, loss_fn, device=device) 126 | 127 | # Setup tensorboard and the overall logging and log some static data 128 | writer = get_tensorboard_writer(tensorboard_description, 129 | tensorboard_log_dir) 130 | log_graph_to_tensorboard(writer, model, device) 131 | 132 | test_indices = torch.randint( 133 | len(train_loader.dataset), size=(10, )).to(torch.int64).to(device) 134 | test_images = [train_loader.dataset[int(i)][1] for i in test_indices] 135 | test_images = torch.cat( 136 | [x.view(1, *x.size()) for x in test_images]).to(device) 137 | log_images_to_tensorboard_writer(writer, test_images, 0, 'original_image') 138 | 139 | desc = '[Epoch {:d}/{:d}] Loss: {:.4f}' 140 | pbar = tqdm.tqdm(total=len(train_loader), desc=desc.format(0, epochs, 0)) 141 | 142 | @trainer.on(Events.EPOCH_STARTED) 143 | def initialize_running_loss(engine): 144 | engine.state.running_loss = 0 145 | engine.state._running_loss_sum = 0 146 | 147 | @trainer.on(Events.ITERATION_COMPLETED) 148 | def calculate_running_loss(engine): 149 | total_iteration = (engine.state.iteration - 1) % len(train_loader) + 1 150 | engine.state._running_loss_sum += engine.state.output 151 | engine.state.running_loss = ( 152 | engine.state._running_loss_sum / total_iteration) 153 | 154 | @trainer.on(Events.ITERATION_COMPLETED) 155 | def update_progress_bar(engine): 156 | total_iteration = (engine.state.iteration - 1) % len(train_loader) + 1 157 | if total_iteration % log_interval == 0: 158 | pbar.desc = desc.format(engine.state.epoch, epochs, 159 | engine.state.running_loss) 160 | pbar.update(log_interval) 161 | 162 | @trainer.on(Events.EPOCH_COMPLETED) 163 | def refresh_progress_bar(engine): 164 | print() 165 | pbar.n = pbar.last_print_n = 0 166 | 167 | @trainer.on(Events.EPOCH_COMPLETED) 168 | def log_training_loss(engine): 169 | epoch = engine.state.epoch 170 | writer.add_scalar('metrics/train_loss', engine.state.running_loss, 171 | epoch) 172 | 173 | @trainer.on(Events.EPOCH_COMPLETED) 174 | def log_model_parameters(engine): 175 | epoch = engine.state.epoch 176 | for tag, value in model.named_parameters(): 177 | tag = tag.replace('.', '/') 178 | writer.add_histogram(tag, value.data.cpu().numpy(), epoch) 179 | writer.add_histogram(tag + '/grad', 180 | value.grad.data.cpu().numpy(), epoch) 181 | 182 | @trainer.on(Events.EPOCH_COMPLETED) 183 | def log_reconstructed_images(engine): 184 | output = model(test_indices) 185 | log_images_to_tensorboard_writer(writer, output, engine.state.epoch, 186 | 'reconstructed_image') 187 | 188 | @trainer.on(Events.COMPLETED) 189 | def close_pbar(engine): 190 | pbar.close() 191 | 192 | trainer.run(train_loader, max_epochs=epochs) 193 | 194 | 195 | if __name__ == '__main__': 196 | import plac 197 | plac.call(main) 198 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn.parameter import Parameter 8 | from sklearn.decomposition import PCA 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def _project_to_l2_ball(z): 14 | # return z / np.maximum(np.sqrt(np.sum(z**2, axis=1))[:, None], 1) 15 | # return z / np.sqrt(np.sum(z**2, axis=1))[:, None] 16 | # return z / np.max(np.sqrt(np.sum(z**2, axis=1))[:, None]) 17 | return z 18 | 19 | 20 | def _generate_latent_from_pca(train_loader, z_dim): 21 | print("[Latent Init] Preparing PCA") 22 | indices, images = zip(*[ 23 | (indices, images) for indices, images in train_loader]) 24 | indices, images = torch.cat(indices), torch.cat(images) 25 | 26 | print("[Latent Init] Performing the actual PCA") 27 | pca = PCA(n_components=z_dim) 28 | pca.fit(images.view(images.size()[0], -1).numpy()) 29 | 30 | print("[Latent Init] Creating and populating the latent variables") 31 | Z = np.empty((len(train_loader.dataset), z_dim)) 32 | Z[indices] = pca.transform(images.view(images.size()[0], -1).numpy()) 33 | Z = _project_to_l2_ball(Z) 34 | Z = torch.tensor(Z, requires_grad=True).float() 35 | return Z 36 | 37 | 38 | def _disc_or_generation(train_loader, z_dim): 39 | """Wrapper function to decide if we initialize from disc or by generating 40 | """ 41 | _path = '/tmp/GLO_pca_init_{}_{}.pt'.format( 42 | train_loader.dataset.base.filename, z_dim) 43 | if os.path.isfile(_path) and False: 44 | print( 45 | '[Latent Init] PCA already calculated before and saved at {}'. 46 | format(_path)) 47 | Z = torch.load(_path) 48 | else: 49 | Z = _generate_latent_from_pca(train_loader, z_dim) 50 | torch.save(Z, _path) 51 | return Z 52 | 53 | 54 | class LatentVariables(nn.Module): 55 | def __init__(self, train_loader, z_dim=100): 56 | super(LatentVariables, self).__init__() 57 | self.Z = Parameter(_disc_or_generation(train_loader, z_dim)) 58 | 59 | def forward(self, indices): 60 | return self.Z[indices] 61 | 62 | 63 | class Generator(nn.Module): 64 | """Vanilla DCGAN generator 65 | 66 | Copied from https://github.com/pytorch/examples/blob/master/dcgan/main.py 67 | Minor adaptation to match the 32x32 dimension on CIFAR10""" 68 | 69 | def __init__(self, train_loader, z_dim=100, n_filters=64): 70 | super(Generator, self).__init__() 71 | 72 | self.z_dim = z_dim 73 | index, image = train_loader.dataset[0] 74 | out_channels, out_width, out_height = image.size() 75 | 76 | assert out_width in [32, 64] 77 | self.main = nn.Sequential( 78 | # input is Z, going into a convolution 79 | nn.ConvTranspose2d(z_dim, n_filters * 8, 4, 1, 0, bias=False), 80 | nn.BatchNorm2d(n_filters * 8), 81 | nn.ReLU(True), 82 | # state size. (n_filters*8) x 4 x 4 83 | nn.ConvTranspose2d( 84 | n_filters * 8, n_filters * 4, 4, 2, 1, bias=False), 85 | nn.BatchNorm2d(n_filters * 4), 86 | nn.ReLU(True), 87 | # state size. (n_filters*4) x 8 x 8 88 | nn.ConvTranspose2d( 89 | n_filters * 4, n_filters * 2, 4, 2, 1, bias=False), 90 | nn.BatchNorm2d(n_filters * 2), 91 | nn.ReLU(True), 92 | # state size. (n_filters*2) x 16 x 16 93 | # nn.ConvTranspose2d( 94 | # n_filters * 2, n_filters, 4, 2, 1, bias=False), 95 | # nn.BatchNorm2d(n_filters), 96 | # nn.ReLU(True), 97 | # state size. (n_filters) x 32 x 32 98 | # nn.ConvTranspose2d(n_filters, out_channels, 4, 2, 1, bias=False), 99 | nn.ConvTranspose2d( 100 | n_filters * 2, out_channels, 4, 2, 1, bias=False), 101 | nn.Tanh() 102 | # state size. (out_channels) x 64 x 64 103 | ) 104 | 105 | def forward(self, code): 106 | return self.main(code.view(code.size(0), self.z_dim, 1, 1)) 107 | 108 | 109 | class CombinedModel(nn.Module): 110 | def __init__(self, train_loader, z_dim=100, n_filters=64): 111 | super(CombinedModel, self).__init__() 112 | self.Z = LatentVariables(train_loader, z_dim) 113 | self.Generator = Generator(train_loader, z_dim) 114 | 115 | def forward(self, index): 116 | code = self.Z(index) 117 | # code = code.view(code.size(0), self.z_dim, 1, 1) 118 | return self.Generator(code) 119 | -------------------------------------------------------------------------------- /mount_remote_runs.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | sshfs -p 58022 w0126@atcremers47.informatik.tu-muenchen.de:/usr/prakt/w0126/opt_lat/runs runs 3 | -------------------------------------------------------------------------------- /push_to_workstation.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | rsync -avrz \ 3 | --exclude '.venv' \ 4 | --exclude 'data' \ 5 | --exclude 'logs' \ 6 | --exclude 'runs' \ 7 | --exclude '__pycache__' \ 8 | --include '*.py' \ 9 | -e 'ssh -p 58022' \ 10 | * w0126@atbeetz21.informatik.tu-muenchen.de:opt_lat 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | plac 4 | scikit-learn 5 | tensorboard 6 | tensorflow 7 | tensorboardX 8 | torchvision 9 | tqdm 10 | ignite 11 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class IndexToImageDataset(Dataset): 9 | """Wrap a dataset to map indices to images 10 | 11 | In other words, instead of producing (X, y) it produces (idx, X). The label 12 | y is not relevant for our task. 13 | """ 14 | def __init__(self, base_dataset): 15 | self.base = base_dataset 16 | 17 | def __len__(self): 18 | return len(self.base) 19 | 20 | def __getitem__(self, idx): 21 | img, _ = self.base[idx] 22 | return (idx, img) 23 | 24 | 25 | def gaussian(x, sigma=1.0): 26 | return np.exp(-(x**2) / (2*(sigma**2))) 27 | 28 | 29 | def build_gauss_kernel( 30 | size=5, sigma=1.0, n_channels=1, device=None): 31 | """Construct the convolution kernel for a gaussian blur 32 | 33 | See https://en.wikipedia.org/wiki/Gaussian_blur for a definition. 34 | Overall I first generate a NxNx2 matrix of indices, and then use those to 35 | calculate the gaussian function on each element. The two dimensional 36 | Gaussian function is then the product along axis=2. 37 | Also, in_channels == out_channels == n_channels 38 | """ 39 | if size % 2 != 1: 40 | raise ValueError("kernel size must be uneven") 41 | grid = np.mgrid[range(size), range(size)] - size//2 42 | kernel = np.prod(gaussian(grid, sigma), axis=0) 43 | # kernel = np.sum(gaussian(grid, sigma), axis=0) 44 | kernel /= np.sum(kernel) 45 | 46 | # repeat same kernel for all pictures and all channels 47 | # Also, conv weight should be (out_channels, in_channels/groups, h, w) 48 | kernel = np.tile(kernel, (n_channels, 1, 1, 1)) 49 | kernel = torch.from_numpy(kernel).to(torch.float).to(device) 50 | return kernel 51 | 52 | 53 | def blur_images(images, kernel): 54 | """Convolve the gaussian kernel with the given stack of images""" 55 | _, n_channels, _, _ = images.shape 56 | _, _, kw, kh = kernel.shape 57 | imgs_padded = F.pad(images, (kw//2, kh//2, kw//2, kh//2), mode='replicate') 58 | return F.conv2d(imgs_padded, kernel, groups=n_channels) 59 | 60 | 61 | def laplacian_pyramid(images, kernel, max_levels=5): 62 | """Laplacian pyramid of each image 63 | 64 | https://en.wikipedia.org/wiki/Pyramid_(image_processing)#Laplacian_pyramid 65 | """ 66 | current = images 67 | pyramid = [] 68 | 69 | for level in range(max_levels): 70 | filtered = blur_images(current, kernel) 71 | diff = current - filtered 72 | pyramid.append(diff) 73 | current = F.avg_pool2d(filtered, 2) 74 | pyramid.append(current) 75 | return pyramid 76 | 77 | 78 | class LapLoss(nn.Module): 79 | def __init__(self, max_levels=5, kernel_size=5, sigma=1.0): 80 | super(LapLoss, self).__init__() 81 | self.max_levels = max_levels 82 | self.kernel_size = kernel_size 83 | self.sigma = sigma 84 | self._gauss_kernel = None 85 | 86 | def forward(self, output, target): 87 | if (self._gauss_kernel is None 88 | or self._gauss_kernel.shape[1] != output.shape[1]): 89 | self._gauss_kernel = build_gauss_kernel( 90 | n_channels=output.shape[1], 91 | device=output.device) 92 | output_pyramid = laplacian_pyramid( 93 | output, self._gauss_kernel, max_levels=self.max_levels) 94 | target_pyramid = laplacian_pyramid( 95 | target, self._gauss_kernel, max_levels=self.max_levels) 96 | diff_levels = [F.l1_loss(o, t) 97 | for o, t in zip(output_pyramid, target_pyramid)] 98 | loss = sum([2**(-2*j) * diff_levels[j] 99 | for j in range(self.max_levels)]) 100 | return loss 101 | --------------------------------------------------------------------------------