├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── _assets ├── example_groundtruth.png └── example_prior.png ├── data ├── __init__.py ├── convert_to_numpy.py ├── datasets.py └── loader.py ├── model.py ├── requirements.txt ├── run.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Jens Petersen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gqn-pytorch 2 | Implementation of [GQN](http://science.sciencemag.org/content/360/6394/1204.full?ijkey=kGcNflzOLiIKQ&keytype=ref&siteid=sci) in PyTorch. 3 | 4 | I'd like to encourage you to also check out these two implementations that I used extensively for inspiration, troubleshooting, etc.: 5 | 6 | https://github.com/iShohei220/torch-gqn 7 | https://github.com/wohlert/generative-query-network-pytorch 8 | 9 | They're probably more accessible than this implementation, because I tried to make mine very flexible. 10 | That being said, what I found a little bit annoying with the others was the need to convert the data first, so I uploaded them in numpy format for your convenience, to be found [here](https://console.cloud.google.com/storage/browser/gqn-datasets), so you should have this one up and running faster than the others ;) UPDATE: All datasets are now online! 11 | 12 | No readme is complete without at least one nice picture, so here are some example predictions from the prior (bottom) along with the groundtruth (top) after training for 1 million batches (the model could see all other viewpoints for the prediction). 13 | 14 | ![Groundtruth](_assets/example_groundtruth.png) 15 | 16 | ![Prediction](_assets/example_prior.png) 17 | 18 | ### Prerequisites 19 | 20 | I've only tried this with Python 3, but feel free to give it a go with version 2 and report any errors (or that it's working fine), it shouldn't take much to get it to work. Also, Windows is probably not working, not least because the batchgenerators I'm using don't support it. All dependencies are listed in requirements.txt, at the moment there are only two, [batchgenerators](https://github.com/MIC-DKFZ/batchgenerators) and [trixi](https://github.com/MIC-DKFZ/trixi), everything else you need will be installed with them. This GQN implementation will also become an official trixi example very soon :) 21 | 22 | ### Running 23 | 24 | 1. Download the data from [here](https://console.cloud.google.com/storage/browser/gqn-datasets). By default the loader will assume that the data folder is on the same level as this repository. You can also set the data location from the CLI. To unpack the data, use pigz with as many processes as you can afford. 25 | 26 | pigz -p PROCESSES -d data_folder/* 27 | 28 | 2. Install dependencies (if you work with virtualenv or conda, it's probably good practice to set up a new environment first). 29 | 30 | pip install -r requirements.txt 31 | 32 | 33 | 2. If you want live monitoring using Visdom, start the server 34 | 35 | python -m visdom.server -port 8080 36 | 37 | trixi uses port 8080 by default instead of 8097. 38 | 39 | 3. Run the experiment with 40 | 41 | python run.py OUTPUT_FOLDER [-v] 42 | 43 | where -v indicates you want to use a VisdomLogger. If you look into run.py, you will find a rather large default Config. Everything in this Config will be exposed to the command line automatically, as is described in the [trixi docs](https://trixi.readthedocs.io/en/latest/_api/trixi.util.html#module-trixi.util.config). For example, if you don't want to hardcode the location of the data, you can just use the data_dir attribute via `--data_dir SOME_OTHER_LOCATION`. So far there is only one mod SHAREDCORES, i.e. modification to the default Config, but in principle mods are designed to be combined, e.g. `-m MOD1 MOD2`. You can of course resume experiments and do everything else that trixi offers. 44 | 45 | ### Speed, Stability and Performance 46 | 47 | I'm using PyTorch 1.0.0 with CUDA 10.0.130 and CuDNN 7.4.1, and sometimes I get a CUDNN_STATUS_INTERNAL_ERROR in the backward pass. I'm not entirely sure what the reason for this is, other people have had similar errors in different contexts and with different PyTorch versions. If you encounter this problem, set `cudnn_benchmark = False` in run.py. Unfortunately that makes the whole thing quite a bit slower, ~1.1s/~1.5s per batch with and without shared cores versus ~0.6s/~0.8s when cudnn_benchmark is on. 48 | 49 | I can fit batches of size 36 on my TitanXp with 12GB memory. I'm not sure that's how it's supposed to be, because in the paper the authors state that they're working with batch size 36 on 4 K80 GPUs. Not sure whether that's 24GB K80s or 12GB like the Google Colab ones. My best guess is that it's the latter and they just had a full batch on each at a time. Running the experiment for 1 million batches (paper used twice as many) took me 22 days. 50 | 51 | ### TBD 52 | 53 | To be honest I don't see myself investing loads of time into this project, so feel free work on it yourself, pull requests are more than welcome :) A few things that are missing. 54 | 55 | * Multi-GPU support 56 | * A test function (as opposed to training and validation) 57 | * Actual tests (as in unittests) 58 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jenspetersen/gqn-pytorch/5062a970e3c23990ca121e82222503338243302f/__init__.py -------------------------------------------------------------------------------- /_assets/example_groundtruth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jenspetersen/gqn-pytorch/5062a970e3c23990ca121e82222503338243302f/_assets/example_groundtruth.png -------------------------------------------------------------------------------- /_assets/example_prior.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jenspetersen/gqn-pytorch/5062a970e3c23990ca121e82222503338243302f/_assets/example_prior.png -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jenspetersen/gqn-pytorch/5062a970e3c23990ca121e82222503338243302f/data/__init__.py -------------------------------------------------------------------------------- /data/convert_to_numpy.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/l3robot/gqn_datasets_translator 2 | 3 | import os 4 | import numpy as np 5 | from numpy.lib.format import open_memmap 6 | import tensorflow as tf 7 | import subprocess as sp 8 | from .datasets import all_datasets 9 | import argparse as ap 10 | 11 | tf.logging.set_verbosity(tf.logging.ERROR) # disable annoying logging 12 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # disable gpu 13 | 14 | _pose_dim = 5 15 | 16 | 17 | def collect_files(path, ext=None, key=None): 18 | if key is None: 19 | files = sorted(os.listdir(path)) 20 | else: 21 | files = sorted(os.listdir(path), key=key) 22 | 23 | if ext is not None: 24 | files = [f for f in files if os.path.splitext(f)[-1] == ext] 25 | 26 | return [os.path.join(path, fname) for fname in files] 27 | 28 | 29 | def convert_record(record, info): 30 | print(record) 31 | 32 | path, filename = os.path.split(record) 33 | images, viewpoints = process_record(record, info) 34 | 35 | return images, viewpoints 36 | 37 | 38 | def process_record(record, info): 39 | engine = tf.python_io.tf_record_iterator(record) 40 | 41 | images = [] 42 | viewpoints = [] 43 | for i, data in enumerate(engine): 44 | image, viewpoint = convert_to_numpy(data, info) 45 | images.append(image) 46 | viewpoints.append(viewpoint) 47 | 48 | return np.stack(images), np.stack(viewpoints) 49 | 50 | 51 | def process_images(example, seq_length, image_size): 52 | """Instantiates the ops used to preprocess the frames data.""" 53 | images = tf.concat(example['frames'], axis=0) 54 | images = tf.map_fn(tf.image.decode_jpeg, tf.reshape(images, [-1]), 55 | dtype=tf.uint8, back_prop=False) 56 | shape = (image_size, image_size, 3) 57 | images = tf.reshape(images, (-1, seq_length) + shape) 58 | return images 59 | 60 | 61 | def process_poses(example, seq_length): 62 | """Instantiates the ops used to preprocess the cameras data.""" 63 | poses = example['cameras'] 64 | poses = tf.reshape(poses, (-1, seq_length, _pose_dim)) 65 | return poses 66 | 67 | 68 | def convert_to_numpy(raw_data, info): 69 | seq_length = info.seq_length 70 | image_size = info.image_size 71 | 72 | feature = {'frames': tf.FixedLenFeature(shape=seq_length, dtype=tf.string), 73 | 'cameras': tf.FixedLenFeature(shape=seq_length * _pose_dim, dtype=tf.float32)} 74 | example = tf.parse_single_example(raw_data, feature) 75 | 76 | images = process_images(example, seq_length, image_size) 77 | poses = process_poses(example, seq_length) 78 | 79 | return images.numpy().squeeze(), poses.numpy().squeeze() 80 | 81 | 82 | if __name__ == '__main__': 83 | 84 | tf.enable_eager_execution() 85 | 86 | parser = ap.ArgumentParser(description='Convert gqn tfrecords to gzipped numpy arrays.') 87 | parser.add_argument('base_dir', nargs=1, 88 | help='base directory of gqn dataset') 89 | parser.add_argument('dataset', nargs=1, 90 | help='datasets to convert, eg. shepard_metzler_5_parts') 91 | parser.add_argument('-n', '--first-n', type=int, default=None, 92 | help='convert only the first n tfrecords if given') 93 | parser.add_argument('-m', '--mode', type=str, default='train', 94 | help='whether to convert train or test') 95 | parser.add_argument("-o", "--output_dir", type=str, default=os.getcwd(), help="Output directory, default current working dir") 96 | parser.add_argument("-c", "--compression_cores", type=int, default=8, help="Use this many cores for compression.") 97 | args = parser.parse_args() 98 | 99 | base_dir = os.path.expanduser(args.base_dir[0]) 100 | dataset = args.dataset[0] 101 | output_dir = os.path.join(args.output_dir, dataset) 102 | os.makedirs(output_dir, exist_ok=True) 103 | 104 | print(f'base_dir: {base_dir}') 105 | print(f'dataset: {dataset}') 106 | print(f'output_dir: {output_dir}') 107 | 108 | info = all_datasets[dataset] 109 | data_dir = os.path.join(base_dir, dataset) 110 | records = collect_files(os.path.join(data_dir, args.mode), '.tfrecord') 111 | 112 | if args.first_n is not None: 113 | records = records[:args.first_n] 114 | 115 | images_shape = (getattr(info, "{}_instances".format(args.mode)), info.seq_length, info.image_size, info.image_size, 3) 116 | viewpoints_shape = (getattr(info, "{}_instances".format(args.mode)), info.seq_length, _pose_dim) 117 | 118 | images_arr = open_memmap(os.path.join(output_dir, "{}_images.npy".format(args.mode)), mode="w+", dtype=np.uint8, shape=images_shape) 119 | viewpoints_arr = open_memmap(os.path.join(output_dir, "{}_viewpoints.npy".format(args.mode)), mode="w+", dtype=np.float32, shape=viewpoints_shape) 120 | 121 | print(f'converting {len(records)} records in {dataset}/{args.mode}') 122 | 123 | index = 0 124 | for r, record in enumerate(records): 125 | images, viewpoints = convert_record(record, info) 126 | images_arr[index:index+images.shape[0]] = images 127 | viewpoints_arr[index:index+images.shape[0]] = viewpoints 128 | index += images.shape[0] 129 | 130 | del images_arr 131 | del viewpoints_arr 132 | 133 | sp.check_output(["pigz", "--fast", "-p", "{}".format(args.compression_cores), os.path.join(output_dir, "{}_viewpoints.npy".format(args.mode))]) 134 | sp.check_output(["pigz", "--fast", "-p", "{}".format(args.compression_cores), os.path.join(output_dir, "{}_images.npy".format(args.mode))]) 135 | 136 | print('Done') 137 | -------------------------------------------------------------------------------- /data/datasets.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | 4 | DatasetInfo = namedtuple('DatasetInfo', ['image_size', 'seq_length', 'train_instances', 'test_instances']) 5 | 6 | all_datasets = dict( 7 | jaco=DatasetInfo(image_size=64, seq_length=11, train_instances=7200000, test_instances=800000), 8 | mazes=DatasetInfo(image_size=84, seq_length=300, train_instances=108000, test_instances=12000), 9 | rooms_free_camera_with_object_rotations=DatasetInfo(image_size=128, seq_length=10, train_instances=10170000, test_instances=630000), 10 | rooms_ring_camera=DatasetInfo(image_size=64, seq_length=10, train_instances=10800000, test_instances=1200000), 11 | rooms_free_camera_no_object_rotations=DatasetInfo(image_size=64, seq_length=10, train_instances=10800000, test_instances=1200000), 12 | shepard_metzler_5_parts=DatasetInfo(image_size=64, seq_length=15, train_instances=810000, test_instances=200000), 13 | shepard_metzler_7_parts=DatasetInfo(image_size=64, seq_length=15, train_instances=810000, test_instances=200000) 14 | ) 15 | -------------------------------------------------------------------------------- /data/loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from batchgenerators.dataloading.data_loader import SlimDataLoaderBase 4 | from .datasets import all_datasets 5 | 6 | data_dir = "../../" 7 | file_dir = os.path.dirname(os.path.abspath(__file__)) 8 | data_dir = os.path.join(file_dir, data_dir) 9 | 10 | 11 | def split(dataset="shepard_metzler_5_parts", N=5, seed=1): 12 | 13 | indices = np.arange(all_datasets[dataset][2]) 14 | r = np.random.RandomState(seed) 15 | r.shuffle(indices) 16 | num = len(indices) // N 17 | splits = [] 18 | for i in range(N): 19 | if i < (N-1): 20 | splits.append(sorted(indices[i*num:(i+1)*num])) 21 | else: 22 | splits.append(sorted(indices[i*num:])) 23 | return splits 24 | 25 | 26 | def load(dataset="shepard_metzler_5_parts", mode="train", image_kwargs=None, viewpoint_kwargs=None): 27 | """Use image_kwargs and viewpoint_kwargs to set e.g. mmap_mode.""" 28 | 29 | data_dir_ = os.path.join(data_dir, dataset) 30 | 31 | if image_kwargs is None: image_kwargs = {} 32 | if viewpoint_kwargs is None: viewpoint_kwargs = {} 33 | 34 | images = np.load(os.path.join(data_dir_, "{}_images.npy".format(mode)), **image_kwargs) 35 | viewpoints = np.load(os.path.join(data_dir_, "{}_viewpoints.npy".format(mode)), **viewpoint_kwargs) 36 | 37 | return {"data": images, "viewpoints": viewpoints} 38 | 39 | 40 | def transform_viewpoint(v): 41 | """Transforms the viewpoint vector into a consistentrepresentation""" 42 | 43 | return np.concatenate([v[:, :3], 44 | np.cos(v[:, 3:4]), 45 | np.sin(v[:, 3:4]), 46 | np.cos(v[:, 4:5]), 47 | np.sin(v[:, 4:5])], 1) 48 | 49 | 50 | class LinearBatchGenerator(SlimDataLoaderBase): 51 | 52 | def __init__(self, 53 | data, 54 | batch_size, 55 | dtype=np.float32, 56 | num_viewpoints=10, # both input and query 57 | shuffle_viewpoints=False, 58 | data_order=None, 59 | **kwargs): 60 | 61 | super(LinearBatchGenerator, self).__init__(data, batch_size, **kwargs) 62 | self.dtype = dtype 63 | self.num_viewpoints = num_viewpoints 64 | self.shuffle_viewpoints = shuffle_viewpoints 65 | 66 | self.current_position = 0 67 | self.was_initialized = False 68 | if self.number_of_threads_in_multithreaded is None: 69 | self.number_of_threads_in_multithreaded = 1 70 | if data_order is None: 71 | self.data_order = np.arange(data["viewpoints"].shape[0]) 72 | else: 73 | self.data_order = data_order 74 | 75 | self.num_restarted = 0 76 | 77 | def reset(self): 78 | 79 | self.current_position = self.thread_id * self.batch_size 80 | self.was_initialized = True 81 | self.rs = np.random.RandomState(self.num_restarted) 82 | self.num_restarted = self.num_restarted + 1 83 | 84 | def __len__(self): 85 | 86 | return len(self.data_order) 87 | 88 | def generate_train_batch(self): 89 | 90 | if not self.was_initialized: 91 | self.reset() 92 | if self.current_position >= len(self): 93 | self.reset() 94 | raise StopIteration 95 | batch = self.make_batch(self.current_position) 96 | self.current_position += self.number_of_threads_in_multithreaded * self.batch_size 97 | return batch 98 | 99 | def make_batch(self, idx): 100 | 101 | batch_images = [] 102 | batch_viewpoints = [] 103 | data_indices = [] 104 | viewpoint_indices = [] 105 | 106 | if self.num_viewpoints == "random": 107 | num_viewpoints = self.rs.randint(2, 16) 108 | else: 109 | num_viewpoints = self.num_viewpoints 110 | 111 | while len(batch_images) < self.batch_size: 112 | 113 | idx = idx % len(self.data_order) 114 | idx_data = self.data_order[idx] 115 | 116 | viewpoint_indices_current = np.arange(15) 117 | # for linear generator we leave the existing viewpoint order 118 | if self.shuffle_viewpoints: 119 | self.rs.shuffle(viewpoint_indices_current) 120 | viewpoint_indices_current = viewpoint_indices_current[:num_viewpoints] 121 | viewpoint_indices.append(viewpoint_indices_current) 122 | 123 | batch_images.append(np.array(self._data["data"][idx_data, viewpoint_indices_current])) 124 | batch_viewpoints.append(self._data["viewpoints"][idx_data, viewpoint_indices_current]) 125 | data_indices.append(idx_data) 126 | 127 | idx += 1 128 | 129 | batch_images = np.stack(batch_images)\ 130 | .astype(self.dtype)\ 131 | .reshape(self.batch_size * num_viewpoints, 64, 64, 3)\ 132 | .transpose(0, 3, 1, 2) 133 | batch_viewpoints = np.stack(batch_viewpoints)\ 134 | .astype(np.float32)\ 135 | .reshape(self.batch_size * num_viewpoints, -1) 136 | batch_viewpoints = transform_viewpoint(batch_viewpoints) 137 | data_indices = np.array(data_indices) 138 | viewpoint_indices = np.array(viewpoint_indices) 139 | 140 | # images are saved as uint8, so we need to normalize 141 | if self.dtype != np.uint8: 142 | batch_images /= 255. 143 | 144 | return {"data": batch_images, 145 | "viewpoints": batch_viewpoints, 146 | "num_viewpoints": num_viewpoints, 147 | "data_indices": data_indices, 148 | "viewpoint_indices": viewpoint_indices} 149 | 150 | 151 | class RandomOrderBatchGenerator(LinearBatchGenerator): 152 | 153 | def __init__(self, 154 | *args, 155 | num_viewpoints="random", 156 | shuffle_viewpoints=True, 157 | infinite=True, 158 | **kwargs): 159 | 160 | super(RandomOrderBatchGenerator, self).__init__(*args, num_viewpoints=num_viewpoints, **kwargs) 161 | self.infinite = infinite 162 | 163 | def reset(self): 164 | 165 | super(RandomOrderBatchGenerator, self).reset() 166 | self.rs.shuffle(self.data_order) 167 | 168 | def generate_train_batch(self): 169 | 170 | if not self.was_initialized: 171 | self.reset() 172 | if self.current_position >= len(self): 173 | self.reset() 174 | if not self.infinite: 175 | raise StopIteration 176 | batch = self.make_batch(self.current_position) 177 | self.current_position += self.number_of_threads_in_multithreaded * self.batch_size 178 | return batch 179 | 180 | 181 | class RandomBatchGenerator(RandomOrderBatchGenerator): 182 | pass 183 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.distributions import Normal, kl_divergence 6 | from util import ConvModule 7 | 8 | 9 | class ConvLSTMCell(ConvModule): 10 | 11 | _default_conv_kwargs = dict( 12 | kernel_size=3, 13 | stride=1, 14 | padding=1, 15 | dilation=1 16 | ) 17 | 18 | def __init__(self, 19 | in_channels, 20 | out_channels, 21 | conv_op=nn.Conv2d, 22 | conv_kwargs=None, 23 | concat_hidden_and_input=True, 24 | **kwargs): 25 | 26 | super(ConvLSTMCell, self).__init__(**kwargs) 27 | 28 | self.in_channels = in_channels 29 | self.out_channels = out_channels 30 | self.conv_op = conv_op 31 | self.conv_kwargs = self._default_conv_kwargs.copy() 32 | if conv_kwargs is not None: 33 | self.conv_kwargs.update(conv_kwargs) 34 | self.concat_hidden_and_input = concat_hidden_and_input 35 | 36 | self.add_module("forget", conv_op(in_channels, out_channels, **conv_kwargs)) 37 | self.add_module("input", conv_op(in_channels, out_channels, **conv_kwargs)) 38 | self.add_module("output", conv_op(in_channels, out_channels, **conv_kwargs)) 39 | self.add_module("state", conv_op(in_channels, out_channels, **conv_kwargs)) 40 | 41 | def forward(self, input, hidden, cell): 42 | 43 | if self.concat_hidden_and_input: 44 | input = torch.cat((input, hidden), 1) 45 | cell = torch.sigmoid(self.forget(input)) * cell + torch.sigmoid(self.input(input)) * torch.tanh(self.state(input)) 46 | hidden = torch.sigmoid(self.output(input)) * torch.tanh(cell) 47 | 48 | return hidden, cell 49 | 50 | 51 | class TowerRepresentation(ConvModule): 52 | 53 | def __init__(self, 54 | in_channels=3, 55 | query_channels=7, 56 | r_channels=256, 57 | conv_op=nn.Conv2d, 58 | activation_op=nn.ReLU, 59 | activation_kwargs={"inplace": True}, 60 | pool_op=nn.AdaptiveAvgPool2d, 61 | pool_kwargs={"output_size": 1}, 62 | **kwargs): 63 | 64 | super(TowerRepresentation, self).__init__(**kwargs) 65 | self.in_channels = in_channels 66 | self.query_channels = query_channels 67 | self.r_channels = r_channels 68 | self.conv_op = conv_op 69 | self.activation_op = activation_op 70 | self.activation_kwargs = activation_kwargs if activation_kwargs is not None else {} 71 | self.pool_op = pool_op 72 | self.pool_kwargs = pool_kwargs if pool_kwargs is not None else {} 73 | 74 | k = self.r_channels 75 | 76 | self.add_module("conv1", self.conv_op(in_channels, k, kernel_size=2, stride=2)) 77 | self.add_module("conv2", self.conv_op(k, k//2, kernel_size=3, stride=1, padding=1)) 78 | self.add_module("conv2_skip", self.conv_op(k, k, kernel_size=2, stride=2)) 79 | self.add_module("conv3", self.conv_op(k//2, k, kernel_size=2, stride=2)) 80 | 81 | self.add_module("conv4", self.conv_op(k + self.query_channels, k//2, kernel_size=3, stride=1, padding=1)) 82 | self.add_module("conv4_skip", self.conv_op(k + self.query_channels, k, kernel_size=3, stride=1, padding=1)) 83 | self.add_module("conv5", self.conv_op(k//2, k, kernel_size=3, stride=1, padding=1)) 84 | self.add_module("conv6", self.conv_op(k, k, kernel_size=1, stride=1)) 85 | 86 | self.add_module("activation", self.activation_op(**self.activation_kwargs)) 87 | if self.pool_op is not None: 88 | self.pool = self.pool_op(**self.pool_kwargs) 89 | 90 | def forward(self, image, query): 91 | 92 | image = self.activation(self.conv1(image)) 93 | skip = self.activation(self.conv2_skip(image)) 94 | image = self.activation(self.conv2(image)) 95 | image = self.activation(self.conv3(image)) + skip 96 | 97 | query = query.view(query.shape[0], -1, 1, 1).repeat(1, 1, *image.shape[2:]) 98 | 99 | image = torch.cat((image, query), 1) 100 | skip = self.activation(self.conv4_skip(image)) 101 | image = self.activation(self.conv4(image)) 102 | image = self.activation(self.conv5(image)) + skip 103 | image = self.activation(self.conv6(image)) 104 | 105 | if self.pool_op is not None: 106 | image = self.pool(image) 107 | 108 | return image 109 | 110 | 111 | class GeneratorCore(ConvModule): 112 | 113 | _default_core_kwargs = dict( 114 | conv_op=nn.Conv2d, 115 | conv_kwargs=dict(kernel_size=5, stride=1, padding=2), 116 | concat_hidden_and_input=True 117 | ) 118 | 119 | _default_upsample_kwargs = dict( 120 | padding=0, 121 | bias=False 122 | ) 123 | 124 | def __init__(self, 125 | query_channels=7, 126 | r_channels=256, 127 | z_channels=64, 128 | h_channels=128, 129 | scale=4, 130 | core_op=ConvLSTMCell, 131 | core_kwargs=None, 132 | upsample_op=nn.ConvTranspose2d, 133 | upsample_kwargs=None, 134 | **kwargs): 135 | 136 | super(GeneratorCore, self).__init__(**kwargs) 137 | 138 | self.query_channels = query_channels 139 | self.r_channels = r_channels 140 | self.z_channels = z_channels 141 | self.h_channels = h_channels 142 | self.scale = scale 143 | self.core_op = core_op 144 | self.core_kwargs = self._default_core_kwargs.copy() 145 | if core_kwargs is not None: 146 | self.core_kwargs.update(core_kwargs) 147 | self.upsample_op = upsample_op 148 | self.upsample_kwargs = self._default_upsample_kwargs.copy() 149 | if upsample_kwargs is not None: 150 | self.upsample_kwargs.update(upsample_kwargs) 151 | 152 | self.add_module("core", core_op(h_channels + query_channels + r_channels + z_channels, h_channels, **self.core_kwargs)) 153 | self.add_module("upsample", upsample_op(h_channels, h_channels, kernel_size=scale, stride=scale, **self.upsample_kwargs)) 154 | 155 | def forward(self, viewpoint_query, representation, z, cell, hidden, u): 156 | 157 | batch_size = hidden.shape[0] 158 | spatial_dims = hidden.shape[2:] 159 | 160 | # we assume all following tensors are either (b,c) or (b,c,1,1,...) 161 | # we further assume that cell, hidden & u have correct shapes 162 | if viewpoint_query.shape[2:] != spatial_dims: 163 | viewpoint_query = viewpoint_query.view(batch_size, -1, *[1, ] * len(spatial_dims)).repeat(1, 1, *spatial_dims) 164 | if representation.shape[2:] != spatial_dims: 165 | representation = representation.view(batch_size, -1, *[1, ] * len(spatial_dims)).repeat(1, 1, *spatial_dims) 166 | if z.shape[2:] != spatial_dims: 167 | z = z.view(batch_size, -1, *[1, ] * len(spatial_dims)).repeat(1, 1, *spatial_dims) 168 | 169 | cell, hidden = self.core(torch.cat([viewpoint_query, representation, z], 1), cell, hidden) 170 | u = self.upsample(hidden) + u 171 | 172 | return cell, hidden, u 173 | 174 | 175 | class InferenceCore(ConvModule): 176 | 177 | _default_core_kwargs = dict( 178 | conv_op=nn.Conv2d, 179 | conv_kwargs=dict( 180 | kernel_size=5, 181 | stride=1, 182 | padding=2 183 | ), 184 | concat_hidden_and_input=True 185 | ) 186 | 187 | _default_downsample_kwargs = dict( 188 | padding=0, 189 | bias=False 190 | ) 191 | 192 | def __init__(self, 193 | in_channels=3, 194 | query_channels=7, 195 | r_channels=256, 196 | z_channels=64, 197 | h_channels=128, 198 | scale=4, 199 | core_op=ConvLSTMCell, 200 | core_kwargs=None, 201 | downsample_op=nn.Conv2d, 202 | downsample_kwargs=None, 203 | **kwargs): 204 | 205 | super(InferenceCore, self).__init__(**kwargs) 206 | 207 | self.in_channels = in_channels 208 | self.query_channels = query_channels 209 | self.r_channels = r_channels 210 | self.z_channels = z_channels 211 | self.h_channels = h_channels 212 | self.scale = scale 213 | self.core_op = core_op 214 | self.core_kwargs = self._default_core_kwargs.copy() 215 | if core_kwargs is not None: 216 | self.core_kwargs.update(core_kwargs) 217 | self.downsample_op = downsample_op 218 | self.downsample_kwargs = self._default_downsample_kwargs.copy() 219 | if downsample_kwargs is not None: 220 | self.downsample_kwargs.update(downsample_kwargs) 221 | 222 | self.add_module("core", core_op(in_channels + query_channels + r_channels + 3*h_channels, h_channels, **self.core_kwargs)) 223 | self.add_module("downsample_image", downsample_op(in_channels, in_channels, kernel_size=scale, stride=scale, **self.downsample_kwargs)) 224 | self.add_module("downsample_u", downsample_op(h_channels, h_channels, kernel_size=scale, stride=scale, **self.downsample_kwargs)) 225 | 226 | def forward(self, image_query, viewpoint_query, representation, cell, hidden_self, hidden_gen, u): 227 | 228 | batch_size = hidden_self.shape[0] 229 | spatial_dims = hidden_self.shape[2:] 230 | 231 | # we assume all following tensors are either (b,c) or (b,c,1,1,...) 232 | # we further assume that image_query, cell, hidden_self, hidden_gen & u have correct shapes 233 | if viewpoint_query.shape[2:] != spatial_dims: 234 | viewpoint_query = viewpoint_query.view(batch_size, -1, *[1, ] * len(spatial_dims)).repeat(1, 1, *spatial_dims) 235 | if representation.shape[2:] != spatial_dims: 236 | representation = representation.view(batch_size, -1, *[1, ] * len(spatial_dims)).repeat(1, 1, *spatial_dims) 237 | 238 | image_query = self.downsample_image(image_query) 239 | u = self.downsample_u(u) 240 | cell, hidden_self = self.core(torch.cat([image_query, viewpoint_query, representation, hidden_gen, u], 1), cell, hidden_self) 241 | 242 | return cell, hidden_self 243 | 244 | 245 | class GQNDecoder(ConvModule): 246 | 247 | _default_conv_kwargs = dict( 248 | kernel_size=5, 249 | stride=1, 250 | padding=2 251 | ) 252 | 253 | def __init__(self, 254 | in_channels=3, 255 | query_channels=7, 256 | r_channels=256, 257 | z_channels=64, 258 | h_channels=128, 259 | scale=4, 260 | core_repeat=12, 261 | core_shared=False, 262 | generator_op=GeneratorCore, 263 | generator_kwargs=None, 264 | inference_op=InferenceCore, 265 | inference_kwargs=None, 266 | conv_op=nn.Conv2d, 267 | conv_kwargs=None, 268 | output_activation_op=nn.Sigmoid, 269 | output_activation_kwargs=None, 270 | **kwargs): 271 | 272 | super(GQNDecoder, self).__init__(**kwargs) 273 | self.in_channels = in_channels 274 | self.query_channels = query_channels 275 | self.r_channels = r_channels 276 | self.z_channels = z_channels 277 | self.h_channels = h_channels 278 | self.scale = scale 279 | self.core_repeat = core_repeat 280 | self.core_shared = core_shared 281 | self.generator_op = generator_op 282 | self.generator_kwargs = {} if generator_kwargs is None else generator_kwargs 283 | self.inference_op = inference_op 284 | self.inference_kwargs = {} if inference_kwargs is None else inference_kwargs 285 | self.conv_op = conv_op 286 | self.conv_kwargs = self._default_conv_kwargs.copy() 287 | if conv_kwargs is not None: 288 | self.conv_kwargs.update(conv_kwargs) 289 | self.output_activation_op = output_activation_op 290 | self.output_activation_kwargs = output_activation_kwargs if output_activation_kwargs is not None else {} 291 | 292 | for dict_ in (self.generator_kwargs, self.inference_kwargs): 293 | dict_.update(dict( 294 | query_channels=query_channels, 295 | r_channels=r_channels, 296 | z_channels=z_channels, 297 | h_channels=h_channels, 298 | scale=scale 299 | )) 300 | self.inference_kwargs["in_channels"] = in_channels 301 | 302 | if self.core_shared: 303 | self.add_module("generator_core", self.generator_op(**self.generator_kwargs)) 304 | self.add_module("inference_core", self.inference_op(**self.inference_kwargs)) 305 | else: 306 | self.add_module("generator_core", nn.ModuleList([self.generator_op(**self.generator_kwargs) for _ in range(self.core_repeat)])) 307 | self.add_module("inference_core", nn.ModuleList([self.inference_op(**self.inference_kwargs) for _ in range(self.core_repeat)])) 308 | 309 | self.add_module("posterior_net", conv_op(h_channels, 2*z_channels, **self.conv_kwargs)) 310 | self.add_module("prior_net", conv_op(h_channels, 2*z_channels, **self.conv_kwargs)) 311 | self.add_module("observation_net", conv_op(h_channels, in_channels, kernel_size=1, stride=1, padding=0)) 312 | 313 | self.add_module("output_activation", output_activation_op(**self.output_activation_kwargs)) 314 | 315 | def forward(self, representation, viewpoint_query, image_query): 316 | 317 | batch_size, _, *spatial_dims = image_query.shape 318 | spatial_dims_scaled = tuple(np.array(spatial_dims) // self.scale) 319 | kl = 0 320 | 321 | # Increase dimensions 322 | viewpoint_query = viewpoint_query.view(batch_size, -1, *[1, ] * len(spatial_dims)).repeat(1, 1, *spatial_dims_scaled) 323 | if representation.shape[2:] != spatial_dims_scaled: 324 | representation = representation.view(batch_size, -1, *[1, ] * len(spatial_dims)).repeat(1, 1, *spatial_dims_scaled) 325 | 326 | # Reset hidden state 327 | hidden_g = image_query.new_zeros((batch_size, self.h_channels, *spatial_dims_scaled)) 328 | hidden_i = image_query.new_zeros((batch_size, self.h_channels, *spatial_dims_scaled)) 329 | 330 | # Reset cell state 331 | cell_g = image_query.new_zeros((batch_size, self.h_channels, *spatial_dims_scaled)) 332 | cell_i = image_query.new_zeros((batch_size, self.h_channels, *spatial_dims_scaled)) 333 | 334 | # better name for u? 335 | u = image_query.new_zeros((batch_size, self.h_channels, *spatial_dims)) 336 | 337 | for i in range(self.core_repeat): 338 | 339 | if self.core_shared: 340 | current_generator_core = self.generator_core 341 | current_inference_core = self.inference_core 342 | else: 343 | current_generator_core = self.generator_core[i] 344 | current_inference_core = self.inference_core[i] 345 | 346 | # Prior 347 | o = self.prior_net(hidden_g) 348 | prior_mu, prior_std_pseudo = torch.split(o, self.z_channels, dim=1) 349 | prior = Normal(prior_mu, F.softplus(prior_std_pseudo)) 350 | 351 | # Inference state update 352 | cell_i, hidden_i = current_inference_core(image_query, viewpoint_query, representation, cell_i, hidden_i, hidden_g, u) 353 | 354 | # Posterior 355 | o = self.posterior_net(hidden_i) 356 | posterior_mu, posterior_std_pseudo = torch.split(o, self.z_channels, dim=1) 357 | posterior = Normal(posterior_mu, F.softplus(posterior_std_pseudo)) 358 | 359 | # Posterior sample 360 | if self.training: 361 | z = posterior.rsample() 362 | else: 363 | z = prior.loc 364 | 365 | # Generator update 366 | cell_g, hidden_g, u = current_generator_core(viewpoint_query, representation, z, cell_g, hidden_g, u) 367 | 368 | # Calculate KL-divergence 369 | kl += kl_divergence(posterior, prior) 370 | 371 | image_prediction = self.output_activation(self.observation_net(u)) 372 | 373 | return image_prediction, kl 374 | 375 | def sample(self, representation, viewpoint_query, spatial_dims=(64, 64)): 376 | 377 | batch_size = viewpoint_query.shape[0] 378 | spatial_dims_scaled = tuple(np.array(spatial_dims) // self.scale) 379 | 380 | # Increase dimensions 381 | viewpoint_query = viewpoint_query.view(batch_size, -1, 1, 1).repeat(1, 1, *spatial_dims_scaled) 382 | if representation.shape[2:] != spatial_dims_scaled: 383 | representation = representation.repeat(1, 1, *spatial_dims_scaled) 384 | 385 | # Reset hidden and cell state for generator 386 | hidden_g = viewpoint_query.new_zeros((batch_size, self.h_channels, *spatial_dims_scaled)) 387 | cell_g = viewpoint_query.new_zeros((batch_size, self.h_channels, *spatial_dims_scaled)) 388 | 389 | u = viewpoint_query.new_zeros((batch_size, self.h_channels, *spatial_dims)) 390 | 391 | for i in range(self.core_repeat): 392 | 393 | if self.core_shared: 394 | current_generator_core = self.generator_core 395 | else: 396 | current_generator_core = self.generator_core[i] 397 | 398 | o = self.prior_net(hidden_g) 399 | prior_mu, prior_std_pseudo = torch.split(o, self.z_channels, dim=1) 400 | prior = Normal(prior_mu, F.softplus(prior_std_pseudo)) 401 | 402 | # Prior sample 403 | z = prior.sample() 404 | 405 | # Update 406 | cell_g, hidden_g, u = current_generator_core(viewpoint_query, representation, z, cell_g, hidden_g, u) 407 | 408 | return self.output_activation(self.observation_net(u)) 409 | 410 | 411 | class GenerativeQueryNetwork(ConvModule): 412 | 413 | def __init__(self, 414 | in_channels=3, 415 | query_channels=7, 416 | r_channels=256, 417 | encoder_op=TowerRepresentation, 418 | encoder_kwargs=None, 419 | decoder_op=GQNDecoder, 420 | decoder_kwargs=None, 421 | **kwargs): 422 | 423 | super(GenerativeQueryNetwork, self).__init__(**kwargs) 424 | self.in_channels = in_channels 425 | self.query_channels = query_channels 426 | self.r_channels = r_channels 427 | self.encoder_op = encoder_op 428 | self.encoder_kwargs = encoder_kwargs if encoder_kwargs is not None else {} 429 | self.decoder_op = decoder_op 430 | self.decoder_kwargs = decoder_kwargs if decoder_kwargs is not None else {} 431 | 432 | self.add_module("encoder", self.encoder_op(self.in_channels, self.query_channels, self.r_channels, **self.encoder_kwargs)) 433 | self.add_module("decoder", self.decoder_op(self.in_channels, self.query_channels, self.r_channels, **self.decoder_kwargs)) 434 | 435 | @staticmethod 436 | def split_batch(images, viewpoints, num_viewpoints): 437 | 438 | # for debugging, if you want to reconstruct known images 439 | if num_viewpoints == 1: 440 | return images, viewpoints, images, viewpoints 441 | 442 | # images are (batch_size * num_viewpoints, channels, space) 443 | # viewpoints are (batch_size * num_viewpoints, channels) 444 | batch_size = images.shape[0] // num_viewpoints 445 | 446 | # separate input and target 447 | images = images.view(batch_size, num_viewpoints, *images.shape[1:]) 448 | images, image_query = images[:, :-1], images[:, -1] 449 | images = images.contiguous().view(batch_size * (num_viewpoints-1), *images.shape[2:]) 450 | 451 | viewpoints = viewpoints.view(batch_size, num_viewpoints, *viewpoints.shape[1:]) 452 | viewpoints, viewpoint_query = viewpoints[:, :-1], viewpoints[:, -1] 453 | viewpoints = viewpoints.contiguous().view(batch_size * (num_viewpoints-1), *viewpoints.shape[2:]) 454 | 455 | return images, viewpoints, image_query, viewpoint_query 456 | 457 | def encode(self, images, viewpoints, num_viewpoints): 458 | 459 | representation = self.encoder(images, viewpoints) 460 | representation = representation.view(-1, max(num_viewpoints, 1), *representation.shape[1:]) 461 | representation = representation.mean(1, keepdim=False) 462 | 463 | return representation 464 | 465 | def forward(self, images, viewpoints, num_viewpoints): 466 | """Will automatically split input and query. num_viewpoints includes the query viewpoint.""" 467 | 468 | images, viewpoints, image_query, viewpoint_query = self.split_batch(images, viewpoints, num_viewpoints) 469 | representation = self.encode(images, viewpoints, num_viewpoints - 1) 470 | image_mu, kl = self.decoder(representation, viewpoint_query, image_query) 471 | 472 | return image_mu, image_query, representation, kl 473 | 474 | def sample(self, images, viewpoints, viewpoint_query, num_viewpoints, sigma=None): 475 | 476 | if sigma is None: 477 | query_sample = viewpoint_query 478 | else: 479 | # note that this might produce incorrect viewpoints 480 | query_sample = Normal(viewpoint_query, sigma).sample() 481 | representation = self.encode(images, viewpoints, num_viewpoints) 482 | image_mu = self.decoder.sample(representation, query_sample, images.shape[2:]) 483 | 484 | return image_mu 485 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | trixi>=0.1.2.0 2 | -e git://github.com/MIC-DKFZ/batchgenerators#egg=batchgenerators 3 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import torch 4 | from torch import nn, optim, distributions 5 | import torch.backends.cudnn as cudnn 6 | cudnn.benchmark = False 7 | 8 | from batchgenerators.dataloading import MultiThreadedAugmenter 9 | from trixi.util import Config, ResultLogDict 10 | from trixi.experiment import PytorchExperiment 11 | 12 | from model import GenerativeQueryNetwork 13 | from util import get_default_experiment_parser, set_seeds, run_experiment 14 | from data import loader 15 | 16 | 17 | DESCRIPTION = """This experiment just tries to reproduce GQN results, 18 | specifically for the Shepard-Metzler-5 dataset.""" 19 | 20 | 21 | def make_defaults(): 22 | 23 | DEFAULTS = Config( 24 | 25 | # Base 26 | name="gqn", 27 | description=DESCRIPTION, 28 | n_epochs=1000000, 29 | batch_size=36, 30 | batch_size_val=36, 31 | seed=1, 32 | device="cuda", 33 | 34 | # Data 35 | split_val=3, # index for set of 5 36 | split_test=4, # index for set of 5 37 | data_module=loader, 38 | dataset="shepard_metzler_5_parts", 39 | data_dir=None, # will be set for data_module if not None 40 | debug=0, # 1 for single repeating batch, 2 for single viewpoint (i.e. reconstruct known images) 41 | generator_train=loader.RandomBatchGenerator, 42 | generator_val=loader.LinearBatchGenerator, 43 | num_viewpoints_val=8, # use this many viewpoints in validation 44 | shuffle_viewpoints_val=False, 45 | augmenter=MultiThreadedAugmenter, 46 | augmenter_kwargs={"num_processes": 8}, 47 | 48 | # Model 49 | model=GenerativeQueryNetwork, 50 | model_kwargs={ 51 | "in_channels": 3, 52 | "query_channels": 7, 53 | "r_channels": 256, 54 | "encoder_kwargs": { 55 | "activation_op": nn.ReLU 56 | }, 57 | "decoder_kwargs": { 58 | "z_channels": 64, 59 | "h_channels": 128, 60 | "scale": 4, 61 | "core_repeat": 12 62 | } 63 | }, 64 | model_init_weights_args=None, # e.g. [nn.init.kaiming_normal_, 1e-2], 65 | model_init_bias_args=None, # e.g. [nn.init.constant_, 0], 66 | 67 | # Learning 68 | optimizer=optim.Adam, 69 | optimizer_kwargs={"weight_decay": 1e-5}, 70 | lr_initial=5e-4, 71 | lr_final=5e-5, 72 | lr_cutoff=16e4, # lr is increased linearly in cutoff epochs 73 | sigma_initial=2.0, 74 | sigma_final=0.7, 75 | sigma_cutoff=2e4, # sigma is increased linearly in cutoff epochs 76 | kl_weight_initial=0.05, 77 | kl_weight_final=1.0, 78 | kl_weight_cutoff=1e5, # kl_weight is increased linearly in cutoff epochs 79 | nll_weight=1.0, 80 | 81 | # Logging 82 | backup_every=10000, 83 | validate_every=1000, 84 | validate_subset=0.01, # validate only this percentage randomly 85 | show_every=100, 86 | val_example_samples=10, # draw this many random samples for last validation item 87 | test_on_val=True, # test on the validation set 88 | 89 | ) 90 | 91 | SHAREDCORES = Config( 92 | model_kwargs={"decoder_kwargs": {"core_shared": True}} 93 | ) 94 | 95 | MODS = { 96 | "SHAREDCORES": SHAREDCORES 97 | } 98 | 99 | return {"DEFAULTS": DEFAULTS}, MODS 100 | 101 | 102 | class GQNExperiment(PytorchExperiment): 103 | 104 | def setup(self): 105 | 106 | set_seeds(self.config.seed, "cuda" in self.config.device) 107 | 108 | self.setup_data() 109 | self.setup_model() 110 | 111 | self.config.epoch_str_template = "{:0" + str(len(str(self.config.n_epochs))) + "d}" 112 | self.clog.show_text(self.model.__repr__(), "Model") 113 | 114 | def setup_data(self): 115 | 116 | c = self.config 117 | 118 | if c.data_dir is not None: 119 | c.data_module.data_dir = c.data_dir 120 | 121 | # set actual data 122 | self.data_train_val = c.data_module.load(c.dataset, "train", image_kwargs={"mmap_mode": "r"}) 123 | self.data_test = c.data_module.load(c.dataset, "test", image_kwargs={"mmap_mode": "r"}) 124 | 125 | # train, val, test split 126 | indices_split = c.data_module.split(c.dataset) 127 | indices_val = indices_split[c.split_val] 128 | indices_test = indices_split[c.split_test] 129 | indices_train = [] 130 | for i in range(5): 131 | if i not in (c.split_val, c.split_test): 132 | indices_train += indices_split[i] 133 | indices_train = sorted(indices_train) 134 | 135 | # for debugging we only use a single batch and validate on training data 136 | if c.debug > 0: 137 | indices_train = indices_train[:c.batch_size] 138 | indices_val = indices_train 139 | indices_test = indices_test[:c.batch_size_val] 140 | 141 | # construct generators 142 | self.generator_train = c.generator_train( 143 | self.data_train_val, 144 | c.batch_size, 145 | data_order=indices_train, 146 | num_viewpoints=1 if c.debug == 2 else "random", 147 | shuffle_viewpoints=not c.debug, 148 | number_of_threads_in_multithreaded=c.augmenter_kwargs.num_processes) 149 | self.generator_val = c.generator_val( 150 | self.data_train_val, 151 | c.batch_size_val, 152 | data_order=indices_val, 153 | num_viewpoints=1 if c.debug == 2 else c.num_viewpoints_val, 154 | shuffle_viewpoints=c.shuffle_viewpoints_val, 155 | number_of_threads_in_multithreaded=c.augmenter_kwargs.num_processes) 156 | self.generator_test = c.generator_val( 157 | self.data_test, 158 | c.batch_size_val, 159 | data_order=indices_test, 160 | num_viewpoints=1 if c.debug == 2 else c.num_viewpoints_val, 161 | number_of_threads_in_multithreaded=c.augmenter_kwargs.num_processes) 162 | 163 | # construct augmenters (no actual augmentation at the moment, just multithreading) 164 | self.augmenter_train = c.augmenter(self.generator_train, None, **c.augmenter_kwargs) 165 | self.augmenter_val = c.augmenter(self.generator_val, None, **c.augmenter_kwargs) 166 | self.augmenter_test = c.augmenter(self.generator_test, None, **c.augmenter_kwargs) 167 | 168 | def setup_model(self): 169 | 170 | c = self.config 171 | 172 | # intialize model and weights 173 | self.model = c.model(**c.model_kwargs) 174 | if c.model_init_weights_args is not None and hasattr(self.model, "init_weights"): 175 | self.model.init_weights(*c.model_init_weights_args) 176 | if c.model_init_bias_args is not None and hasattr(self.model, "init_bias"): 177 | import IPython 178 | IPython.embed() 179 | self.model.init_bias(*c.model_init_bias_args) 180 | 181 | # optimization 182 | self.optimizer = c.optimizer(self.model.parameters(), lr=c.lr_initial, **c.optimizer_kwargs) 183 | self.lr = c.lr_initial 184 | self.sigma = c.sigma_initial 185 | 186 | def _setup_internal(self): 187 | 188 | super(GQNExperiment, self)._setup_internal() 189 | self.elog.save_config(self.config, "config") # default PytorchExperiment only saves self._config_raw 190 | 191 | # we want a results dictionary with running mean, so close default and construct new 192 | self.results.close() 193 | self.results = ResultLogDict("results-log.json", base_dir=self.elog.result_dir, mode="w", running_mean_length=self.config.show_every) 194 | 195 | def prepare(self): 196 | 197 | # move everything to selected device 198 | for name, model in self.get_pytorch_modules().items(): 199 | model.to(self.config.device) 200 | 201 | def train(self, epoch): 202 | 203 | c = self.config 204 | 205 | t0 = time.time() 206 | 207 | # set learning rates, sigmas, loss weights 208 | self.train_prepare(epoch) 209 | 210 | # get data 211 | data = next(self.augmenter_train) 212 | data["data"] = torch.from_numpy(data["data"]).to(dtype=torch.float32, device=c.device) 213 | data["viewpoints"] = torch.from_numpy(data["viewpoints"]).to(dtype=torch.float32, device=c.device) 214 | 215 | # forward 216 | image_pred, image_query, representation, kl = self.model(data["data"], data["viewpoints"], data["num_viewpoints"]) 217 | loss_elbo, loss_nll, loss_kl = self.criterion(image_pred, image_query, kl) 218 | 219 | # backward 220 | loss_elbo.backward() 221 | self.optimizer.step() 222 | self.optimizer.zero_grad() 223 | 224 | training_time = time.time() - t0 225 | 226 | # use data dictionary as training summary 227 | data["data"] = data["data"].cpu() 228 | data["viewpoints"] = data["viewpoints"].cpu() 229 | data["image_query"] = image_query.cpu() # also in "data" but we're lazy 230 | data["image_pred"] = image_pred.cpu() 231 | data["loss_elbo"] = loss_elbo.item() 232 | data["loss_nll"] = loss_nll.item() 233 | data["loss_kl"] = loss_kl.item() 234 | data["training_time"] = training_time 235 | 236 | self.train_log(data, epoch) 237 | 238 | def train_prepare(self, epoch): 239 | 240 | c = self.config 241 | 242 | # sets parameters as is done in the paper, additionally start with lower KL weight 243 | self.lr = max(c.lr_final + (c.lr_initial - c.lr_final) * (1 - epoch / c.lr_cutoff), c.lr_final) 244 | self.sigma = max(c.sigma_final + (c.sigma_initial - c.sigma_final) * (1 - epoch / c.sigma_cutoff), c.sigma_final) 245 | _lr = self.lr * np.sqrt(1 - 0.999**(epoch+1)) / (1 - 0.9**(epoch+1)) 246 | for group in self.optimizer.param_groups: 247 | group["lr"] = _lr 248 | self.nll_weight = c.nll_weight 249 | self.kl_weight = min(c.kl_weight_final, c.kl_weight_initial + (c.kl_weight_final - c.kl_weight_initial) * epoch / c.kl_weight_cutoff) 250 | 251 | self.model.train() 252 | self.optimizer.zero_grad() 253 | 254 | def criterion(self, image_predicted, image_query, kl, batch_mean=True): 255 | 256 | # mean over batch but sum over individual 257 | nll = -distributions.Normal(image_predicted, self.sigma).log_prob(image_query) 258 | # nll = nn.MSELoss(reduction="none")(image_predicted, image_query) 259 | nll = nll.view(nll.shape[0], -1) 260 | kl = kl.view(kl.shape[0], -1) 261 | if batch_mean: 262 | nll = nll.mean(0) 263 | kl = kl.mean(0) 264 | nll = nll.sum(-1) 265 | kl = kl.sum(-1) 266 | 267 | elbo = self.nll_weight * nll + self.kl_weight * kl 268 | 269 | return elbo, nll, kl 270 | 271 | def train_log(self, summary, epoch): 272 | 273 | _backup = (epoch + 1) % self.config.backup_every == 0 274 | _show = (epoch + 1) % self.config.show_every == 0 275 | 276 | self.elog.show_text("{}/{}: {}".format(epoch, self.config.n_epochs, summary["training_time"]), name="Training Time") 277 | 278 | # add_result will show graphs and log to json file at the same time 279 | self.add_result(summary["loss_elbo"], "loss_elbo", epoch, "Loss", plot_result=_show, plot_running_mean=True) 280 | self.add_result(summary["loss_nll"], "loss_nll", epoch, "Loss", plot_result=_show, plot_running_mean=True) 281 | self.add_result(summary["loss_kl"], "loss_kl", epoch, "Loss", plot_result=_show, plot_running_mean=True) 282 | 283 | self.make_images(summary["image_query"], 284 | "reference", 285 | epoch, 286 | save=_backup, 287 | show=_show) 288 | self.make_images(summary["image_pred"], 289 | "reconstruction", 290 | epoch, 291 | save=_backup, 292 | show=_show) 293 | 294 | def validate(self, epoch): 295 | 296 | c = self.config 297 | 298 | if (epoch+1) % c.validate_every == 0: 299 | 300 | with torch.no_grad(): 301 | 302 | t0 = time.time() 303 | self.model.eval() 304 | 305 | validation_scores = [] 306 | info = {} # holds info on score array axes 307 | info["dims"] = ["Object Index", "Loss"] 308 | info["coords"] = {"Object Index": [], "Loss": ["NLL", "KL", "ELBO"]} 309 | 310 | example_output_shown = False 311 | for d, data in enumerate(self.augmenter_val): 312 | 313 | # this ensures we always validate at least one item even for very small subset ratios 314 | if c.validate_subset not in (False, None, 1.) and c.debug == 0: 315 | rand_number = np.random.rand() 316 | if rand_number < 1 - c.validate_subset: 317 | if not (d * c.batch_size_val >= len(self.generator_val) - 1 and len(validation_scores) == 0): 318 | continue 319 | 320 | # get data 321 | data["data"] = torch.from_numpy(data["data"]).to(dtype=torch.float32, device=c.device) 322 | data["viewpoints"] = torch.from_numpy(data["viewpoints"]).to(dtype=torch.float32, device=c.device) 323 | 324 | # forward 325 | image_pred, image_query, representation, kl = self.model(data["data"], data["viewpoints"], data["num_viewpoints"]) 326 | loss_elbo, loss_nll, loss_kl = self.criterion(image_pred, image_query, kl, batch_mean=False) 327 | 328 | # use data dict as summary dict 329 | data["data"] = data["data"].cpu() 330 | data["viewpoints"] = data["viewpoints"].cpu() 331 | data["image_query"] = image_query.cpu() 332 | data["image_pred"] = image_pred.cpu() 333 | data["loss_elbo"] = loss_elbo.cpu() 334 | data["loss_nll"] = loss_nll.cpu() 335 | data["loss_kl"] = loss_kl.cpu() 336 | 337 | current_scores = np.array([data["loss_nll"].cpu().numpy(), 338 | data["loss_kl"].cpu().numpy(), 339 | data["loss_elbo"].cpu().numpy()]).T 340 | validation_scores.append(current_scores) 341 | info["coords"]["Object Index"].append(data["data_indices"]) 342 | 343 | self.make_images(data["image_query"], 344 | "val/{}_reference".format(d), 345 | epoch, 346 | save=True, 347 | show=False) 348 | self.make_images(data["image_pred"], 349 | "val/{}_prediction".format(d), 350 | epoch, 351 | save=True, 352 | show=False) 353 | self.make_images(data["data"], 354 | "val/{}_seen".format(d), 355 | epoch, 356 | save=True, 357 | show=False, 358 | images_per_row=data["data"].shape[0] // c.batch_size_val) 359 | 360 | # only show one validation item 361 | if not example_output_shown: 362 | self.make_images(data["image_query"], 363 | "val_reference", 364 | epoch, 365 | save=False, 366 | show=True) 367 | self.make_images(data["image_pred"], 368 | "val_prediction", 369 | epoch, 370 | save=False, 371 | show=True) 372 | self.make_images(data["data"], 373 | "val_seen", 374 | epoch, 375 | save=False, 376 | show=True, 377 | images_per_row=data["data"].shape[0] // c.batch_size_val) 378 | example_output_shown = True 379 | 380 | validation_time = time.time() - t0 381 | validation_scores = np.concatenate(validation_scores, 0) 382 | info["coords"]["Object Index"] = np.concatenate(info["coords"]["Object Index"], 0) 383 | 384 | # there can be duplicates in the last batch 385 | for i in range(c.batch_size_val): 386 | if info["coords"]["Object Index"][-(i+1)] not in info["coords"]["Object Index"][:-(i+1)]: 387 | break 388 | if i > 0: 389 | validation_scores = validation_scores[:-i] 390 | info["coords"]["Object Index"] = info["coords"]["Object Index"][:-i] 391 | 392 | summary = {} 393 | summary["validation_time"] = validation_time 394 | summary["validation_scores"] = validation_scores 395 | summary["validation_info"] = info 396 | 397 | self.validate_log(summary, epoch) 398 | 399 | # draw a few different samples for the last data item 400 | # item could have been skipped, so we might need to transfer again 401 | if c.val_example_samples > 0: 402 | if isinstance(data["data"], np.ndarray): 403 | data["data"] = torch.from_numpy(data["data"]).to(dtype=torch.float32, device=c.device) 404 | data["viewpoints"] = torch.from_numpy(data["viewpoints"]).to(dtype=torch.float32, device=c.device) 405 | images_context, viewpoints_context, _, viewpoint_query =\ 406 | self.model.split_batch(data["data"], 407 | data["viewpoints"], 408 | data["num_viewpoints"]) 409 | images_context = images_context.to(device=c.device) 410 | viewpoints_context = viewpoints_context.to(device=c.device) 411 | viewpoint_query = viewpoint_query.to(device=c.device) 412 | 413 | samples = [] 414 | for i in range(c.val_example_samples): 415 | samples.append(self.model.sample(images_context, viewpoints_context, viewpoint_query, data["num_viewpoints"] - 1, self.sigma).cpu()) 416 | samples = torch.cat(samples, 0) 417 | # samples should now be (batch * samples, 3, 64, 64) 418 | self.make_images(samples, "samples", epoch, save=True, show=True, images_per_row=c.batch_size_val) 419 | 420 | def validate_log(self, summary, epoch): 421 | 422 | epoch_str = self.config.epoch_str_template.format(epoch) 423 | validation_scores_mean = np.nanmean(summary["validation_scores"], 0) 424 | 425 | self.elog.save_numpy_data(summary["validation_scores"], "validation/{}.npy".format(epoch_str)) 426 | self.elog.save_dict(summary["validation_info"], "validation/{}.json".format(epoch_str)) 427 | self.elog.show_text("{}/{}: {}".format(epoch, self.config.n_epochs, summary["validation_time"]), name="Validation Time") 428 | 429 | self.add_result(float(validation_scores_mean[2]), "loss_elbo_val", epoch, "Loss") 430 | self.add_result(float(validation_scores_mean[0]), "loss_nll_val", epoch, "Loss") 431 | self.add_result(float(validation_scores_mean[1]), "loss_kl_val", epoch, "Loss") 432 | 433 | def _end_epoch_internal(self, epoch): 434 | 435 | self.save_results() 436 | if (epoch+1) % self.config.backup_every == 0: 437 | self.save_temp_checkpoint() 438 | 439 | def make_images(self, 440 | images, 441 | name, 442 | epoch, 443 | save=False, 444 | show=True, 445 | images_per_row=None): 446 | 447 | n_images = images.shape[0] 448 | if images_per_row is None: 449 | images_per_row = int(np.sqrt(n_images)) 450 | 451 | if show and self.vlog is not None: 452 | self.vlog.show_image_grid(images, name, 453 | image_args={"normalize": True, 454 | "nrow": images_per_row, 455 | "pad_value": 1}) 456 | if save and self.elog is not None: 457 | name = self.config.epoch_str_template.format(epoch) + "/" + name 458 | self.elog.show_image_grid(images, name, 459 | image_args={"normalize": True, 460 | "nrow": images_per_row, 461 | "pad_value": 1}) 462 | 463 | def test(self): 464 | 465 | pass 466 | 467 | 468 | if __name__ == '__main__': 469 | 470 | parser = get_default_experiment_parser() 471 | args, _ = parser.parse_known_args() 472 | DEFAULTS, MODS = make_defaults() 473 | run_experiment(GQNExperiment, 474 | DEFAULTS, 475 | args, 476 | mods=MODS, 477 | explogger_kwargs=dict(folder_format="{experiment_name}_%Y%m%d-%H%M%S"), 478 | globs=globals(), 479 | resume_save_types=("model", "simple", "th_vars", "results")) 480 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | from trixi.util import Config, GridSearch 7 | 8 | 9 | class ConvModule(nn.Module): 10 | """Utility Module for more convenient weight initialization""" 11 | 12 | conv_types = (nn.Conv1d, 13 | nn.Conv2d, 14 | nn.Conv3d, 15 | nn.ConvTranspose1d, 16 | nn.ConvTranspose2d, 17 | nn.ConvTranspose3d) 18 | 19 | @classmethod 20 | def is_conv(cls, op): 21 | 22 | if type(op) == type and issubclass(op, cls.conv_types): 23 | return True 24 | elif type(op) in cls.conv_types: 25 | return True 26 | else: 27 | return False 28 | 29 | def __init__(self, *args, **kwargs): 30 | 31 | super(ConvModule, self).__init__(*args, **kwargs) 32 | 33 | def init_weights(self, init_fn, *args, **kwargs): 34 | 35 | class init_(object): 36 | 37 | def __init__(self): 38 | self.fn = init_fn 39 | self.args = args 40 | self.kwargs = kwargs 41 | 42 | def __call__(self, module): 43 | if ConvModule.is_conv(type(module)): 44 | module.weight = self.fn(module.weight, *self.args, **self.kwargs) 45 | 46 | _init_ = init_() 47 | self.apply(_init_) 48 | 49 | def init_bias(self, init_fn, *args, **kwargs): 50 | 51 | class init_(object): 52 | 53 | def __init__(self): 54 | self.fn = init_fn 55 | self.args = args 56 | self.kwargs = kwargs 57 | 58 | def __call__(self, module): 59 | if ConvModule.is_conv(type(module)) and module.bias is not None: 60 | module.bias = self.fn(module.bias, *self.args, **self.kwargs) 61 | 62 | _init_ = init_() 63 | self.apply(_init_) 64 | 65 | 66 | def get_default_experiment_parser(): 67 | 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument("base_dir", type=str, help="Working directory for experiment.") 70 | parser.add_argument("-c", "--config", type=str, default=None, help="Path to a config file.") 71 | parser.add_argument("-v", "--visdomlogger", action="store_true", help="Use visdomlogger.") 72 | parser.add_argument("-dc", "--default_config", type=str, default="DEFAULTS", help="Select a default Config") 73 | parser.add_argument("--resume", type=str, default=None, help="Path to resume from") 74 | parser.add_argument("-ir", "--ignore_resume_config", action="store_true", help="Ignore Config in experiment we resume from.") 75 | parser.add_argument("--test", action="store_true", help="Run test instead of training") 76 | parser.add_argument("--grid", type=str, help="Path to a config for grid search") 77 | parser.add_argument("-s", "--skip_existing", action="store_true", help="Skip configs fpr which an experiment exists") 78 | parser.add_argument("-m", "--mods", type=str, nargs="+", default=None, help="Mods are Config stubs to update only relevant parts for a certain setup.") 79 | 80 | return parser 81 | 82 | 83 | def run_experiment(experiment, configs, args, mods=None, **kwargs): 84 | 85 | config = Config(file_=args.config) if args.config is not None else Config() 86 | config.update_missing(configs[args.default_config]) 87 | if args.mods is not None: 88 | for mod in args.mods: 89 | config.update(mods[mod]) 90 | config = Config(config=config, update_from_argv=True) 91 | 92 | # GET EXISTING EXPERIMENTS TO BE ABLE TO SKIP CERTAIN CONFIGS 93 | if args.skip_existing: 94 | existing_configs = [] 95 | for exp in os.listdir(args.base_dir): 96 | try: 97 | existing_configs.append(Config(file_=os.path.join(args.base_dir, exp, "config", "config.json"))) 98 | except Exception as e: 99 | pass 100 | 101 | if args.grid is not None: 102 | grid = GridSearch().read(args.grid) 103 | else: 104 | grid = [{}] 105 | 106 | for combi in grid: 107 | 108 | config.update(combi) 109 | 110 | if args.skip_existing: 111 | skip_this = False 112 | for existing_config in existing_configs: 113 | if existing_config.contains(config): 114 | skip_this = True 115 | break 116 | if skip_this: 117 | continue 118 | 119 | loggers = {} 120 | if args.visdomlogger: 121 | loggers["visdom"] = ("visdom", {}, 1) 122 | 123 | exp = experiment(config=config, 124 | base_dir=args.base_dir, 125 | resume=args.resume, 126 | ignore_resume_config=args.ignore_resume_config, 127 | loggers=loggers, 128 | **kwargs) 129 | 130 | if not args.test: 131 | exp.run() 132 | else: 133 | exp.run_test() 134 | 135 | 136 | def set_seeds(seed, cuda=True): 137 | 138 | if not hasattr(seed, "__iter__"): 139 | seed = (seed, seed, seed) 140 | np.random.seed(seed[0]) 141 | torch.manual_seed(seed[1]) 142 | if cuda: torch.cuda.manual_seed_all(seed[2]) 143 | --------------------------------------------------------------------------------