├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── custom_layers.py ├── data_util.py ├── dataio.py ├── environment.yml ├── geometry.py ├── hyperlayers.py ├── srns.py ├── test.py ├── test_configs ├── cars_few_shot_novel_view.yml └── cars_training_set_novel_view.yml ├── train.py ├── train_configs ├── cars.yml ├── cars_one_shot.yml ├── cars_two_shot.yml ├── chairs.yml └── shepard_metzler.yml └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "pytorch_prototyping"] 2 | path = pytorch_prototyping 3 | url = https://github.com/vsitzmann/pytorch_prototyping.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Vincent Sitzmann 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Scene Representation Networks 2 | 3 | [![Paper](http://img.shields.io/badge/paper-arxiv.1001.2234-B31B1B.svg)](https://arxiv.org/abs/1906.01618) 4 | [![Conference](http://img.shields.io/badge/NeurIPS-2019-4b44ce.svg)]() 5 | 6 | This is the official implementation of the NeurIPS submission "Scene Representation Networks: 7 | Continuous 3D-Structure-Aware Neural Scene Representations" 8 | 9 | Scene Representation Networks (SRNs) are a continuous, 3D-structure-aware scene representation that encodes both geometry and appearance. 10 | SRNs represent scenes as continuous functions that map world coordinates to a feature representation of local scene properties. 11 | By formulating the image formation as a neural, 3D-aware rendering algorithm, SRNs can be trained end-to-end from only 2D observations, 12 | without access to depth or geometry. SRNs do not discretize space, smoothly parameterizing scene surfaces, and their 13 | memory complexity does not scale directly with scene resolution. This formulation naturally generalizes across scenes, 14 | learning powerful geometry and appearance priors in the process. 15 | 16 | [![srns_video](https://img.youtube.com/vi/6vMEBWD8O20/0.jpg)](https://youtu.be/6vMEBWD8O20f) 17 | 18 | ## Usage 19 | ### Installation 20 | This code was tested with python 3.7 and pytorch 1.2. I recommend using anaconda for dependency management. 21 | You can create an environment with name "srns" with all dependencies like so: 22 | ``` 23 | conda env create -f environment.yml 24 | ``` 25 | 26 | This repository depends on a git submodule, [pytorch-prototyping](https://github.com/vsitzmann/pytorch_prototyping). 27 | To clone both the main repo and the submodule, use 28 | ``` 29 | git clone --recurse-submodules https://github.com/vsitzmann/scene-representation-networks.git 30 | ``` 31 | 32 | ### High-Level structure 33 | The code is organized as follows: 34 | * dataio.py loads training and testing data. 35 | * data_util.py and util.py contain utility functions. 36 | * train.py contains the training code. 37 | * test.py contains the testing code. 38 | * srns.py contains the core SRNs model. 39 | * hyperlayers.py contains implementations of different hypernetworks. 40 | * custom_layers.py contains implementations of the raymarcher and the DeepVoxels U-Net renderer. 41 | * geometry.py contains utility functions for 3D and projective geometry. 42 | * util.py contains misc utility functions. 43 | 44 | ### Pre-Trained models 45 | There are pre-trained models for the shapenet car and chair datasets available, including tensorboard event files of the 46 | full training process. 47 | 48 | Please download them [here](https://drive.google.com/open?id=1IdOywOSLuK6WlkO5_h-ykr3ubeY9eDig). 49 | 50 | The checkpoint is in the "checkpoints" directory - to load weights from the checkpoint, simply pass the full path to the checkpoint 51 | to the "--checkpoint_path" command-line argument. 52 | 53 | To inspect the progress of how I trained these models, run tensorboard in the "events" subdirectory. 54 | 55 | ### Data 56 | Four different datasets appear in the paper: 57 | * Shapenet v2 chairs and car classes. 58 | * Shepard-Metzler objects. 59 | * Bazel face dataset. 60 | 61 | Please download the datasets [here](https://drive.google.com/drive/folders/1OkYgeRcIcLOFu1ft5mRODWNQaPJ0ps90?usp=sharing). 62 | 63 | ### Rendering your own datasets 64 | I have put together a few scripts for the Blender python interface that make it easy to render your own dataset. Please find them [here](https://github.com/vsitzmann/shapenet_renderer/blob/master/shapenet_spherical_renderer.py). 65 | 66 | ### Coordinate and camera parameter conventions 67 | This code uses an "OpenCV" style camera coordinate system, where the Y-axis points downwards (the up-vector points in the negative Y-direction), 68 | the X-axis points right, and the Z-axis points into the image plane. Camera poses are assumed to be in a "camera2world" format, 69 | i.e., they denote the matrix transform that transforms camera coordinates to world coordinates. 70 | 71 | The code also reads an "intrinsics.txt" file from the dataset directory. This file is expected to be structured as follows (unnamed constants are unused): 72 | ``` 73 | f cx cy 0. 74 | 0. 0. 0. 75 | 1. 76 | img_height img_width 77 | ``` 78 | The focal length, cx and cy are in pixels. Height and width are the resolution of the image. 79 | 80 | ### Training 81 | See `python train.py --help` for all train options. 82 | Example train call: 83 | ``` 84 | python train.py --data_root [path to directory with dataset] \ 85 | --val_root [path to directory with train_val dataset] \ 86 | --logging_root [path to directory where tensorboard summaries and checkpoints should be written to] 87 | ``` 88 | To monitor progress, the training code writes tensorboard summaries every 100 steps into a "events" subdirectory in the logging_root. 89 | 90 | For experiments described in the paper, config-files are available that configure the command-line flags according to 91 | the settings in the paper. You only need to edit the dataset path. Example call: 92 | ``` 93 | [edit train_configs/cars.yml to point to the correct dataset and logging paths] 94 | python train.py --config_filepath train_configs/cars.yml 95 | ``` 96 | 97 | ### Testing 98 | Example test call: 99 | ``` 100 | python test.py --data_root [path to directory with dataset] ] \ 101 | --logging_root [path to directoy where test output should be written to] \ 102 | --num_instances [number of instances in training set (for instance, 2433 for shapenet cars)] \ 103 | --checkpoint [path to checkpoint] 104 | ``` 105 | Again, for experiments described in the paper, config-files are available that configure the command-line flags according to 106 | the settings in the paper. Example call: 107 | ``` 108 | [edit test_configs/cars.yml to point to the correct dataset and logging paths] 109 | python test.py --config_filepath test_configs/cars_training_set_novel_view.yml 110 | ``` 111 | 112 | ## Misc 113 | ### Citation 114 | If you find our work useful in your research, please cite: 115 | ``` 116 | @inproceedings{sitzmann2019srns, 117 | author = {Sitzmann, Vincent 118 | and Zollh{\"o}fer, Michael 119 | and Wetzstein, Gordon}, 120 | title = {Scene Representation Networks: Continuous 3D-Structure-Aware Neural Scene Representations}, 121 | booktitle = {Advances in Neural Information Processing Systems}, 122 | year={2019} 123 | } 124 | ``` 125 | 126 | ### Submodule "pytorch_prototyping" 127 | The code in the subdirectory "pytorch_prototyping" comes from a library of custom pytorch modules that I use throughout my 128 | research projects. You can find it [here](https://github.com/vsitzmann/pytorch_prototyping). 129 | 130 | ### Contact 131 | If you have any questions, please email Vincent Sitzmann at sitzmann@cs.stanford.edu. 132 | -------------------------------------------------------------------------------- /custom_layers.py: -------------------------------------------------------------------------------- 1 | import geometry 2 | import torchvision 3 | import util 4 | 5 | from pytorch_prototyping import pytorch_prototyping 6 | 7 | import torch 8 | from torch import nn 9 | 10 | 11 | def init_recurrent_weights(self): 12 | for m in self.modules(): 13 | if type(m) in [nn.GRU, nn.LSTM, nn.RNN]: 14 | for name, param in m.named_parameters(): 15 | if 'weight_ih' in name: 16 | nn.init.kaiming_normal_(param.data) 17 | elif 'weight_hh' in name: 18 | nn.init.orthogonal_(param.data) 19 | elif 'bias' in name: 20 | param.data.fill_(0) 21 | 22 | 23 | def lstm_forget_gate_init(lstm_layer): 24 | for name, parameter in lstm_layer.named_parameters(): 25 | if not "bias" in name: continue 26 | n = parameter.size(0) 27 | start, end = n // 4, n // 2 28 | parameter.data[start:end].fill_(1.) 29 | 30 | 31 | def clip_grad_norm_hook(x, max_norm=10): 32 | total_norm = x.norm() 33 | total_norm = total_norm ** (1 / 2.) 34 | clip_coef = max_norm / (total_norm + 1e-6) 35 | if clip_coef < 1: 36 | return x * clip_coef 37 | 38 | 39 | class DepthSampler(nn.Module): 40 | def __init__(self): 41 | super().__init__() 42 | 43 | def forward(self, 44 | xy, 45 | depth, 46 | cam2world, 47 | intersection_net, 48 | intrinsics): 49 | self.logs = list() 50 | 51 | batch_size, _, _ = cam2world.shape 52 | 53 | intersections = geometry.world_from_xy_depth(xy=xy, depth=depth, cam2world=cam2world, intrinsics=intrinsics) 54 | 55 | depth = geometry.depth_from_world(intersections, cam2world) 56 | 57 | if self.training: 58 | print(depth.min(), depth.max()) 59 | 60 | return intersections, depth 61 | 62 | 63 | class Raymarcher(nn.Module): 64 | def __init__(self, 65 | num_feature_channels, 66 | raymarch_steps): 67 | super().__init__() 68 | 69 | self.n_feature_channels = num_feature_channels 70 | self.steps = raymarch_steps 71 | 72 | hidden_size = 16 73 | self.lstm = nn.LSTMCell(input_size=self.n_feature_channels, 74 | hidden_size=hidden_size) 75 | 76 | self.lstm.apply(init_recurrent_weights) 77 | lstm_forget_gate_init(self.lstm) 78 | 79 | self.out_layer = nn.Linear(hidden_size, 1) 80 | self.counter = 0 81 | 82 | def forward(self, 83 | cam2world, 84 | phi, 85 | uv, 86 | intrinsics): 87 | batch_size, num_samples, _ = uv.shape 88 | log = list() 89 | 90 | ray_dirs = geometry.get_ray_directions(uv, 91 | cam2world=cam2world, 92 | intrinsics=intrinsics) 93 | 94 | initial_depth = torch.zeros((batch_size, num_samples, 1)).normal_(mean=0.05, std=5e-4).cuda() 95 | init_world_coords = geometry.world_from_xy_depth(uv, 96 | initial_depth, 97 | intrinsics=intrinsics, 98 | cam2world=cam2world) 99 | 100 | world_coords = [init_world_coords] 101 | depths = [initial_depth] 102 | states = [None] 103 | 104 | for step in range(self.steps): 105 | v = phi(world_coords[-1]) 106 | 107 | state = self.lstm(v.view(-1, self.n_feature_channels), states[-1]) 108 | 109 | if state[0].requires_grad: 110 | state[0].register_hook(lambda x: x.clamp(min=-10, max=10)) 111 | 112 | signed_distance = self.out_layer(state[0]).view(batch_size, num_samples, 1) 113 | new_world_coords = world_coords[-1] + ray_dirs * signed_distance 114 | 115 | states.append(state) 116 | world_coords.append(new_world_coords) 117 | 118 | depth = geometry.depth_from_world(world_coords[-1], cam2world) 119 | 120 | if self.training: 121 | print("Raymarch step %d: Min depth %0.6f, max depth %0.6f" % 122 | (step, depths[-1].min().detach().cpu().numpy(), depths[-1].max().detach().cpu().numpy())) 123 | 124 | depths.append(depth) 125 | 126 | if not self.counter % 100: 127 | # Write tensorboard summary for each step of ray-marcher. 128 | drawing_depths = torch.stack(depths, dim=0)[:, 0, :, :] 129 | drawing_depths = util.lin2img(drawing_depths).repeat(1, 3, 1, 1) 130 | log.append(('image', 'raycast_progress', 131 | torch.clamp(torchvision.utils.make_grid(drawing_depths, scale_each=False, normalize=True), 0.0, 132 | 5), 133 | 100)) 134 | 135 | # Visualize residual step distance (i.e., the size of the final step) 136 | fig = util.show_images([util.lin2img(signed_distance)[i, :, :, :].detach().cpu().numpy().squeeze() 137 | for i in range(batch_size)]) 138 | log.append(('figure', 'stopping_distances', fig, 100)) 139 | self.counter += 1 140 | 141 | return world_coords[-1], depths[-1], log 142 | 143 | 144 | class DeepvoxelsRenderer(nn.Module): 145 | def __init__(self, 146 | nf0, 147 | in_channels, 148 | input_resolution, 149 | img_sidelength): 150 | super().__init__() 151 | 152 | self.nf0 = nf0 153 | self.in_channels = in_channels 154 | self.input_resolution = input_resolution 155 | self.img_sidelength = img_sidelength 156 | 157 | self.num_down_unet = util.num_divisible_by_2(input_resolution) 158 | self.num_upsampling = util.num_divisible_by_2(img_sidelength) - self.num_down_unet 159 | 160 | self.build_net() 161 | 162 | def build_net(self): 163 | self.net = [ 164 | pytorch_prototyping.Unet(in_channels=self.in_channels, 165 | out_channels=3 if self.num_upsampling <= 0 else 4 * self.nf0, 166 | outermost_linear=True if self.num_upsampling <= 0 else False, 167 | use_dropout=True, 168 | dropout_prob=0.1, 169 | nf0=self.nf0 * (2 ** self.num_upsampling), 170 | norm=nn.BatchNorm2d, 171 | max_channels=8 * self.nf0, 172 | num_down=self.num_down_unet) 173 | ] 174 | 175 | if self.num_upsampling > 0: 176 | self.net += [ 177 | pytorch_prototyping.UpsamplingNet(per_layer_out_ch=self.num_upsampling * [self.nf0], 178 | in_channels=4 * self.nf0, 179 | upsampling_mode='transpose', 180 | use_dropout=True, 181 | dropout_prob=0.1), 182 | pytorch_prototyping.Conv2dSame(self.nf0, out_channels=self.nf0 // 2, kernel_size=3, bias=False), 183 | nn.BatchNorm2d(self.nf0 // 2), 184 | nn.ReLU(True), 185 | pytorch_prototyping.Conv2dSame(self.nf0 // 2, 3, kernel_size=3) 186 | ] 187 | 188 | self.net += [nn.Tanh()] 189 | self.net = nn.Sequential(*self.net) 190 | 191 | def forward(self, input): 192 | batch_size, _, ch = input.shape 193 | input = input.permute(0, 2, 1).view(batch_size, ch, self.img_sidelength, self.img_sidelength) 194 | out = self.net(input) 195 | return out.view(batch_size, 3, -1).permute(0, 2, 1) 196 | -------------------------------------------------------------------------------- /data_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import cv2 3 | import numpy as np 4 | import imageio 5 | from glob import glob 6 | import os 7 | import shutil 8 | import skimage 9 | import pandas as pd 10 | 11 | 12 | def load_rgb(path, sidelength=None): 13 | img = imageio.imread(path)[:, :, :3] 14 | img = skimage.img_as_float32(img) 15 | 16 | img = square_crop_img(img) 17 | 18 | if sidelength is not None: 19 | img = cv2.resize(img, (sidelength, sidelength), interpolation=cv2.INTER_AREA) 20 | 21 | img -= 0.5 22 | img *= 2. 23 | img = img.transpose(2, 0, 1) 24 | return img 25 | 26 | 27 | def load_depth(path, sidelength=None): 28 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) 29 | 30 | if sidelength is not None: 31 | img = cv2.resize(img, (sidelength, sidelength), interpolation=cv2.INTER_NEAREST) 32 | 33 | img *= 1e-4 34 | 35 | if len(img.shape) == 3: 36 | img = img[:, :, :1] 37 | img = img.transpose(2, 0, 1) 38 | else: 39 | img = img[None, :, :] 40 | return img 41 | 42 | 43 | def load_pose(filename): 44 | lines = open(filename).read().splitlines() 45 | if len(lines) == 1: 46 | pose = np.zeros((4, 4), dtype=np.float32) 47 | for i in range(16): 48 | pose[i // 4, i % 4] = lines[0].split(" ")[i] 49 | return pose.squeeze() 50 | else: 51 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines[:4])] 52 | return np.asarray(lines).astype(np.float32).squeeze() 53 | 54 | 55 | def load_params(filename): 56 | lines = open(filename).read().splitlines() 57 | 58 | params = np.array([float(x) for x in lines[0].split()]).astype(np.float32).squeeze() 59 | return params 60 | 61 | 62 | def cond_mkdir(path): 63 | if not os.path.exists(path): 64 | os.makedirs(path) 65 | 66 | 67 | def square_crop_img(img): 68 | min_dim = np.amin(img.shape[:2]) 69 | center_coord = np.array(img.shape[:2]) // 2 70 | img = img[center_coord[0] - min_dim // 2:center_coord[0] + min_dim // 2, 71 | center_coord[1] - min_dim // 2:center_coord[1] + min_dim // 2] 72 | return img 73 | 74 | 75 | def train_val_split(object_dir, train_dir, val_dir): 76 | dirs = [os.path.join(object_dir, x) for x in ['pose', 'rgb', 'depth']] 77 | data_lists = [sorted(glob(os.path.join(dir, x))) 78 | for dir, x in zip(dirs, ['*.txt', "*.png", "*.png"])] 79 | 80 | cond_mkdir(train_dir) 81 | cond_mkdir(val_dir) 82 | 83 | [cond_mkdir(os.path.join(train_dir, x)) for x in ['pose', 'rgb', 'depth']] 84 | [cond_mkdir(os.path.join(val_dir, x)) for x in ['pose', 'rgb', 'depth']] 85 | 86 | shutil.copy(os.path.join(object_dir, 'intrinsics.txt'), os.path.join(val_dir, 'intrinsics.txt')) 87 | shutil.copy(os.path.join(object_dir, 'intrinsics.txt'), os.path.join(train_dir, 'intrinsics.txt')) 88 | 89 | for data_name, data_ending, data_list in zip(['pose', 'rgb', 'depth'], ['.txt', '.png', '.png'], data_lists): 90 | val_counter = 0 91 | train_counter = 0 92 | for i, item in enumerate(data_list): 93 | if not i % 3: 94 | shutil.copy(item, os.path.join(train_dir, data_name, "%06d" % train_counter + data_ending)) 95 | train_counter += 1 96 | else: 97 | shutil.copy(item, os.path.join(val_dir, data_name, "%06d" % val_counter + data_ending)) 98 | val_counter += 1 99 | 100 | 101 | def glob_imgs(path): 102 | imgs = [] 103 | for ext in ['*.png', '*.jpg', '*.JPEG', '*.JPG']: 104 | imgs.extend(glob(os.path.join(path, ext))) 105 | return imgs 106 | 107 | 108 | def read_view_direction_rays(direction_file): 109 | img = cv2.imread(direction_file, cv2.IMREAD_UNCHANGED).astype(np.float32) 110 | img -= 40000 111 | img /= 10000 112 | return img 113 | 114 | 115 | def shapenet_train_test_split(shapenet_path, synset_id, name, csv_path): 116 | ''' 117 | 118 | :param synset_id: synset ID as a string. 119 | :param name: 120 | :param csv_path: 121 | :return: 122 | ''' 123 | parsed_csv = pd.read_csv(filepath_or_buffer=csv_path) 124 | synset_df = parsed_csv[parsed_csv['synsetId'] == int(synset_id)] 125 | 126 | train = synset_df[synset_df['split'] == 'train'] 127 | val = synset_df[synset_df['split'] == 'val'] 128 | test = synset_df[synset_df['split'] == 'test'] 129 | print(len(train), len(val), len(test)) 130 | 131 | train_path, val_path, test_path = [os.path.join(shapenet_path, str(synset_id) + '_' + name + '_' + x) 132 | for x in ['train', 'val', 'test']] 133 | cond_mkdir(train_path) 134 | cond_mkdir(val_path) 135 | cond_mkdir(test_path) 136 | 137 | for split_df, trgt_path in zip([train, val, test], [train_path, val_path, test_path]): 138 | for row_no, row in split_df.iterrows(): 139 | try: 140 | shutil.copytree(os.path.join(shapenet_path, str(synset_id), str(row.modelId)), 141 | os.path.join(shapenet_path, trgt_path, str(row.modelId))) 142 | except FileNotFoundError: 143 | print("%s does not exist" % str(row.modelId)) 144 | 145 | 146 | def transform_viewpoint(v): 147 | """Transforms the viewpoint vector into a consistent representation""" 148 | 149 | return np.concatenate([v[:, :3], 150 | np.cos(v[:, 3:4]), 151 | np.sin(v[:, 3:4]), 152 | np.cos(v[:, 4:5]), 153 | np.sin(v[:, 4:5])], 1) 154 | 155 | 156 | def euler2mat(z=0, y=0, x=0): 157 | Ms = [] 158 | if z: 159 | cosz = np.cos(z) 160 | sinz = np.sin(z) 161 | Ms.append(np.array( 162 | [[cosz, -sinz, 0], 163 | [sinz, cosz, 0], 164 | [0, 0, 1]])) 165 | if y: 166 | cosy = np.cos(y) 167 | siny = np.sin(y) 168 | Ms.append(np.array( 169 | [[cosy, 0, siny], 170 | [0, 1, 0], 171 | [-siny, 0, cosy]])) 172 | if x: 173 | cosx = np.cos(x) 174 | sinx = np.sin(x) 175 | Ms.append(np.array( 176 | [[1, 0, 0], 177 | [0, cosx, -sinx], 178 | [0, sinx, cosx]])) 179 | if Ms: 180 | return functools.reduce(np.dot, Ms[::-1]) 181 | return np.eye(3) 182 | 183 | 184 | def look_at(vec_pos, vec_look_at): 185 | z = vec_look_at - vec_pos 186 | z = z / np.linalg.norm(z) 187 | 188 | x = np.cross(z, np.array([0., 1., 0.])) 189 | x = x / np.linalg.norm(x) 190 | 191 | y = np.cross(x, z) 192 | y = y / np.linalg.norm(y) 193 | 194 | view_mat = np.zeros((3, 3)) 195 | 196 | view_mat[:3, 0] = x 197 | view_mat[:3, 1] = y 198 | view_mat[:3, 2] = z 199 | 200 | return view_mat 201 | -------------------------------------------------------------------------------- /dataio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from glob import glob 5 | import data_util 6 | import util 7 | 8 | 9 | def pick(list, item_idcs): 10 | if not list: 11 | return list 12 | return [list[i] for i in item_idcs] 13 | 14 | 15 | class SceneInstanceDataset(): 16 | """This creates a dataset class for a single object instance (such as a single car).""" 17 | 18 | def __init__(self, 19 | instance_idx, 20 | instance_dir, 21 | specific_observation_idcs=None, # For few-shot case: Can pick specific observations only 22 | img_sidelength=None, 23 | num_images=-1): 24 | self.instance_idx = instance_idx 25 | self.img_sidelength = img_sidelength 26 | self.instance_dir = instance_dir 27 | 28 | color_dir = os.path.join(instance_dir, "rgb") 29 | pose_dir = os.path.join(instance_dir, "pose") 30 | param_dir = os.path.join(instance_dir, "params") 31 | 32 | if not os.path.isdir(color_dir): 33 | print("Error! root dir %s is wrong" % instance_dir) 34 | return 35 | 36 | self.has_params = os.path.isdir(param_dir) 37 | self.color_paths = sorted(data_util.glob_imgs(color_dir)) 38 | self.pose_paths = sorted(glob(os.path.join(pose_dir, "*.txt"))) 39 | 40 | if self.has_params: 41 | self.param_paths = sorted(glob(os.path.join(param_dir, "*.txt"))) 42 | else: 43 | self.param_paths = [] 44 | 45 | if specific_observation_idcs is not None: 46 | self.color_paths = pick(self.color_paths, specific_observation_idcs) 47 | self.pose_paths = pick(self.pose_paths, specific_observation_idcs) 48 | self.param_paths = pick(self.param_paths, specific_observation_idcs) 49 | elif num_images != -1: 50 | idcs = np.linspace(0, stop=len(self.color_paths), num=num_images, endpoint=False, dtype=int) 51 | self.color_paths = pick(self.color_paths, idcs) 52 | self.pose_paths = pick(self.pose_paths, idcs) 53 | self.param_paths = pick(self.param_paths, idcs) 54 | 55 | def set_img_sidelength(self, new_img_sidelength): 56 | """For multi-resolution training: Updates the image sidelength with whichimages are loaded.""" 57 | self.img_sidelength = new_img_sidelength 58 | 59 | def __len__(self): 60 | return len(self.pose_paths) 61 | 62 | def __getitem__(self, idx): 63 | intrinsics, _, _, _ = util.parse_intrinsics(os.path.join(self.instance_dir, "intrinsics.txt"), 64 | trgt_sidelength=self.img_sidelength) 65 | intrinsics = torch.Tensor(intrinsics).float() 66 | 67 | rgb = data_util.load_rgb(self.color_paths[idx], sidelength=self.img_sidelength) 68 | rgb = rgb.reshape(3, -1).transpose(1, 0) 69 | 70 | pose = data_util.load_pose(self.pose_paths[idx]) 71 | 72 | if self.has_params: 73 | params = data_util.load_params(self.param_paths[idx]) 74 | else: 75 | params = np.array([0]) 76 | 77 | uv = np.mgrid[0:self.img_sidelength, 0:self.img_sidelength].astype(np.int32) 78 | uv = torch.from_numpy(np.flip(uv, axis=0).copy()).long() 79 | uv = uv.reshape(2, -1).transpose(1, 0) 80 | 81 | sample = { 82 | "instance_idx": torch.Tensor([self.instance_idx]).squeeze(), 83 | "rgb": torch.from_numpy(rgb).float(), 84 | "pose": torch.from_numpy(pose).float(), 85 | "uv": uv, 86 | "param": torch.from_numpy(params).float(), 87 | "intrinsics": intrinsics 88 | } 89 | return sample 90 | 91 | 92 | class SceneClassDataset(torch.utils.data.Dataset): 93 | """Dataset for a class of objects, where each datapoint is a SceneInstanceDataset.""" 94 | 95 | def __init__(self, 96 | root_dir, 97 | img_sidelength=None, 98 | max_num_instances=-1, 99 | max_observations_per_instance=-1, 100 | specific_observation_idcs=None, # For few-shot case: Can pick specific observations only 101 | samples_per_instance=2): 102 | 103 | self.samples_per_instance = samples_per_instance 104 | self.instance_dirs = sorted(glob(os.path.join(root_dir, "*/"))) 105 | 106 | assert (len(self.instance_dirs) != 0), "No objects in the data directory" 107 | 108 | if max_num_instances != -1: 109 | self.instance_dirs = self.instance_dirs[:max_num_instances] 110 | 111 | self.all_instances = [SceneInstanceDataset(instance_idx=idx, 112 | instance_dir=dir, 113 | specific_observation_idcs=specific_observation_idcs, 114 | img_sidelength=img_sidelength, 115 | num_images=max_observations_per_instance) 116 | for idx, dir in enumerate(self.instance_dirs)] 117 | 118 | self.num_per_instance_observations = [len(obj) for obj in self.all_instances] 119 | self.num_instances = len(self.all_instances) 120 | 121 | def set_img_sidelength(self, new_img_sidelength): 122 | """For multi-resolution training: Updates the image sidelength with whichimages are loaded.""" 123 | for instance in self.all_instances: 124 | instance.set_img_sidelength(new_img_sidelength) 125 | 126 | def __len__(self): 127 | return np.sum(self.num_per_instance_observations) 128 | 129 | def get_instance_idx(self, idx): 130 | """Maps an index into all tuples of all objects to the idx of the tuple relative to the other tuples of that 131 | object 132 | """ 133 | obj_idx = 0 134 | while idx >= 0: 135 | idx -= self.num_per_instance_observations[obj_idx] 136 | obj_idx += 1 137 | return obj_idx - 1, int(idx + self.num_per_instance_observations[obj_idx - 1]) 138 | 139 | def collate_fn(self, batch_list): 140 | batch_list = zip(*batch_list) 141 | 142 | all_parsed = [] 143 | for entry in batch_list: 144 | # make them all into a new dict 145 | ret = {} 146 | for k in entry[0][0].keys(): 147 | ret[k] = [] 148 | # flatten the list of list 149 | for b in entry: 150 | for k in entry[0][0].keys(): 151 | ret[k].extend( [bi[k] for bi in b]) 152 | for k in ret.keys(): 153 | if type(ret[k][0]) == torch.Tensor: 154 | ret[k] = torch.stack(ret[k]) 155 | all_parsed.append(ret) 156 | 157 | return tuple(all_parsed) 158 | 159 | def __getitem__(self, idx): 160 | """Each __getitem__ call yields a list of self.samples_per_instance observations of a single scene (each a dict), 161 | as well as a list of ground-truths for each observation (also a dict).""" 162 | obj_idx, rel_idx = self.get_instance_idx(idx) 163 | 164 | observations = [] 165 | observations.append(self.all_instances[obj_idx][rel_idx]) 166 | 167 | for i in range(self.samples_per_instance - 1): 168 | observations.append(self.all_instances[obj_idx][np.random.randint(len(self.all_instances[obj_idx]))]) 169 | 170 | ground_truth = [{'rgb':ray_bundle['rgb']} for ray_bundle in observations] 171 | 172 | return observations, ground_truth 173 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: srns 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - blas=1.0=mkl 8 | - bzip2=1.0.8=h7b6447c_0 9 | - ca-certificates=2019.8.28=0 10 | - cairo=1.14.12=h8948797_3 11 | - certifi=2019.9.11=py37_0 12 | - cffi=1.12.3=py37h2e261b9_0 13 | - cloudpickle=1.2.2=py_0 14 | - configargparse=0.14.0=py37_0 15 | - cudatoolkit=10.0.130=0 16 | - cycler=0.10.0=py37_0 17 | - cytoolz=0.10.0=py37h7b6447c_0 18 | - dask-core=2.4.0=py_0 19 | - dbus=1.13.6=h746ee38_0 20 | - decorator=4.4.0=py37_1 21 | - expat=2.2.6=he6710b0_0 22 | - ffmpeg=4.0=hcdf2ecd_0 23 | - fontconfig=2.13.0=h9420a91_0 24 | - freeglut=3.0.0=hf484d3e_5 25 | - freetype=2.9.1=h8a8886c_1 26 | - glib=2.56.2=hd408876_0 27 | - graphite2=1.3.13=h23475e2_0 28 | - gst-plugins-base=1.14.0=hbbd80ab_1 29 | - gstreamer=1.14.0=hb453b48_1 30 | - harfbuzz=1.8.8=hffaf4a1_0 31 | - hdf5=1.10.2=hba1933b_1 32 | - icu=58.2=h9c2bf20_1 33 | - imageio=2.5.0=py37_0 34 | - intel-openmp=2019.4=243 35 | - jasper=2.0.14=h07fcdf6_1 36 | - jpeg=9b=h024ee3a_2 37 | - kiwisolver=1.1.0=py37he6710b0_0 38 | - libedit=3.1.20181209=hc058e9b_0 39 | - libffi=3.2.1=hd88cf55_4 40 | - libgcc-ng=9.1.0=hdf63c60_0 41 | - libgfortran-ng=7.3.0=hdf63c60_0 42 | - libglu=9.0.0=hf484d3e_1 43 | - libopencv=3.4.2=hb342d67_1 44 | - libopus=1.3=h7b6447c_0 45 | - libpng=1.6.37=hbc83047_0 46 | - libstdcxx-ng=9.1.0=hdf63c60_0 47 | - libtiff=4.0.10=h2733197_2 48 | - libuuid=1.0.3=h1bed415_2 49 | - libvpx=1.7.0=h439df22_0 50 | - libxcb=1.13=h1bed415_1 51 | - libxml2=2.9.9=hea5a465_1 52 | - matplotlib=3.1.1=py37h5429711_0 53 | - mkl=2019.4=243 54 | - mkl-service=2.3.0=py37he904b0f_0 55 | - mkl_fft=1.0.14=py37ha843d7b_0 56 | - mkl_random=1.1.0=py37hd6b4f25_0 57 | - ncurses=6.1=he6710b0_1 58 | - networkx=2.3=py_0 59 | - ninja=1.9.0=py37hfd86e86_0 60 | - numpy=1.17.2=py37haad9e8e_0 61 | - numpy-base=1.17.2=py37hde5b4d6_0 62 | - olefile=0.46=py37_0 63 | - opencv=3.4.2=py37h6fd60c2_1 64 | - openssl=1.1.1d=h7b6447c_1 65 | - pandas=0.25.1=py37he6710b0_0 66 | - pcre=8.43=he6710b0_0 67 | - pillow=6.1.0=py37h34e0f95_0 68 | - pip=19.2.3=py37_0 69 | - pixman=0.38.0=h7b6447c_0 70 | - py-opencv=3.4.2=py37hb342d67_1 71 | - pycparser=2.19=py37_0 72 | - pyparsing=2.4.2=py_0 73 | - pyqt=5.9.2=py37h05f1152_2 74 | - python=3.7.4=h265db76_1 75 | - python-dateutil=2.8.0=py37_0 76 | - pytz=2019.2=py_0 77 | - pywavelets=1.0.3=py37hdd07704_1 78 | - qt=5.9.7=h5867ecd_1 79 | - readline=7.0=h7b6447c_5 80 | - scikit-image=0.15.0=py37he6710b0_0 81 | - scipy=1.3.1=py37h7c811a0_0 82 | - setuptools=41.2.0=py37_0 83 | - sip=4.19.8=py37hf484d3e_0 84 | - six=1.12.0=py37_0 85 | - sqlite=3.29.0=h7b6447c_0 86 | - tk=8.6.8=hbc83047_0 87 | - toolz=0.10.0=py_0 88 | - tornado=6.0.3=py37h7b6447c_0 89 | - wheel=0.33.6=py37_0 90 | - xz=5.2.4=h14c3975_4 91 | - zlib=1.2.11=h7b6447c_3 92 | - zstd=1.3.7=h0b5b093_0 93 | - pytorch=1.2.0=py3.7_cuda10.0.130_cudnn7.6.2_0 94 | - torchvision=0.4.0=py37_cu100 95 | - pip: 96 | - absl-py==0.8.0 97 | - dask==2.4.0 98 | - future==0.17.1 99 | - grpcio==1.23.0 100 | - markdown==3.1.1 101 | - protobuf==3.9.2 102 | - pybind11==2.2.4 103 | - tensorboard==2.0.0 104 | - torch==1.2.0 105 | - werkzeug==0.16.0 106 | -------------------------------------------------------------------------------- /geometry.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from torch.nn import functional as F 5 | import util 6 | 7 | 8 | def compute_normal_map(x_img, y_img, z, intrinsics): 9 | cam_coords = lift(x_img, y_img, z, intrinsics) 10 | cam_coords = util.lin2img(cam_coords) 11 | 12 | shift_left = cam_coords[:, :, 2:, :] 13 | shift_right = cam_coords[:, :, :-2, :] 14 | 15 | shift_up = cam_coords[:, :, :, 2:] 16 | shift_down = cam_coords[:, :, :, :-2] 17 | 18 | diff_hor = F.normalize(shift_right - shift_left, dim=1)[:, :, :, 1:-1] 19 | diff_ver = F.normalize(shift_up - shift_down, dim=1)[:, :, 1:-1, :] 20 | 21 | cross = torch.cross(diff_hor, diff_ver, dim=1) 22 | return cross 23 | 24 | 25 | def get_ray_directions_cam(uv, intrinsics): 26 | '''Translates meshgrid of uv pixel coordinates to normalized directions of rays through these pixels, 27 | in camera coordinates. 28 | ''' 29 | batch_size, num_samples, _ = uv.shape 30 | 31 | x_cam = uv[:, :, 0].view(batch_size, -1) 32 | y_cam = uv[:, :, 1].view(batch_size, -1) 33 | z_cam = torch.ones((batch_size, num_samples)).cuda() 34 | 35 | pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics, homogeneous=False) # (batch_size, -1, 4) 36 | ray_dirs = F.normalize(pixel_points_cam, dim=2) 37 | return ray_dirs 38 | 39 | 40 | def reflect_vector_on_vector(vector_to_reflect, reflection_axis): 41 | refl = F.normalize(vector_to_reflect.cuda()) 42 | ax = F.normalize(reflection_axis.cuda()) 43 | 44 | r = 2 * (ax * refl).sum(dim=1, keepdim=True) * ax - refl 45 | return r 46 | 47 | 48 | def parse_intrinsics(intrinsics): 49 | intrinsics = intrinsics.cuda() 50 | 51 | fx = intrinsics[:, 0, 0] 52 | fy = intrinsics[:, 1, 1] 53 | cx = intrinsics[:, 0, 2] 54 | cy = intrinsics[:, 1, 2] 55 | return fx, fy, cx, cy 56 | 57 | 58 | def expand_as(x, y): 59 | if len(x.shape) == len(y.shape): 60 | return x 61 | 62 | for i in range(len(y.shape) - len(x.shape)): 63 | x = x.unsqueeze(-1) 64 | 65 | return x 66 | 67 | 68 | def lift(x, y, z, intrinsics, homogeneous=False): 69 | ''' 70 | 71 | :param self: 72 | :param x: Shape (batch_size, num_points) 73 | :param y: 74 | :param z: 75 | :param intrinsics: 76 | :return: 77 | ''' 78 | fx, fy, cx, cy = parse_intrinsics(intrinsics) 79 | 80 | x_lift = (x - expand_as(cx, x)) / expand_as(fx, x) * z 81 | y_lift = (y - expand_as(cy, y)) / expand_as(fy, y) * z 82 | 83 | if homogeneous: 84 | return torch.stack((x_lift, y_lift, z, torch.ones_like(z).cuda()), dim=-1) 85 | else: 86 | return torch.stack((x_lift, y_lift, z), dim=-1) 87 | 88 | 89 | def project(x, y, z, intrinsics): 90 | ''' 91 | 92 | :param self: 93 | :param x: Shape (batch_size, num_points) 94 | :param y: 95 | :param z: 96 | :param intrinsics: 97 | :return: 98 | ''' 99 | fx, fy, cx, cy = parse_intrinsics(intrinsics) 100 | 101 | x_proj = expand_as(fx, x) * x / z + expand_as(cx, x) 102 | y_proj = expand_as(fy, y) * y / z + expand_as(cy, y) 103 | 104 | return torch.stack((x_proj, y_proj, z), dim=-1) 105 | 106 | 107 | def world_from_xy_depth(xy, depth, cam2world, intrinsics): 108 | '''Translates meshgrid of xy pixel coordinates plus depth to world coordinates. 109 | ''' 110 | batch_size, _, _ = cam2world.shape 111 | 112 | x_cam = xy[:, :, 0].view(batch_size, -1) 113 | y_cam = xy[:, :, 1].view(batch_size, -1) 114 | z_cam = depth.view(batch_size, -1) 115 | 116 | pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics, homogeneous=True) # (batch_size, -1, 4) 117 | 118 | # permute for batch matrix product 119 | pixel_points_cam = pixel_points_cam.permute(0, 2, 1) 120 | 121 | world_coords = torch.bmm(cam2world, pixel_points_cam).permute(0, 2, 1)[:, :, :3] # (batch_size, -1, 3) 122 | 123 | return world_coords 124 | 125 | 126 | def project_point_on_line(projection_point, line_direction, point_on_line, dim): 127 | '''Projects a batch of points on a batch of lines as defined by their direction and a point on each line. ''' 128 | assert torch.allclose((line_direction ** 2).sum(dim=dim, keepdim=True).cuda(), torch.Tensor([1]).cuda()) 129 | return point_on_line + ((projection_point - point_on_line) * line_direction).sum(dim=dim, 130 | keepdim=True) * line_direction 131 | 132 | def get_ray_directions(xy, cam2world, intrinsics): 133 | '''Translates meshgrid of xy pixel coordinates to normalized directions of rays through these pixels. 134 | ''' 135 | batch_size, num_samples, _ = xy.shape 136 | 137 | z_cam = torch.ones((batch_size, num_samples)).cuda() 138 | pixel_points = world_from_xy_depth(xy, z_cam, intrinsics=intrinsics, cam2world=cam2world) # (batch, num_samples, 3) 139 | 140 | cam_pos = cam2world[:, :3, 3] 141 | ray_dirs = pixel_points - cam_pos[:, None, :] # (batch, num_samples, 3) 142 | ray_dirs = F.normalize(ray_dirs, dim=2) 143 | return ray_dirs 144 | 145 | 146 | def depth_from_world(world_coords, cam2world): 147 | batch_size, num_samples, _ = world_coords.shape 148 | 149 | points_hom = torch.cat((world_coords, torch.ones((batch_size, num_samples, 1)).cuda()), 150 | dim=2) # (batch, num_samples, 4) 151 | 152 | # permute for bmm 153 | points_hom = points_hom.permute(0, 2, 1) 154 | 155 | points_cam = torch.inverse(cam2world).bmm(points_hom) # (batch, 4, num_samples) 156 | depth = points_cam[:, 2, :][:, :, None] # (batch, num_samples, 1) 157 | return depth 158 | 159 | 160 | -------------------------------------------------------------------------------- /hyperlayers.py: -------------------------------------------------------------------------------- 1 | '''Pytorch implementations of hyper-network modules.''' 2 | import torch 3 | import torch.nn as nn 4 | from pytorch_prototyping import pytorch_prototyping 5 | import functools 6 | 7 | 8 | def partialclass(cls, *args, **kwds): 9 | 10 | class NewCls(cls): 11 | __init__ = functools.partialmethod(cls.__init__, *args, **kwds) 12 | 13 | return NewCls 14 | 15 | 16 | class LookupLayer(nn.Module): 17 | def __init__(self, in_ch, out_ch, num_objects): 18 | super().__init__() 19 | 20 | self.out_ch = out_ch 21 | self.lookup_lin = LookupLinear(in_ch, 22 | out_ch, 23 | num_objects=num_objects) 24 | self.norm_nl = nn.Sequential( 25 | nn.LayerNorm([self.out_ch], elementwise_affine=False), 26 | nn.ReLU(inplace=True) 27 | ) 28 | 29 | def forward(self, obj_idx): 30 | net = nn.Sequential( 31 | self.lookup_lin(obj_idx), 32 | self.norm_nl 33 | ) 34 | return net 35 | 36 | class LookupFC(nn.Module): 37 | def __init__(self, 38 | hidden_ch, 39 | num_hidden_layers, 40 | num_objects, 41 | in_ch, 42 | out_ch, 43 | outermost_linear=False): 44 | super().__init__() 45 | self.layers = nn.ModuleList() 46 | self.layers.append(LookupLayer(in_ch=in_ch, out_ch=hidden_ch, num_objects=num_objects)) 47 | 48 | for i in range(num_hidden_layers): 49 | self.layers.append(LookupLayer(in_ch=hidden_ch, out_ch=hidden_ch, num_objects=num_objects)) 50 | 51 | if outermost_linear: 52 | self.layers.append(LookupLinear(in_ch=hidden_ch, out_ch=out_ch, num_objects=num_objects)) 53 | else: 54 | self.layers.append(LookupLayer(in_ch=hidden_ch, out_ch=out_ch, num_objects=num_objects)) 55 | 56 | def forward(self, obj_idx): 57 | net = [] 58 | for i in range(len(self.layers)): 59 | net.append(self.layers[i](obj_idx)) 60 | 61 | return nn.Sequential(*net) 62 | 63 | 64 | class LookupLinear(nn.Module): 65 | def __init__(self, 66 | in_ch, 67 | out_ch, 68 | num_objects): 69 | super().__init__() 70 | self.in_ch = in_ch 71 | self.out_ch = out_ch 72 | 73 | self.hypo_params = nn.Embedding(num_objects, in_ch * out_ch + out_ch) 74 | 75 | for i in range(num_objects): 76 | nn.init.kaiming_normal_(self.hypo_params.weight.data[i, :self.in_ch * self.out_ch].view(self.out_ch, self.in_ch), 77 | a=0.0, 78 | nonlinearity='relu', 79 | mode='fan_in') 80 | self.hypo_params.weight.data[i, self.in_ch * self.out_ch:].fill_(0.) 81 | 82 | def forward(self, obj_idx): 83 | hypo_params = self.hypo_params(obj_idx) 84 | 85 | # Indices explicit to catch erros in shape of output layer 86 | weights = hypo_params[..., :self.in_ch * self.out_ch] 87 | biases = hypo_params[..., self.in_ch * self.out_ch:(self.in_ch * self.out_ch)+self.out_ch] 88 | 89 | biases = biases.view(*(biases.size()[:-1]), 1, self.out_ch) 90 | weights = weights.view(*(weights.size()[:-1]), self.out_ch, self.in_ch) 91 | 92 | return BatchLinear(weights=weights, biases=biases) 93 | 94 | 95 | class HyperLayer(nn.Module): 96 | '''A hypernetwork that predicts a single Dense Layer, including LayerNorm and a ReLU.''' 97 | def __init__(self, 98 | in_ch, 99 | out_ch, 100 | hyper_in_ch, 101 | hyper_num_hidden_layers, 102 | hyper_hidden_ch): 103 | super().__init__() 104 | 105 | self.hyper_linear = HyperLinear(in_ch=in_ch, 106 | out_ch=out_ch, 107 | hyper_in_ch=hyper_in_ch, 108 | hyper_num_hidden_layers=hyper_num_hidden_layers, 109 | hyper_hidden_ch=hyper_hidden_ch) 110 | self.norm_nl = nn.Sequential( 111 | nn.LayerNorm([out_ch], elementwise_affine=False), 112 | nn.ReLU(inplace=True) 113 | ) 114 | 115 | def forward(self, hyper_input): 116 | ''' 117 | :param hyper_input: input to hypernetwork. 118 | :return: nn.Module; predicted fully connected network. 119 | ''' 120 | return nn.Sequential(self.hyper_linear(hyper_input), self.norm_nl) 121 | 122 | 123 | class HyperFC(nn.Module): 124 | '''Builds a hypernetwork that predicts a fully connected neural network. 125 | ''' 126 | def __init__(self, 127 | hyper_in_ch, 128 | hyper_num_hidden_layers, 129 | hyper_hidden_ch, 130 | hidden_ch, 131 | num_hidden_layers, 132 | in_ch, 133 | out_ch, 134 | outermost_linear=False): 135 | super().__init__() 136 | 137 | PreconfHyperLinear = partialclass(HyperLinear, 138 | hyper_in_ch=hyper_in_ch, 139 | hyper_num_hidden_layers=hyper_num_hidden_layers, 140 | hyper_hidden_ch=hyper_hidden_ch) 141 | PreconfHyperLayer = partialclass(HyperLayer, 142 | hyper_in_ch=hyper_in_ch, 143 | hyper_num_hidden_layers=hyper_num_hidden_layers, 144 | hyper_hidden_ch=hyper_hidden_ch) 145 | 146 | self.layers = nn.ModuleList() 147 | self.layers.append(PreconfHyperLayer(in_ch=in_ch, out_ch=hidden_ch)) 148 | 149 | for i in range(num_hidden_layers): 150 | self.layers.append(PreconfHyperLayer(in_ch=hidden_ch, out_ch=hidden_ch)) 151 | 152 | if outermost_linear: 153 | self.layers.append(PreconfHyperLinear(in_ch=hidden_ch, out_ch=out_ch)) 154 | else: 155 | self.layers.append(PreconfHyperLayer(in_ch=hidden_ch, out_ch=out_ch)) 156 | 157 | 158 | def forward(self, hyper_input): 159 | ''' 160 | :param hyper_input: Input to hypernetwork. 161 | :return: nn.Module; Predicted fully connected neural network. 162 | ''' 163 | net = [] 164 | for i in range(len(self.layers)): 165 | net.append(self.layers[i](hyper_input)) 166 | 167 | return nn.Sequential(*net) 168 | 169 | 170 | class BatchLinear(nn.Module): 171 | def __init__(self, 172 | weights, 173 | biases): 174 | '''Implements a batch linear layer. 175 | 176 | :param weights: Shape: (batch, out_ch, in_ch) 177 | :param biases: Shape: (batch, 1, out_ch) 178 | ''' 179 | super().__init__() 180 | 181 | self.weights = weights 182 | self.biases = biases 183 | 184 | def __repr__(self): 185 | return "BatchLinear(in_ch=%d, out_ch=%d)"%(self.weights.shape[-1], self.weights.shape[-2]) 186 | 187 | def forward(self, input): 188 | output = input.matmul(self.weights.permute(*[i for i in range(len(self.weights.shape)-2)], -1, -2)) 189 | output += self.biases 190 | return output 191 | 192 | 193 | def last_hyper_layer_init(m): 194 | if type(m) == nn.Linear: 195 | nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in') 196 | m.weight.data *= 1e-1 197 | 198 | 199 | class HyperLinear(nn.Module): 200 | '''A hypernetwork that predicts a single linear layer (weights & biases).''' 201 | def __init__(self, 202 | in_ch, 203 | out_ch, 204 | hyper_in_ch, 205 | hyper_num_hidden_layers, 206 | hyper_hidden_ch): 207 | 208 | super().__init__() 209 | self.in_ch = in_ch 210 | self.out_ch = out_ch 211 | 212 | self.hypo_params = pytorch_prototyping.FCBlock(in_features=hyper_in_ch, 213 | hidden_ch=hyper_hidden_ch, 214 | num_hidden_layers=hyper_num_hidden_layers, 215 | out_features=(in_ch * out_ch) + out_ch, 216 | outermost_linear=True) 217 | self.hypo_params[-1].apply(last_hyper_layer_init) 218 | 219 | def forward(self, hyper_input): 220 | hypo_params = self.hypo_params(hyper_input.cuda()) 221 | 222 | # Indices explicit to catch erros in shape of output layer 223 | weights = hypo_params[..., :self.in_ch * self.out_ch] 224 | biases = hypo_params[..., self.in_ch * self.out_ch:(self.in_ch * self.out_ch)+self.out_ch] 225 | 226 | biases = biases.view(*(biases.size()[:-1]), 1, self.out_ch) 227 | weights = weights.view(*(weights.size()[:-1]), self.out_ch, self.in_ch) 228 | 229 | return BatchLinear(weights=weights, biases=biases) 230 | 231 | -------------------------------------------------------------------------------- /srns.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | import torchvision 6 | import util 7 | 8 | import skimage.measure 9 | from torch.nn import functional as F 10 | 11 | from pytorch_prototyping import pytorch_prototyping 12 | import custom_layers 13 | import geometry 14 | import hyperlayers 15 | 16 | 17 | class SRNsModel(nn.Module): 18 | def __init__(self, 19 | num_instances, 20 | latent_dim, 21 | tracing_steps, 22 | has_params=False, 23 | fit_single_srn=False, 24 | use_unet_renderer=False, 25 | freeze_networks=False): 26 | super().__init__() 27 | 28 | self.latent_dim = latent_dim 29 | self.has_params = has_params 30 | 31 | self.num_hidden_units_phi = 256 32 | self.phi_layers = 4 # includes the in and out layers 33 | self.rendering_layers = 5 # includes the in and out layers 34 | self.sphere_trace_steps = tracing_steps 35 | self.freeze_networks = freeze_networks 36 | self.fit_single_srn = fit_single_srn 37 | 38 | if self.fit_single_srn: # Fit a single scene with a single SRN (no hypernetworks) 39 | self.phi = pytorch_prototyping.FCBlock(hidden_ch=self.num_hidden_units_phi, 40 | num_hidden_layers=self.phi_layers - 2, 41 | in_features=3, 42 | out_features=self.num_hidden_units_phi) 43 | else: 44 | # Auto-decoder: each scene instance gets its own code vector z 45 | self.latent_codes = nn.Embedding(num_instances, latent_dim).cuda() 46 | nn.init.normal_(self.latent_codes.weight, mean=0, std=0.01) 47 | 48 | self.hyper_phi = hyperlayers.HyperFC(hyper_in_ch=self.latent_dim, 49 | hyper_num_hidden_layers=1, 50 | hyper_hidden_ch=self.latent_dim, 51 | hidden_ch=self.num_hidden_units_phi, 52 | num_hidden_layers=self.phi_layers - 2, 53 | in_ch=3, 54 | out_ch=self.num_hidden_units_phi) 55 | 56 | self.ray_marcher = custom_layers.Raymarcher(num_feature_channels=self.num_hidden_units_phi, 57 | raymarch_steps=self.sphere_trace_steps) 58 | 59 | if use_unet_renderer: 60 | self.pixel_generator = custom_layers.DeepvoxelsRenderer(nf0=32, in_channels=self.num_hidden_units_phi, 61 | input_resolution=128, img_sidelength=128) 62 | else: 63 | self.pixel_generator = pytorch_prototyping.FCBlock(hidden_ch=self.num_hidden_units_phi, 64 | num_hidden_layers=self.rendering_layers - 1, 65 | in_features=self.num_hidden_units_phi, 66 | out_features=3, 67 | outermost_linear=True) 68 | 69 | if self.freeze_networks: 70 | all_network_params = (list(self.pixel_generator.parameters()) 71 | + list(self.ray_marcher.parameters()) 72 | + list(self.hyper_phi.parameters())) 73 | for param in all_network_params: 74 | param.requires_grad = False 75 | 76 | # Losses 77 | self.l2_loss = nn.MSELoss(reduction="mean") 78 | 79 | # List of logs 80 | self.logs = list() 81 | 82 | print(self) 83 | print("Number of parameters:") 84 | util.print_network(self) 85 | 86 | def get_regularization_loss(self, prediction, ground_truth): 87 | """Computes regularization loss on final depth map (L_{depth} in eq. 6 in paper) 88 | 89 | :param prediction (tuple): Output of forward pass. 90 | :param ground_truth: Ground-truth (unused). 91 | :return: Regularization loss on final depth map. 92 | """ 93 | _, depth = prediction 94 | 95 | neg_penalty = (torch.min(depth, torch.zeros_like(depth)) ** 2) 96 | return torch.mean(neg_penalty) * 10000 97 | 98 | def get_image_loss(self, prediction, ground_truth): 99 | """Computes loss on predicted image (L_{img} in eq. 6 in paper) 100 | 101 | :param prediction (tuple): Output of forward pass. 102 | :param ground_truth: Ground-truth (unused). 103 | :return: image reconstruction loss. 104 | """ 105 | pred_imgs, _ = prediction 106 | trgt_imgs = ground_truth['rgb'] 107 | 108 | trgt_imgs = trgt_imgs.cuda() 109 | 110 | loss = self.l2_loss(pred_imgs, trgt_imgs) 111 | return loss 112 | 113 | def get_latent_loss(self): 114 | """Computes loss on latent code vectors (L_{latent} in eq. 6 in paper) 115 | :return: Latent loss. 116 | """ 117 | if self.fit_single_srn: 118 | self.latent_reg_loss = 0 119 | else: 120 | self.latent_reg_loss = torch.mean(self.z ** 2) 121 | 122 | return self.latent_reg_loss 123 | 124 | def get_psnr(self, prediction, ground_truth): 125 | """Compute PSNR of model image predictions. 126 | 127 | :param prediction: Return value of forward pass. 128 | :param ground_truth: Ground truth. 129 | :return: (psnr, ssim): tuple of floats 130 | """ 131 | pred_imgs, _ = prediction 132 | trgt_imgs = ground_truth['rgb'] 133 | 134 | trgt_imgs = trgt_imgs.cuda() 135 | batch_size = pred_imgs.shape[0] 136 | 137 | if not isinstance(pred_imgs, np.ndarray): 138 | pred_imgs = util.lin2img(pred_imgs).detach().cpu().numpy() 139 | 140 | if not isinstance(trgt_imgs, np.ndarray): 141 | trgt_imgs = util.lin2img(trgt_imgs).detach().cpu().numpy() 142 | 143 | psnrs, ssims = list(), list() 144 | for i in range(batch_size): 145 | p = pred_imgs[i].squeeze().transpose(1, 2, 0) 146 | trgt = trgt_imgs[i].squeeze().transpose(1, 2, 0) 147 | 148 | p = (p / 2.) + 0.5 149 | p = np.clip(p, a_min=0., a_max=1.) 150 | 151 | trgt = (trgt / 2.) + 0.5 152 | 153 | ssim = skimage.measure.compare_ssim(p, trgt, multichannel=True, data_range=1) 154 | psnr = skimage.measure.compare_psnr(p, trgt, data_range=1) 155 | 156 | psnrs.append(psnr) 157 | ssims.append(ssim) 158 | 159 | return psnrs, ssims 160 | 161 | def get_comparisons(self, model_input, prediction, ground_truth=None): 162 | predictions, depth_maps = prediction 163 | 164 | batch_size = predictions.shape[0] 165 | 166 | # Parse model input. 167 | intrinsics = model_input["intrinsics"].cuda() 168 | uv = model_input["uv"].cuda().float() 169 | 170 | x_cam = uv[:, :, 0].view(batch_size, -1) 171 | y_cam = uv[:, :, 1].view(batch_size, -1) 172 | z_cam = depth_maps.view(batch_size, -1) 173 | 174 | normals = geometry.compute_normal_map(x_img=x_cam, y_img=y_cam, z=z_cam, intrinsics=intrinsics) 175 | normals = F.pad(normals, pad=(1, 1, 1, 1), mode="constant", value=1.) 176 | 177 | predictions = util.lin2img(predictions) 178 | 179 | if ground_truth is not None: 180 | trgt_imgs = ground_truth["rgb"] 181 | trgt_imgs = util.lin2img(trgt_imgs) 182 | 183 | return torch.cat((normals.cpu(), predictions.cpu(), trgt_imgs.cpu()), dim=3).numpy() 184 | else: 185 | return torch.cat((normals.cpu(), predictions.cpu()), dim=3).numpy() 186 | 187 | def get_output_img(self, prediction): 188 | pred_imgs, _ = prediction 189 | return util.lin2img(pred_imgs) 190 | 191 | def write_updates(self, writer, predictions, ground_truth, iter, prefix=""): 192 | """Writes tensorboard summaries using tensorboardx api. 193 | 194 | :param writer: tensorboardx writer object. 195 | :param predictions: Output of forward pass. 196 | :param ground_truth: Ground truth. 197 | :param iter: Iteration number. 198 | :param prefix: Every summary will be prefixed with this string. 199 | """ 200 | predictions, depth_maps = predictions 201 | trgt_imgs = ground_truth['rgb'] 202 | 203 | trgt_imgs = trgt_imgs.cuda() 204 | 205 | batch_size, num_samples, _ = predictions.shape 206 | 207 | # Module"s own log 208 | for type, name, content, every_n in self.logs: 209 | name = prefix + name 210 | 211 | if not iter % every_n: 212 | if type == "image": 213 | writer.add_image(name, content.detach().cpu().numpy(), iter) 214 | writer.add_scalar(name + "_min", content.min(), iter) 215 | writer.add_scalar(name + "_max", content.max(), iter) 216 | elif type == "figure": 217 | writer.add_figure(name, content, iter, close=True) 218 | elif type == "histogram": 219 | writer.add_histogram(name, content.detach().cpu().numpy(), iter) 220 | elif type == "scalar": 221 | writer.add_scalar(name, content.detach().cpu().numpy(), iter) 222 | elif type == "embedding": 223 | writer.add_embedding(mat=content, global_step=iter) 224 | 225 | if not iter % 100: 226 | output_vs_gt = torch.cat((predictions, trgt_imgs), dim=0) 227 | output_vs_gt = util.lin2img(output_vs_gt) 228 | writer.add_image(prefix + "Output_vs_gt", 229 | torchvision.utils.make_grid(output_vs_gt, 230 | scale_each=False, 231 | normalize=True).cpu().detach().numpy(), 232 | iter) 233 | 234 | rgb_loss = ((predictions.float().cuda() - trgt_imgs.float().cuda()) ** 2).mean(dim=2, keepdim=True) 235 | rgb_loss = util.lin2img(rgb_loss) 236 | 237 | fig = util.show_images([rgb_loss[i].detach().cpu().numpy().squeeze() 238 | for i in range(batch_size)]) 239 | writer.add_figure(prefix + "rgb_error_fig", 240 | fig, 241 | iter, 242 | close=True) 243 | 244 | depth_maps_plot = util.lin2img(depth_maps) 245 | writer.add_image(prefix + "pred_depth", 246 | torchvision.utils.make_grid(depth_maps_plot.repeat(1, 3, 1, 1), 247 | scale_each=True, 248 | normalize=True).cpu().detach().numpy(), 249 | iter) 250 | 251 | writer.add_scalar(prefix + "out_min", predictions.min(), iter) 252 | writer.add_scalar(prefix + "out_max", predictions.max(), iter) 253 | 254 | writer.add_scalar(prefix + "trgt_min", trgt_imgs.min(), iter) 255 | writer.add_scalar(prefix + "trgt_max", trgt_imgs.max(), iter) 256 | 257 | if iter: 258 | writer.add_scalar(prefix + "latent_reg_loss", self.latent_reg_loss, iter) 259 | 260 | def forward(self, input, z=None): 261 | self.logs = list() # log saves tensors that"ll receive summaries when model"s write_updates function is called 262 | 263 | # Parse model input. 264 | instance_idcs = input["instance_idx"].long().cuda() 265 | pose = input["pose"].cuda() 266 | intrinsics = input["intrinsics"].cuda() 267 | uv = input["uv"].cuda().float() 268 | 269 | if self.fit_single_srn: 270 | phi = self.phi 271 | else: 272 | if self.has_params: # If each instance has a latent parameter vector, we"ll use that one. 273 | if z is None: 274 | self.z = input["param"].cuda() 275 | else: 276 | self.z = z 277 | else: # Else, we"ll use the embedding. 278 | self.z = self.latent_codes(instance_idcs) 279 | 280 | phi = self.hyper_phi(self.z) # Forward pass through hypernetwork yields a (callable) SRN. 281 | 282 | # Raymarch SRN phi along rays defined by camera pose, intrinsics and uv coordinates. 283 | points_xyz, depth_maps, log = self.ray_marcher(cam2world=pose, 284 | intrinsics=intrinsics, 285 | uv=uv, 286 | phi=phi) 287 | self.logs.extend(log) 288 | 289 | # Sapmle phi a last time at the final ray-marched world coordinates. 290 | v = phi(points_xyz) 291 | 292 | # Translate features at ray-marched world coordinates to RGB colors. 293 | novel_views = self.pixel_generator(v) 294 | 295 | # Calculate normal map 296 | with torch.no_grad(): 297 | batch_size = uv.shape[0] 298 | x_cam = uv[:, :, 0].view(batch_size, -1) 299 | y_cam = uv[:, :, 1].view(batch_size, -1) 300 | z_cam = depth_maps.view(batch_size, -1) 301 | 302 | normals = geometry.compute_normal_map(x_img=x_cam, y_img=y_cam, z=z_cam, intrinsics=intrinsics) 303 | self.logs.append(("image", "normals", 304 | torchvision.utils.make_grid(normals, scale_each=True, normalize=True), 100)) 305 | 306 | if not self.fit_single_srn: 307 | self.logs.append(("embedding", "", self.latent_codes.weight, 500)) 308 | self.logs.append(("scalar", "embed_min", self.z.min(), 1)) 309 | self.logs.append(("scalar", "embed_max", self.z.max(), 1)) 310 | 311 | return novel_views, depth_maps 312 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | import os, time, datetime 3 | 4 | import torch 5 | import numpy as np 6 | 7 | import dataio 8 | from torch.utils.data import DataLoader 9 | from srns import * 10 | import util 11 | 12 | p = configargparse.ArgumentParser() 13 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.') 14 | 15 | # Note: in contrast to training, no multi-resolution! 16 | p.add_argument('--img_sidelength', type=int, default=128, required=False, 17 | help='Sidelength of test images.') 18 | 19 | p.add_argument('--data_root', required=True, help='Path to directory with training data.') 20 | p.add_argument('--logging_root', type=str, default='./logs', 21 | required=False, help='Path to directory where checkpoints & tensorboard events will be saved.') 22 | p.add_argument('--batch_size', type=int, default=32, help='Batch size.') 23 | p.add_argument('--preload', action='store_true', default=False, help='Whether to preload data to RAM.') 24 | 25 | p.add_argument('--max_num_instances', type=int, default=-1, 26 | help='If \'data_root\' has more instances, only the first max_num_instances are used') 27 | p.add_argument('--specific_observation_idcs', type=str, default=None, 28 | help='Only pick a subset of specific observations for each instance.') 29 | p.add_argument('--has_params', action='store_true', default=False, 30 | help='Whether each object instance already comes with its own parameter vector.') 31 | 32 | p.add_argument('--save_out_first_n',type=int, default=250, help='Only saves images of first n object instances.') 33 | p.add_argument('--checkpoint_path', default=None, help='Path to trained model.') 34 | 35 | # Model options 36 | p.add_argument('--num_instances', type=int, required=True, 37 | help='The number of object instances that the model was trained with.') 38 | p.add_argument('--tracing_steps', type=int, default=10, help='Number of steps of intersection tester.') 39 | p.add_argument('--fit_single_srn', action='store_true', required=False, 40 | help='Only fit a single SRN for a single scene (not a class of SRNs) --> no hypernetwork') 41 | p.add_argument('--use_unet_renderer', action='store_true', 42 | help='Whether to use a DeepVoxels-style unet as rendering network or a per-pixel 1x1 convnet') 43 | p.add_argument('--embedding_size', type=int, default=256, 44 | help='Dimensionality of latent embedding.') 45 | 46 | opt = p.parse_args() 47 | 48 | device = torch.device('cuda') 49 | 50 | 51 | def test(): 52 | if opt.specific_observation_idcs is not None: 53 | specific_observation_idcs = list(map(int, opt.specific_observation_idcs.split(','))) 54 | else: 55 | specific_observation_idcs = None 56 | 57 | dataset = dataio.SceneClassDataset(root_dir=opt.data_root, 58 | max_num_instances=opt.max_num_instances, 59 | specific_observation_idcs=specific_observation_idcs, 60 | max_observations_per_instance=-1, 61 | samples_per_instance=1, 62 | img_sidelength=opt.img_sidelength) 63 | dataset = DataLoader(dataset, 64 | collate_fn=dataset.collate_fn, 65 | batch_size=1, 66 | shuffle=False, 67 | drop_last=False) 68 | 69 | model = SRNsModel(num_instances=opt.num_instances, 70 | latent_dim=opt.embedding_size, 71 | has_params=opt.has_params, 72 | fit_single_srn=opt.fit_single_srn, 73 | use_unet_renderer=opt.use_unet_renderer, 74 | tracing_steps=opt.tracing_steps) 75 | 76 | assert (opt.checkpoint_path is not None), "Have to pass checkpoint!" 77 | 78 | print("Loading model from %s" % opt.checkpoint_path) 79 | util.custom_load(model, path=opt.checkpoint_path, discriminator=None, 80 | overwrite_embeddings=False) 81 | 82 | model.eval() 83 | model.cuda() 84 | 85 | # directory structure: month_day/ 86 | renderings_dir = os.path.join(opt.logging_root, 'renderings') 87 | gt_comparison_dir = os.path.join(opt.logging_root, 'gt_comparisons') 88 | util.cond_mkdir(opt.logging_root) 89 | util.cond_mkdir(gt_comparison_dir) 90 | util.cond_mkdir(renderings_dir) 91 | 92 | # Save command-line parameters to log directory. 93 | with open(os.path.join(opt.logging_root, "params.txt"), "w") as out_file: 94 | out_file.write('\n'.join(["%s: %s" % (key, value) for key, value in vars(opt).items()])) 95 | 96 | print('Beginning evaluation...') 97 | with torch.no_grad(): 98 | instance_idx = 0 99 | idx = 0 100 | psnrs, ssims = list(), list() 101 | for model_input, ground_truth in dataset: 102 | model_outputs = model(model_input) 103 | psnr, ssim = model.get_psnr(model_outputs, ground_truth) 104 | 105 | psnrs.extend(psnr) 106 | ssims.extend(ssim) 107 | 108 | instance_idcs = model_input['instance_idx'] 109 | print("Object instance %d. Running mean PSNR %0.6f SSIM %0.6f" % 110 | (instance_idcs[-1], np.mean(psnrs), np.mean(ssims))) 111 | 112 | if instance_idx < opt.save_out_first_n: 113 | output_imgs = model.get_output_img(model_outputs).cpu().numpy() 114 | comparisons = model.get_comparisons(model_input, 115 | model_outputs, 116 | ground_truth) 117 | for i in range(len(output_imgs)): 118 | prev_instance_idx = instance_idx 119 | instance_idx = instance_idcs[i] 120 | 121 | if prev_instance_idx != instance_idx: 122 | idx = 0 123 | 124 | img_only_path = os.path.join(renderings_dir, "%06d" % instance_idx) 125 | comp_path = os.path.join(gt_comparison_dir, "%06d" % instance_idx) 126 | 127 | util.cond_mkdir(img_only_path) 128 | util.cond_mkdir(comp_path) 129 | 130 | pred = util.convert_image(output_imgs[i].squeeze()) 131 | comp = util.convert_image(comparisons[i].squeeze()) 132 | 133 | util.write_img(pred, os.path.join(img_only_path, "%06d.png" % idx)) 134 | util.write_img(comp, os.path.join(comp_path, "%06d.png" % idx)) 135 | 136 | idx += 1 137 | 138 | with open(os.path.join(opt.logging_root, "results.txt"), "w") as out_file: 139 | out_file.write("%0.6f, %0.6f" % (np.mean(psnrs), np.mean(ssims))) 140 | 141 | print("Final mean PSNR %0.6f SSIM %0.6f" % (np.mean(psnrs), np.mean(ssims))) 142 | 143 | 144 | def main(): 145 | test() 146 | 147 | 148 | if __name__ == '__main__': 149 | main() 150 | -------------------------------------------------------------------------------- /test_configs/cars_few_shot_novel_view.yml: -------------------------------------------------------------------------------- 1 | # Configuration file for testing novel view synthesis on test set (see paper section 4, paragraph 4) 2 | # Datasets can be downloaded here: https://drive.google.com/drive/folders/1OkYgeRcIcLOFu1ft5mRODWNQaPJ0ps90 3 | 4 | data_root: "" # Path to cars test set (cars from test set with novel views). 5 | logging_root: ./logs/cars 6 | num_instances: 2434 # The number of cars that the model was trained with. 7 | checkpoint: 8 | img_sidelength: 128 9 | batch_size: 16 # This is for a GPU with 48 GB of memory. Adapt accordingly for your GPU memory. 10 | -------------------------------------------------------------------------------- /test_configs/cars_training_set_novel_view.yml: -------------------------------------------------------------------------------- 1 | # Configuration file for testing novel view synthesis on training set (see paper section 4, paragraph 4) 2 | # Datasets can be downloaded here: https://drive.google.com/drive/folders/1OkYgeRcIcLOFu1ft5mRODWNQaPJ0ps90 3 | 4 | data_root: "" # Path to cars training_test set (cars from training set with novel views). 5 | logging_root: ./logs/cars 6 | num_instances: 2433 # The number of cars that the model was trained with. 7 | checkpoint: '' # The path to the trained checkpoint 8 | img_sidelength: 128 9 | batch_size: 16 # This is for a GPU with 48 GB of memory. Adapt accordingly for your GPU memory. 10 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | import os, time, datetime 3 | 4 | import torch 5 | from torch.utils.tensorboard import SummaryWriter 6 | import numpy as np 7 | 8 | import dataio 9 | from torch.utils.data import DataLoader 10 | from srns import * 11 | import util 12 | 13 | p = configargparse.ArgumentParser() 14 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.') 15 | 16 | # Multi-resolution training: Instead of passing only a single value, each of these command-line arguments take comma- 17 | # separated lists. If no multi-resolution training is required, simply pass single values (see default values). 18 | p.add_argument('--img_sidelengths', type=str, default='64', required=False, 19 | help='Progression of image sidelengths.' 20 | 'If comma-separated list, will train on each sidelength for respective max_steps.' 21 | 'Images are downsampled to the respective resolution.') 22 | p.add_argument('--max_steps_per_img_sidelength', type=str, default="200000", 23 | help='Maximum number of optimization steps.' 24 | 'If comma-separated list, is understood as steps per image_sidelength.') 25 | p.add_argument('--batch_size_per_img_sidelength', type=str, default="64", 26 | help='Training batch size.' 27 | 'If comma-separated list, will train each image sidelength with respective batch size.') 28 | 29 | # Training options 30 | p.add_argument('--data_root', required=True, help='Path to directory with training data.') 31 | p.add_argument('--val_root', required=False, help='Path to directory with validation data.') 32 | p.add_argument('--logging_root', type=str, default='./logs', 33 | required=False, help='path to directory where checkpoints & tensorboard events will be saved.') 34 | 35 | p.add_argument('--lr', type=float, default=5e-5, help='learning rate. default=5e-5') 36 | 37 | p.add_argument('--l1_weight', type=float, default=200, 38 | help='Weight for l1 loss term (lambda_img in paper).') 39 | p.add_argument('--kl_weight', type=float, default=1, 40 | help='Weight for l2 loss term on code vectors z (lambda_latent in paper).') 41 | p.add_argument('--reg_weight', type=float, default=1e-3, 42 | help='Weight for depth regularization term (lambda_depth in paper).') 43 | 44 | p.add_argument('--steps_til_ckpt', type=int, default=10000, 45 | help='Number of iterations until checkpoint is saved.') 46 | p.add_argument('--steps_til_val', type=int, default=1000, 47 | help='Number of iterations until validation set is run.') 48 | p.add_argument('--no_validation', action='store_true', default=False, 49 | help='If no validation set should be used.') 50 | 51 | p.add_argument('--preload', action='store_true', default=False, 52 | help='Whether to preload data to RAM.') 53 | 54 | p.add_argument('--checkpoint_path', default=None, 55 | help='Checkpoint to trained model.') 56 | p.add_argument('--overwrite_embeddings', action='store_true', default=False, 57 | help='When loading from checkpoint: Whether to discard checkpoint embeddings and initialize at random.') 58 | p.add_argument('--start_step', type=int, default=0, 59 | help='If continuing from checkpoint, which iteration to start counting at.') 60 | 61 | p.add_argument('--specific_observation_idcs', type=str, default=None, 62 | help='Only pick a subset of specific observations for each instance.') 63 | 64 | p.add_argument('--max_num_instances_train', type=int, default=-1, 65 | help='If \'data_root\' has more instances, only the first max_num_instances_train are used') 66 | p.add_argument('--max_num_observations_train', type=int, default=50, required=False, 67 | help='If an instance has more observations, only the first max_num_observations_train are used') 68 | p.add_argument('--max_num_instances_val', type=int, default=10, required=False, 69 | help='If \'val_root\' has more instances, only the first max_num_instances_val are used') 70 | p.add_argument('--max_num_observations_val', type=int, default=10, required=False, 71 | help='Maximum numbers of observations per validation instance') 72 | 73 | p.add_argument('--has_params', action='store_true', default=False, 74 | help='Whether each object instance already comes with its own parameter vector.') 75 | 76 | # Model options 77 | p.add_argument('--tracing_steps', type=int, default=10, help='Number of steps of intersection tester.') 78 | p.add_argument('--freeze_networks', action='store_true', 79 | help='Whether to freeze weights of all networks in SRN (not the embeddings!).') 80 | p.add_argument('--fit_single_srn', action='store_true', required=False, 81 | help='Only fit a single SRN for a single scene (not a class of SRNs) --> no hypernetwork') 82 | p.add_argument('--use_unet_renderer', action='store_true', 83 | help='Whether to use a DeepVoxels-style unet as rendering network or a per-pixel 1x1 convnet') 84 | p.add_argument('--embedding_size', type=int, default=256, 85 | help='Dimensionality of latent embedding.') 86 | 87 | opt = p.parse_args() 88 | 89 | 90 | def train(): 91 | # Parses indices of specific observations from comma-separated list. 92 | if opt.specific_observation_idcs is not None: 93 | specific_observation_idcs = util.parse_comma_separated_integers(opt.specific_observation_idcs) 94 | else: 95 | specific_observation_idcs = None 96 | 97 | img_sidelengths = util.parse_comma_separated_integers(opt.img_sidelengths) 98 | batch_size_per_sidelength = util.parse_comma_separated_integers(opt.batch_size_per_img_sidelength) 99 | max_steps_per_sidelength = util.parse_comma_separated_integers(opt.max_steps_per_img_sidelength) 100 | 101 | train_dataset = dataio.SceneClassDataset(root_dir=opt.data_root, 102 | max_num_instances=opt.max_num_instances_train, 103 | max_observations_per_instance=opt.max_num_observations_train, 104 | img_sidelength=img_sidelengths[0], 105 | specific_observation_idcs=specific_observation_idcs, 106 | samples_per_instance=1) 107 | 108 | assert (len(img_sidelengths) == len(batch_size_per_sidelength)), \ 109 | "Different number of image sidelengths passed than batch sizes." 110 | assert (len(img_sidelengths) == len(max_steps_per_sidelength)), \ 111 | "Different number of image sidelengths passed than max steps." 112 | 113 | if not opt.no_validation: 114 | assert (opt.val_root is not None), "No validation directory passed." 115 | 116 | val_dataset = dataio.SceneClassDataset(root_dir=opt.val_root, 117 | max_num_instances=opt.max_num_instances_val, 118 | max_observations_per_instance=opt.max_num_observations_val, 119 | img_sidelength=img_sidelengths[0], 120 | samples_per_instance=1) 121 | collate_fn = val_dataset.collate_fn 122 | val_dataloader = DataLoader(val_dataset, 123 | batch_size=2, 124 | shuffle=False, 125 | drop_last=True, 126 | collate_fn=val_dataset.collate_fn) 127 | 128 | model = SRNsModel(num_instances=train_dataset.num_instances, 129 | latent_dim=opt.embedding_size, 130 | has_params=opt.has_params, 131 | fit_single_srn=opt.fit_single_srn, 132 | use_unet_renderer=opt.use_unet_renderer, 133 | tracing_steps=opt.tracing_steps, 134 | freeze_networks=opt.freeze_networks) 135 | model.train() 136 | model.cuda() 137 | 138 | if opt.checkpoint_path is not None: 139 | print("Loading model from %s" % opt.checkpoint_path) 140 | util.custom_load(model, path=opt.checkpoint_path, 141 | discriminator=None, 142 | optimizer=None, 143 | overwrite_embeddings=opt.overwrite_embeddings) 144 | 145 | ckpt_dir = os.path.join(opt.logging_root, 'checkpoints') 146 | events_dir = os.path.join(opt.logging_root, 'events') 147 | 148 | util.cond_mkdir(opt.logging_root) 149 | util.cond_mkdir(ckpt_dir) 150 | util.cond_mkdir(events_dir) 151 | 152 | # Save command-line parameters log directory. 153 | with open(os.path.join(opt.logging_root, "params.txt"), "w") as out_file: 154 | out_file.write('\n'.join(["%s: %s" % (key, value) for key, value in vars(opt).items()])) 155 | 156 | # Save text summary of model into log directory. 157 | with open(os.path.join(opt.logging_root, "model.txt"), "w") as out_file: 158 | out_file.write(str(model)) 159 | 160 | optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr) 161 | 162 | writer = SummaryWriter(events_dir) 163 | iter = opt.start_step 164 | epoch = iter // len(train_dataset) 165 | step = 0 166 | 167 | print('Beginning training...') 168 | # This loop implements training with an increasing image sidelength. 169 | cum_max_steps = 0 # Tracks max_steps cumulatively over all image sidelengths. 170 | for img_sidelength, max_steps, batch_size in zip(img_sidelengths, max_steps_per_sidelength, 171 | batch_size_per_sidelength): 172 | print("\n" + "#" * 10) 173 | print("Training with sidelength %d for %d steps with batch size %d" % (img_sidelength, max_steps, batch_size)) 174 | print("#" * 10 + "\n") 175 | train_dataset.set_img_sidelength(img_sidelength) 176 | 177 | # Need to instantiate DataLoader every time to set new batch size. 178 | train_dataloader = DataLoader(train_dataset, 179 | batch_size=batch_size, 180 | shuffle=True, 181 | drop_last=True, 182 | collate_fn=train_dataset.collate_fn, 183 | pin_memory=opt.preload) 184 | 185 | cum_max_steps += max_steps 186 | 187 | # Loops over epochs. 188 | while True: 189 | for model_input, ground_truth in train_dataloader: 190 | model_outputs = model(model_input) 191 | 192 | optimizer.zero_grad() 193 | 194 | dist_loss = model.get_image_loss(model_outputs, ground_truth) 195 | reg_loss = model.get_regularization_loss(model_outputs, ground_truth) 196 | latent_loss = model.get_latent_loss() 197 | 198 | weighted_dist_loss = opt.l1_weight * dist_loss 199 | weighted_reg_loss = opt.reg_weight * reg_loss 200 | weighted_latent_loss = opt.kl_weight * latent_loss 201 | 202 | total_loss = (weighted_dist_loss 203 | + weighted_reg_loss 204 | + weighted_latent_loss) 205 | 206 | total_loss.backward() 207 | 208 | optimizer.step() 209 | 210 | print("Iter %07d Epoch %03d L_img %0.4f L_latent %0.4f L_depth %0.4f" % 211 | (iter, epoch, weighted_dist_loss, weighted_latent_loss, weighted_reg_loss)) 212 | 213 | model.write_updates(writer, model_outputs, ground_truth, iter) 214 | writer.add_scalar("scaled_distortion_loss", weighted_dist_loss, iter) 215 | writer.add_scalar("scaled_regularization_loss", weighted_reg_loss, iter) 216 | writer.add_scalar("scaled_latent_loss", weighted_latent_loss, iter) 217 | writer.add_scalar("total_loss", total_loss, iter) 218 | 219 | if iter % opt.steps_til_val == 0 and not opt.no_validation: 220 | print("Running validation set...") 221 | 222 | model.eval() 223 | with torch.no_grad(): 224 | psnrs = [] 225 | ssims = [] 226 | dist_losses = [] 227 | for model_input, ground_truth in val_dataloader: 228 | model_outputs = model(model_input) 229 | 230 | dist_loss = model.get_image_loss(model_outputs, ground_truth).cpu().numpy() 231 | psnr, ssim = model.get_psnr(model_outputs, ground_truth) 232 | psnrs.append(psnr) 233 | ssims.append(ssim) 234 | dist_losses.append(dist_loss) 235 | 236 | model.write_updates(writer, model_outputs, ground_truth, iter, prefix='val_') 237 | 238 | writer.add_scalar("val_dist_loss", np.mean(dist_losses), iter) 239 | writer.add_scalar("val_psnr", np.mean(psnrs), iter) 240 | writer.add_scalar("val_ssim", np.mean(ssims), iter) 241 | model.train() 242 | 243 | iter += 1 244 | step += 1 245 | 246 | if iter == cum_max_steps: 247 | break 248 | 249 | if iter % opt.steps_til_ckpt == 0: 250 | util.custom_save(model, 251 | os.path.join(ckpt_dir, 'epoch_%04d_iter_%06d.pth' % (epoch, iter)), 252 | discriminator=None, 253 | optimizer=optimizer) 254 | 255 | if iter == cum_max_steps: 256 | break 257 | epoch += 1 258 | 259 | util.custom_save(model, 260 | os.path.join(ckpt_dir, 'epoch_%04d_iter_%06d.pth' % (epoch, iter)), 261 | discriminator=None, 262 | optimizer=optimizer) 263 | 264 | 265 | def main(): 266 | train() 267 | 268 | 269 | if __name__ == '__main__': 270 | main() 271 | -------------------------------------------------------------------------------- /train_configs/cars.yml: -------------------------------------------------------------------------------- 1 | # Configuration file for shapenet cars training (see paper section 4, paragraph 4) 2 | # Datasets can be downloaded here: https://drive.google.com/drive/folders/1OkYgeRcIcLOFu1ft5mRODWNQaPJ0ps90 3 | 4 | data_root: "" # Path to cars training dataset. 5 | val_root: # Path to cars validation set, consisting of the same cars as in the training set, but with novel camera views. 6 | logging_root: ./logs/cars 7 | img_sidelengths: 64,128 8 | batch_size_per_img_sidelength: 64,16 # This is for a GPU with 48 GB of memory. Adapt accordingly for your GPU memory. 9 | max_steps_per_img_sidelength: 5000,170000 10 | -------------------------------------------------------------------------------- /train_configs/cars_one_shot.yml: -------------------------------------------------------------------------------- 1 | # Configuration file for few-shot shapenet cars experiment (see paper section 4, paragraph 4) 2 | 3 | data_root: # Path to cars dataset, can be downloaded here: https://drive.google.com/drive/folders/1OkYgeRcIcLOFu1ft5mRODWNQaPJ0ps90 4 | logging_root: ./ 5 | checkpoint_path: # Path to the pre-trained checkpoint 6 | overwrite_embeddings: True 7 | specific_observation_idcs: 64 8 | freeze_networks: True 9 | img_sidelengths: 128 10 | batch_size_per_img_sidelength: 8 11 | max_steps_per_img_sidelength: 100000 12 | no_validation: True 13 | -------------------------------------------------------------------------------- /train_configs/cars_two_shot.yml: -------------------------------------------------------------------------------- 1 | # Configuration file for few-shot shapenet cars experiment (see paper section 4, paragraph 4) 2 | 3 | data_root: # Path to cars dataset, can be downloaded here: https://drive.google.com/drive/folders/1OkYgeRcIcLOFu1ft5mRODWNQaPJ0ps90 4 | logging_root: ./ 5 | checkpoint_path: # Path to the pre-trained checkpoint 6 | overwrite_embeddings: True 7 | specific_observation_idcs: 64,104 8 | freeze_networks: True 9 | img_sidelengths: 128 10 | batch_size_per_img_sidelength: 8 11 | max_steps_per_img_sidelength: 100000 12 | no_validation: True 13 | -------------------------------------------------------------------------------- /train_configs/chairs.yml: -------------------------------------------------------------------------------- 1 | # Configuration file for shapenet chairs training (see paper section 4, paragraph 4) 2 | # Datasets can be downloaded here: https://drive.google.com/drive/folders/1OkYgeRcIcLOFu1ft5mRODWNQaPJ0ps90 3 | 4 | data_root: "" # Path to chairs training dataset. 5 | val_root: # Path to chairs validation set, consisting of the same cars as in the training set, but with novel camera views. 6 | logging_root: ./logs/chairs 7 | img_sidelengths: 64,128 8 | batch_size_per_img_sidelength: 64,16 # This is for a GPU with 48 GB of memory. Adapt accordingly for your GPU memory. 9 | max_steps_per_img_sidelength: 20000,554000 10 | -------------------------------------------------------------------------------- /train_configs/shepard_metzler.yml: -------------------------------------------------------------------------------- 1 | # Configuration file for shepard-metzler training (see paper section 4, paragraph 3) 2 | # Datasets can be downloaded here: https://drive.google.com/drive/folders/1OkYgeRcIcLOFu1ft5mRODWNQaPJ0ps90 3 | 4 | data_root: "" # Path to shepard-metzler training dataset. 5 | val_root: # Path to shepard-metzler dataset, consisting of the same cars as in the training set, but with novel camera views. 6 | logging_root: ./logs/shepard_metzler 7 | img_sidelengths: 64 8 | batch_size_per_img_sidelength: 48 # This is for a GPU with 48 GB of memory. Adapt accordingly for your GPU memory. 9 | max_steps_per_img_sidelength: 352000 10 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import os, struct, math 2 | import numpy as np 3 | import torch 4 | from glob import glob 5 | 6 | import cv2 7 | import torch.nn.functional as F 8 | 9 | import matplotlib.pyplot as plt 10 | from mpl_toolkits.axes_grid1 import make_axes_locatable 11 | 12 | 13 | def get_latest_file(root_dir): 14 | """Returns path to latest file in a directory.""" 15 | list_of_files = glob.glob(os.path.join(root_dir, '*')) 16 | latest_file = max(list_of_files, key=os.path.getctime) 17 | return latest_file 18 | 19 | 20 | def parse_comma_separated_integers(string): 21 | return list(map(int, string.split(','))) 22 | 23 | 24 | def convert_image(img): 25 | if not isinstance(img, np.ndarray): 26 | img = np.array(img.cpu().detach().numpy()) 27 | 28 | img = img.squeeze() 29 | img = img.transpose(1,2,0) 30 | img += 1. 31 | img /= 2. 32 | img *= 2**8 - 1 33 | img = img.round().clip(0, 2**8-1) 34 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 35 | return img 36 | 37 | def write_img(img, path): 38 | cv2.imwrite(path, img.astype(np.uint8)) 39 | 40 | 41 | def in_out_to_param_count(in_out_tuples): 42 | return np.sum([np.prod(in_out) + in_out[-1] for in_out in in_out_tuples]) 43 | 44 | def parse_intrinsics(filepath, trgt_sidelength=None, invert_y=False): 45 | # Get camera intrinsics 46 | with open(filepath, 'r') as file: 47 | f, cx, cy, _ = map(float, file.readline().split()) 48 | grid_barycenter = torch.Tensor(list(map(float, file.readline().split()))) 49 | scale = float(file.readline()) 50 | height, width = map(float, file.readline().split()) 51 | 52 | try: 53 | world2cam_poses = int(file.readline()) 54 | except ValueError: 55 | world2cam_poses = None 56 | 57 | if world2cam_poses is None: 58 | world2cam_poses = False 59 | 60 | world2cam_poses = bool(world2cam_poses) 61 | 62 | if trgt_sidelength is not None: 63 | cx = cx/width * trgt_sidelength 64 | cy = cy/height * trgt_sidelength 65 | f = trgt_sidelength / height * f 66 | 67 | fx = f 68 | if invert_y: 69 | fy = -f 70 | else: 71 | fy = f 72 | 73 | # Build the intrinsic matrices 74 | full_intrinsic = np.array([[fx, 0., cx, 0.], 75 | [0., fy, cy, 0], 76 | [0., 0, 1, 0], 77 | [0, 0, 0, 1]]) 78 | 79 | return full_intrinsic, grid_barycenter, scale, world2cam_poses 80 | 81 | def lin2img(tensor): 82 | batch_size, num_samples, channels = tensor.shape 83 | sidelen = np.sqrt(num_samples).astype(int) 84 | return tensor.permute(0,2,1).view(batch_size, channels, sidelen, sidelen) 85 | 86 | def num_divisible_by_2(number): 87 | i = 0 88 | while not number%2: 89 | number = number // 2 90 | i += 1 91 | 92 | return i 93 | 94 | def cond_mkdir(path): 95 | if not os.path.exists(path): 96 | os.makedirs(path) 97 | 98 | 99 | def load_pose(filename): 100 | assert os.path.isfile(filename) 101 | lines = open(filename).read().splitlines() 102 | assert len(lines) == 4 103 | lines = [[x[0],x[1],x[2],x[3]] for x in (x.split(" ") for x in lines)] 104 | return torch.from_numpy(np.asarray(lines).astype(np.float32)) 105 | 106 | 107 | def normalize(img): 108 | return (img - img.min()) / (img.max() - img.min()) 109 | 110 | 111 | def write_image(writer, name, img, iter): 112 | writer.add_image(name, normalize(img.permute([0,3,1,2])), iter) 113 | 114 | 115 | def print_network(net): 116 | model_parameters = filter(lambda p: p.requires_grad, net.parameters()) 117 | params = sum([np.prod(p.size()) for p in model_parameters]) 118 | print("%d"%params) 119 | 120 | 121 | def custom_load(model, path, discriminator=None, overwrite_embeddings=False, overwrite_renderer=False, optimizer=None): 122 | if os.path.isdir(path): 123 | checkpoint_path = sorted(glob(os.path.join(path, "*.pth")))[-1] 124 | else: 125 | checkpoint_path = path 126 | 127 | whole_dict = torch.load(checkpoint_path) 128 | 129 | if overwrite_embeddings: 130 | del whole_dict['model']['latent_codes.weight'] 131 | 132 | if overwrite_renderer: 133 | keys_to_remove = [key for key in whole_dict['model'].keys() if 'rendering_net' in key] 134 | for key in keys_to_remove: 135 | print(key) 136 | whole_dict['model'].pop(key, None) 137 | 138 | state = model.state_dict() 139 | state.update(whole_dict['model']) 140 | model.load_state_dict(state) 141 | 142 | if discriminator: 143 | discriminator.load_state_dict(whole_dict['discriminator']) 144 | 145 | if optimizer: 146 | optimizer.load_state_dict(whole_dict['optimizer']) 147 | 148 | 149 | def custom_save(model, path, discriminator=None, optimizer=None): 150 | whole_dict = {'model':model.state_dict()} 151 | if discriminator: 152 | whole_dict.update({'discriminator':discriminator.state_dict()}) 153 | if optimizer: 154 | whole_dict.update({'optimizer':optimizer.state_dict()}) 155 | 156 | torch.save(whole_dict, path) 157 | 158 | 159 | def show_images(images, titles=None): 160 | """Display a list of images in a single figure with matplotlib. 161 | 162 | Parameters 163 | --------- 164 | images: List of np.arrays compatible with plt.imshow. 165 | 166 | cols (Default = 1): Number of columns in figure (number of rows is 167 | set to np.ceil(n_images/float(cols))). 168 | 169 | titles: List of titles corresponding to each image. Must have 170 | the same length as titles. 171 | """ 172 | assert ((titles is None) or (len(images) == len(titles))) 173 | cols = np.ceil(np.sqrt(len(images))).astype(int) 174 | 175 | n_images = len(images) 176 | if titles is None: titles = ['Image (%d)' % i for i in range(1, n_images + 1)] 177 | fig = plt.figure() 178 | for n, (image, title) in enumerate(zip(images, titles)): 179 | a = fig.add_subplot(np.ceil(n_images / float(cols)), cols, n + 1) 180 | im = a.imshow(image) 181 | 182 | a.get_xaxis().set_visible(False) 183 | a.get_yaxis().set_visible(False) 184 | 185 | if len(images) < 10: 186 | divider = make_axes_locatable(a) 187 | cax = divider.append_axes("right", size="5%", pad=0.05) 188 | fig.colorbar(im, cax=cax, orientation='vertical') 189 | 190 | 191 | plt.tight_layout() 192 | 193 | # fig.set_size_inches(np.array(fig.get_size_inches()) * n_images) 194 | return fig 195 | 196 | --------------------------------------------------------------------------------