├── .gitignore ├── LICENSE ├── README.md ├── data ├── __init__.py ├── dataset_sampling.py ├── error_sources.py ├── load_scene.py └── scannet_dataset.py ├── docs ├── index.html └── static │ ├── css │ ├── bulma-carousel.min.css │ ├── bulma-slider.min.css │ ├── bulma.css.map.txt │ ├── bulma.min.css │ ├── fontawesome.all.min.css │ └── index.css │ ├── images │ ├── pipeline.jpg │ ├── video_s708_reduced.mp4 │ ├── video_s710_reduced.mp4 │ ├── video_s738_reduced.mp4 │ ├── video_s758_reduced.mp4 │ └── video_s781_reduced.mp4 │ ├── js │ ├── bulma-carousel.js │ ├── bulma-carousel.min.js │ ├── bulma-slider.js │ ├── bulma-slider.min.js │ ├── fontawesome.all.min.js │ └── index.js │ └── ppt │ └── 20220809_dense_depth_priors_nerf.pptx ├── metric ├── __init__.py └── rmse.py ├── model ├── __init__.py ├── cspn.py ├── cspn_affinity.py └── run_nerf_helpers.py ├── preprocessing ├── CMakeLists.txt ├── camera │ ├── CMakeLists.txt │ ├── include │ │ └── camera.h │ └── src │ │ └── camera.cpp ├── extract_scannet_scene.cpp ├── io │ ├── CMakeLists.txt │ ├── include │ │ ├── camera_frames.h │ │ ├── file_utils.h │ │ └── rgbd.h │ └── src │ │ ├── camera_frames.cpp │ │ ├── file_utils.cpp │ │ └── rgbd.cpp └── io_colmap │ ├── CMakeLists.txt │ ├── include │ └── colmap_reader.h │ └── src │ └── colmap_reader.cpp ├── requirements.txt ├── run_depth_completion.py ├── run_nerf.py └── train_utils ├── __init__.py ├── hyperparameter_update.py └── logging.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # IPython 78 | profile_default/ 79 | ipython_config.py 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # pipenv 85 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 86 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 87 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 88 | # install all needed dependencies. 89 | #Pipfile.lock 90 | 91 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 92 | __pypackages__/ 93 | 94 | # Celery stuff 95 | celerybeat-schedule 96 | celerybeat.pid 97 | 98 | # SageMath parsed files 99 | *.sage.py 100 | 101 | # Environments 102 | .env 103 | .venv 104 | env/ 105 | venv/ 106 | ENV/ 107 | env.bak/ 108 | venv.bak/ 109 | 110 | # Spyder project settings 111 | .spyderproject 112 | .spyproject 113 | 114 | # Rope project settings 115 | .ropeproject 116 | 117 | # mkdocs documentation 118 | /site 119 | 120 | # mypy 121 | .mypy_cache/ 122 | .dmypy.json 123 | dmypy.json 124 | 125 | # Pyre type checker 126 | .pyre/ 127 | 128 | # ide 129 | .vscode/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Barbara Roessle 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dense Depth Priors for NeRF from Sparse Input Views 2 | This repository contains the implementation of the CVPR 2022 paper: Dense Depth Priors for Neural Radiance Fields from Sparse Input Views. 3 | 4 | [Arxiv](https://arxiv.org/abs/2112.03288) | [Video](https://t.co/zjH9JvkuQq) | [Project Page](https://barbararoessle.github.io/dense_depth_priors_nerf/) 5 | 6 | ![](docs/static/images/pipeline.jpg) 7 | 8 | ## Step 1: Train Dense Depth Priors 9 | You can skip this step and download the depth completion model trained on ScanNet from [here](https://drive.google.com/drive/folders/1HTyigHPJKZKBWzGFoY8J2bcS-h8_SfX9?usp=sharing). 10 | 11 | ### Prepare ScanNet 12 | Extract the ScanNet dataset e.g. using [SenseReader](https://github.com/ScanNet/ScanNet/tree/master/SensReader/python) and place the files `scannetv2_test.txt`, 13 | `scannetv2_train.txt`, `scannetv2_val.txt` from [ScanNet Benchmark](https://github.com/ScanNet/ScanNet/tree/master/Tasks/Benchmark) into the same directory. 14 | 15 | ### Precompute Sampling Locations 16 | Run the [COLMAP](https://github.com/colmap/colmap) feature extractor on all RGB images of ScanNet. 17 | For this, the RGB files need to be isolated from the other scene data, f.ex. create a temporary directory `tmp` and copy each `/color/` to `tmp//color/`. 18 | Then run: 19 | ``` 20 | colmap feature_extractor --database_path scannet_sift_database.db --image_path tmp 21 | ``` 22 | When working with different relative paths or filenames, the database reading in `scannet_dataset.py` needs to be adapted accordingly. 23 | 24 | ### Download pretrained ResNet 25 | Download the pretrained ResNet from [here](https://drive.google.com/file/d/17adZHo5dkcU8_M_6OvYzGUTDguF6k-Qu/view) . 26 | 27 | ### Train 28 | ``` 29 | python3 run_depth_completion.py train --dataset_dir --db_path --pretrained_resnet_path --ckpt_dir 30 | ``` 31 | Checkpoints are written into a subdirectory of the provided checkpoint directory. The subdirectory is named by the training start time in the format `jjjjmmdd_hhmmss`, which also serves as experiment name in the following. 32 | 33 | ### Test 34 | ``` 35 | python3 run_depth_completion.py test --expname --dataset_dir --db_path --ckpt_dir 36 | ``` 37 | 38 | ## Step 2: Optimizing NeRF with Dense Depth Priors 39 | ### Prepare scenes 40 | You can skip the scene preparation and directly download the [scenes](https://drive.google.com/drive/folders/1vJ5sZaYljmaxMc1vltm6u4GUH11oqfYU?usp=sharing). 41 | To prepare a scene and render sparse depth maps from COLMAP sparse reconstructions, run: 42 | ``` 43 | cd preprocessing 44 | mkdir build 45 | cd build 46 | cmake .. 47 | make -j 48 | ./extract_scannet_scene 49 | ``` 50 | The scene directory must contain the following: 51 | - `train.csv`: list of training views from the ScanNet scene 52 | - `test.csv`: list of test views from the ScanNet scene 53 | - `config.json`: parameters for the scene: 54 | - `name`: name of the scene 55 | - `max_depth`: maximal depth value in the scene, larger values are invalidated 56 | - `dist2m`: scaling factor that scales the sparse reconstruction to meters 57 | - `rgb_only`: write RGB only, f.ex. to get input for COLMAP 58 | - `colmap`: directory containing 2 sparse reconstruction: 59 | - `sparse`: reconstruction run on train and test images together to determine the camera poses 60 | - `sparse_train`, reconstruction run on train images alone to determine the sparse depth maps. 61 | 62 | Please check the provided scenes as an example. 63 | The option `rgb_only` is used to preprocess the RGB images before running COLMAP. This cuts dark image borders from calibration, which harm the NeRF optimization. It is essential to crop them before running COLMAP to ensure that the determined intrinsics match the cropped RGB images. 64 | 65 | ### Optimize 66 | ``` 67 | python3 run_nerf.py train --scene_id --data_dir --depth_prior_network_path --ckpt_dir 68 | ``` 69 | Checkpoints are written into a subdirectory of the provided checkpoint directory. The subdirectory is named by the training start time in the format `jjjjmmdd_hhmmss`, which also serves as experiment name in the following. 70 | 71 | ### Test 72 | ``` 73 | python3 run_nerf.py test --expname --data_dir --ckpt_dir 74 | ``` 75 | The test results are stored in the experiment directory. 76 | Running `python3 run_nerf.py test_opt ...` performs test time optimization of the latent codes before computing the test metrics. 77 | 78 | ### Render Video 79 | ``` 80 | python3 run_nerf.py video --expname --data_dir --ckpt_dir 81 | ``` 82 | The video is stored in the experiment directory. 83 | 84 | --- 85 | 86 | ### Citation 87 | If you find this repository useful, please cite: 88 | ``` 89 | @inproceedings{roessle2022depthpriorsnerf, 90 | title={Dense Depth Priors for Neural Radiance Fields from Sparse Input Views}, 91 | author={Barbara Roessle and Jonathan T. Barron and Ben Mildenhall and Pratul P. Srinivasan and Matthias Nie{\ss}ner}, 92 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 93 | month={June}, 94 | year={2022} 95 | ``` 96 | 97 | ### Acknowledgements 98 | We thank [nerf-pytorch](https://github.com/yenchenlin/nerf-pytorch) and [CSPN](https://github.com/XinJCheng/CSPN), from which this repository borrows code. 99 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .scannet_dataset import ScanNetDataset, convert_depth_completion_scaling_to_m, convert_m_to_depth_completion_scaling, \ 2 | get_pretrained_normalize, resize_sparse_depth 3 | from .load_scene import load_scene 4 | from .dataset_sampling import create_random_subsets 5 | -------------------------------------------------------------------------------- /data/dataset_sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import random_split 3 | 4 | def compute_samples_per_subset(sample_count, validate_on_at_least_n_samples): 5 | validate_on_at_least_n_samples = min(validate_on_at_least_n_samples, sample_count) 6 | number_subsets = int(sample_count / validate_on_at_least_n_samples) 7 | samples_per_subset = int(sample_count / number_subsets) 8 | extra_sample_subsets = sample_count % samples_per_subset 9 | normal_subsets = number_subsets - extra_sample_subsets 10 | return samples_per_subset, normal_subsets, extra_sample_subsets 11 | 12 | def create_random_subsets(dataset, validate_on_at_least_n_samples, device='cpu'): 13 | samples_per_subset, normal_subsets, extra_sample_subsets = compute_samples_per_subset(len(dataset), validate_on_at_least_n_samples) 14 | subsets = random_split(dataset, (samples_per_subset,) * normal_subsets + (samples_per_subset + 1,) * extra_sample_subsets, \ 15 | torch.Generator(device=device)) 16 | return subsets 17 | -------------------------------------------------------------------------------- /data/error_sources.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def add_missing_depth(depth, valid_depth, p=0.1, invalid_depth_value=0): 4 | n_pixels = valid_depth.numel() 5 | n_valid = valid_depth.sum() 6 | p_before = float(n_pixels - n_valid) / float(n_pixels) 7 | p_gap = p - p_before 8 | if p_gap <= 0.: 9 | return depth, valid_depth 10 | else: 11 | p_to_be_invalidated = p_gap * float(n_pixels) / float(n_valid) 12 | invalid = torch.rand_like(depth) < p_to_be_invalidated 13 | valid_depth[invalid] = False 14 | depth[invalid] = invalid_depth_value 15 | return depth, valid_depth 16 | 17 | def add_quadratic_depth_noise(depth, valid_depth, a=1.68e-3, b=6.58e-3, c=4.78e-2): 18 | std = a * depth[valid_depth].pow(2) + b * depth[valid_depth] + c 19 | noise = torch.randn_like(std) * std 20 | depth[valid_depth] = (depth[valid_depth] + noise).clamp(min=0.) 21 | return depth 22 | -------------------------------------------------------------------------------- /data/load_scene.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import json 5 | import cv2 6 | 7 | def read_files(basedir, rgb_file, depth_file): 8 | fname = os.path.join(basedir, rgb_file) 9 | img = cv2.imread(fname, cv2.IMREAD_UNCHANGED) 10 | if img.shape[-1] == 4: 11 | convert_fn = cv2.COLOR_BGRA2RGBA 12 | else: 13 | convert_fn = cv2.COLOR_BGR2RGB 14 | img = (cv2.cvtColor(img, convert_fn) / 255.).astype(np.float32) # keep 4 channels (RGBA) if available 15 | depth_fname = os.path.join(basedir, depth_file) 16 | depth = cv2.imread(depth_fname, cv2.IMREAD_UNCHANGED).astype(np.float64) 17 | return img, depth 18 | 19 | def load_ground_truth_depth(basedir, train_filenames, image_size, depth_scaling_factor): 20 | H, W = image_size 21 | gt_depths = [] 22 | gt_valid_depths = [] 23 | for filename in train_filenames: 24 | filename = filename.replace("rgb", "target_depth") 25 | filename = filename.replace(".jpg", ".png") 26 | gt_depth_fname = os.path.join(basedir, filename) 27 | if os.path.exists(gt_depth_fname): 28 | gt_depth = cv2.imread(gt_depth_fname, cv2.IMREAD_UNCHANGED).astype(np.float64) 29 | gt_valid_depth = gt_depth > 0.5 30 | gt_depth = (gt_depth / depth_scaling_factor).astype(np.float32) 31 | else: 32 | gt_depth = np.zeros((H, W)) 33 | gt_valid_depth = np.full_like(gt_depth, False) 34 | gt_depths.append(np.expand_dims(gt_depth, -1)) 35 | gt_valid_depths.append(gt_valid_depth) 36 | gt_depths = np.stack(gt_depths, 0) 37 | gt_valid_depths = np.stack(gt_valid_depths, 0) 38 | return gt_depths, gt_valid_depths 39 | 40 | def load_scene(basedir): 41 | splits = ['train', 'val', 'test', 'video'] 42 | 43 | all_imgs = [] 44 | all_depths = [] 45 | all_valid_depths = [] 46 | all_poses = [] 47 | all_intrinsics = [] 48 | counts = [0] 49 | filenames = [] 50 | for s in splits: 51 | if os.path.exists(os.path.join(basedir, 'transforms_{}.json'.format(s))): 52 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp: 53 | meta = json.load(fp) 54 | 55 | if 'train' in s: 56 | near = float(meta['near']) 57 | far = float(meta['far']) 58 | depth_scaling_factor = float(meta['depth_scaling_factor']) 59 | 60 | imgs = [] 61 | depths = [] 62 | valid_depths = [] 63 | poses = [] 64 | intrinsics = [] 65 | 66 | for frame in meta['frames']: 67 | if len(frame['file_path']) != 0 or len(frame['depth_file_path']) != 0: 68 | img, depth = read_files(basedir, frame['file_path'], frame['depth_file_path']) 69 | 70 | if depth.ndim == 2: 71 | depth = np.expand_dims(depth, -1) 72 | 73 | valid_depth = depth[:, :, 0] > 0.5 # 0 values are invalid depth 74 | depth = (depth / depth_scaling_factor).astype(np.float32) 75 | 76 | filenames.append(frame['file_path']) 77 | 78 | imgs.append(img) 79 | depths.append(depth) 80 | valid_depths.append(valid_depth) 81 | 82 | poses.append(np.array(frame['transform_matrix'])) 83 | H, W = img.shape[:2] 84 | fx, fy, cx, cy = frame['fx'], frame['fy'], frame['cx'], frame['cy'] 85 | intrinsics.append(np.array((fx, fy, cx, cy))) 86 | 87 | counts.append(counts[-1] + len(poses)) 88 | if len(imgs) > 0: 89 | all_imgs.append(np.array(imgs)) 90 | all_depths.append(np.array(depths)) 91 | all_valid_depths.append(np.array(valid_depths)) 92 | all_poses.append(np.array(poses).astype(np.float32)) 93 | all_intrinsics.append(np.array(intrinsics).astype(np.float32)) 94 | else: 95 | counts.append(counts[-1]) 96 | 97 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(len(splits))] 98 | imgs = np.concatenate(all_imgs, 0) 99 | depths = np.concatenate(all_depths, 0) 100 | valid_depths = np.concatenate(all_valid_depths, 0) 101 | poses = np.concatenate(all_poses, 0) 102 | intrinsics = np.concatenate(all_intrinsics, 0) 103 | 104 | gt_depths, gt_valid_depths = load_ground_truth_depth(basedir, filenames, (H, W), depth_scaling_factor) 105 | 106 | return imgs, depths, valid_depths, poses, H, W, intrinsics, near, far, i_split, gt_depths, gt_valid_depths 107 | -------------------------------------------------------------------------------- /data/scannet_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import math 4 | import sqlite3 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import cv2 9 | import torch 10 | from torchvision import transforms 11 | 12 | from .error_sources import add_missing_depth, add_quadratic_depth_noise 13 | 14 | def is_in_list(file, list_to_check): 15 | for h in list_to_check: 16 | if h in file: 17 | return True 18 | return False 19 | 20 | def get_whitelist(dataset_dir, dataset_split): 21 | whitelist_txt = os.path.join(dataset_dir, "scannetv2_{}.txt".format(dataset_split)) 22 | scenes = pd.read_csv(whitelist_txt, names=["scenes"], header=None) 23 | return scenes["scenes"].tolist() 24 | 25 | def apply_filter(files, dataset_dir, dataset_split): 26 | whitelist = get_whitelist(dataset_dir, dataset_split) 27 | return [f for f in files if is_in_list(f, whitelist)] 28 | 29 | def read_rgb(rgb_file): 30 | bgr = cv2.imread(rgb_file) 31 | rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) 32 | assert rgb.shape[2] == 3 33 | 34 | to_tensor = transforms.ToTensor() 35 | rgb = to_tensor(rgb) 36 | return rgb 37 | 38 | def read_depth(depth_file): 39 | depth = cv2.imread(depth_file, cv2.IMREAD_UNCHANGED) 40 | assert len(depth.shape) == 2 41 | 42 | valid_depth = depth.astype('bool') 43 | depth = depth.astype('float32') 44 | 45 | # 16bit integer range corresponds to range 0 .. 65.54m 46 | # use the first quarter of this range up to 16.38m and invalidate depth values beyond 47 | # scale depth, such that range 0 .. 1 corresponds to range 0 .. 16.38m 48 | max_depth = np.float32(2 ** 16 - 1) / 4. 49 | depth = depth / max_depth 50 | invalidate_mask = depth > 1. 51 | depth[invalidate_mask] = 0. 52 | valid_depth[invalidate_mask] = False 53 | return transforms.functional.to_tensor(depth), transforms.functional.to_tensor(valid_depth) 54 | 55 | def convert_depth_completion_scaling_to_m(depth): 56 | # convert from depth completion scaling to meter, that means map range 0 .. 1 to range 0 .. 16,38m 57 | return depth * (2 ** 16 - 1) / 4000. 58 | 59 | def convert_m_to_depth_completion_scaling(depth): 60 | # convert from meter to depth completion scaling, which maps range 0 .. 16,38m to range 0 .. 1 61 | return depth * 4000. / (2 ** 16 - 1) 62 | 63 | def get_normalize(mean, std): 64 | normalize = transforms.Normalize(mean=mean, std=std) 65 | unnormalize = transforms.Normalize(mean=np.divide(-mean, std), std=(1. / std)) 66 | return normalize, unnormalize 67 | 68 | def get_pretrained_normalize(): 69 | normalize = dict() 70 | unnormalize = dict() 71 | mean = np.array([0.485, 0.456, 0.406]) 72 | std = np.array([0.229, 0.224, 0.225]) 73 | normalize['rgb'], unnormalize['rgb'] = get_normalize(mean, std) 74 | normalize['rgbd'], unnormalize['rgbd'] = get_normalize(np.concatenate((mean, [0.,]), axis=0), np.concatenate((std, [1.,]), axis=0)) 75 | return normalize, unnormalize 76 | 77 | def resize_sparse_depth(depths, valid_depths, size): 78 | device = depths.device 79 | orig_size = (depths.shape[1], depths.shape[2]) 80 | col, row = torch.meshgrid(torch.tensor(range(orig_size[1])), torch.tensor(range(orig_size[0])), indexing='ij') 81 | rowcol2rowcol = torch.stack((row.t(), col.t()), -1) 82 | rowcol2rowcol = rowcol2rowcol.unsqueeze(0).expand(depths.shape[0], -1, -1, -1) 83 | image_index = torch.arange(depths.shape[0]).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(-1, orig_size[0], orig_size[1], 1) 84 | rowcol2rowcol = torch.cat((image_index, rowcol2rowcol), -1) 85 | factor_h, factor_w = float(size[0]) / float(orig_size[0]), float(size[1]) / float(orig_size[1]) 86 | depths_out = torch.zeros((depths.shape[0], size[0], size[1]), device=device) 87 | valid_depths_out = torch.zeros_like(depths_out).bool() 88 | idx_row_col = rowcol2rowcol[valid_depths] 89 | idx_row_col_resized = idx_row_col 90 | idx_row_col_resized = ((idx_row_col + 0.5) * torch.tensor((1., factor_h, factor_w))).long() # consider pixel centers 91 | depths_out[idx_row_col_resized[..., 0], idx_row_col_resized[..., 1], idx_row_col_resized[..., 2]] \ 92 | = depths[idx_row_col[..., 0], idx_row_col[..., 1], idx_row_col[..., 2]] 93 | valid_depths_out[idx_row_col_resized[..., 0], idx_row_col_resized[..., 1], idx_row_col_resized[..., 2]] = True 94 | return depths_out, valid_depths_out 95 | 96 | class ScanNetDataset(torch.utils.data.dataset.Dataset): 97 | def __init__(self, dataset_dir, data_split, db_path, random_rot=0, load_size=(240, 320), \ 98 | horizontal_flip=False, color_jitter=None, depth_noise=False, missing_depth_percent=0.998): 99 | super(ScanNetDataset, self).__init__() 100 | 101 | # apply train val test split 102 | self.dataset_dir = dataset_dir 103 | dir_suffix = "" 104 | if data_split == "test": 105 | dir_suffix = "_test" 106 | input_scenes_dir = "scans{}".format(dir_suffix) 107 | filtered_scenes = [os.path.join(input_scenes_dir, s) for s in 108 | apply_filter(os.listdir(os.path.join(dataset_dir, input_scenes_dir)), dataset_dir, data_split)] 109 | 110 | # create file list 111 | self.rgb_files = [] 112 | for rel_scene_path in filtered_scenes: 113 | rel_scene_color_path = os.path.join(rel_scene_path, "color") 114 | for rgb in os.listdir(os.path.join(dataset_dir, rel_scene_color_path)): 115 | rel_rgb_path = os.path.join(rel_scene_color_path, rgb) 116 | self.rgb_files.append(rel_rgb_path) 117 | 118 | # transformation 119 | self.normalize, self.unnormalize = get_pretrained_normalize() 120 | self.random_rot = random_rot 121 | self.load_size = load_size 122 | self.horizontal_flip = horizontal_flip 123 | self.color_jitter = color_jitter 124 | 125 | # depth sampling 126 | self.missing_depth_percent = missing_depth_percent # add percentage of missing depth 127 | self.depth_noise = depth_noise # add gaussian depth noise 128 | # open keypoint database for sampling at image keypoints 129 | self.feature_db = sqlite3.connect(db_path).cursor() 130 | self.id2dbid = dict((n[:-4], id) for n, id in self.feature_db.execute("SELECT name, image_id FROM images")) 131 | 132 | def __getitem__(self, index): 133 | rgb_file = os.path.join(self.dataset_dir, self.rgb_files[index]) 134 | depth_file = rgb_file.replace("color", "depth").replace(".jpg", ".png") 135 | rgb = read_rgb(rgb_file) 136 | depth, valid_depth = read_depth(depth_file) 137 | # pad to make aspect ratio of rgb (968x1296) and depth (480x640) match 138 | if rgb.shape[1] == 968 and rgb.shape[2] == 1296: 139 | # pad 2 pixels on both sides in height dimension 140 | pad_rgb_height = 2 141 | rgb = torch.nn.functional.pad(rgb, (0, 0, pad_rgb_height, pad_rgb_height)) 142 | depth_shape = depth.shape 143 | rgb_shape = rgb.shape 144 | scale_rgb = (float(depth_shape[1]) / float(rgb_shape[1]), float(depth_shape[2]) / float(rgb_shape[2])) 145 | rgb = transforms.functional.resize(rgb, (depth_shape[1], depth_shape[2]), interpolation=transforms.functional.InterpolationMode.NEAREST) 146 | else: 147 | pad_rgb_height = 0. 148 | scale_rgb = (1., 1.) 149 | id = self.rgb_files[index][:-4].replace("scans_test/", "").replace("scans/", "") 150 | 151 | # precompute random rotation 152 | rot = random.uniform(-self.random_rot, self.random_rot) 153 | 154 | # precompute resize and crop 155 | tan_abs_rot = math.tan(math.radians(abs(rot))) 156 | border_width = math.ceil(self.load_size[0] * tan_abs_rot) 157 | border_height = math.ceil(self.load_size[1] * tan_abs_rot) 158 | top = math.floor(0.5 * border_height) 159 | left = math.floor(0.5 * border_width) 160 | resize_size = (self.load_size[0] + border_height, self.load_size[1] + border_width) 161 | 162 | # precompute random horizontal flip 163 | apply_hflip = self.horizontal_flip and random.random() > 0.5 164 | 165 | # create a sparsified depth and a complete target depth 166 | target_valid_depth = valid_depth.clone() 167 | target_depth = depth.clone() 168 | depth, valid_depth = self.sample_depth_at_image_features(depth, valid_depth, id, scale_rgb, pad_rgb_height) 169 | depth, valid_depth = add_missing_depth(depth, valid_depth, self.missing_depth_percent) 170 | 171 | rgbd = torch.cat((rgb, depth), 0) 172 | data = {'rgbd': rgbd, 'valid_depth' : valid_depth, 'target_depth' : target_depth, 'target_valid_depth' : target_valid_depth} 173 | 174 | # apply transformation 175 | for key in data.keys(): 176 | # resize 177 | if key == 'rgbd': 178 | # resize such that sparse points are preserved 179 | B_depth, data['valid_depth'] = resize_sparse_depth(data['rgbd'][3, :, :].unsqueeze(0), data['valid_depth'], resize_size) 180 | B_rgb = transforms.functional.resize(data['rgbd'][:3, :, :], resize_size, interpolation=transforms.functional.InterpolationMode.NEAREST) 181 | data['rgbd'] = torch.cat((B_rgb, B_depth), 0) 182 | else: 183 | # avoid blurring the depth channel with invalid values by using interpolation mode nearest 184 | data[key] = transforms.functional.resize(data[key], resize_size, interpolation=transforms.functional.InterpolationMode.NEAREST) 185 | 186 | # augment color 187 | if key == 'rgbd': 188 | if self.color_jitter is not None: 189 | cj = transforms.ColorJitter(brightness=self.color_jitter, contrast=self.color_jitter, saturation=self.color_jitter, \ 190 | hue=self.color_jitter) 191 | data['rgbd'][:3, :, :] = cj(data['rgbd'][:3, :, :]) 192 | 193 | # rotate 194 | if self.random_rot != 0: 195 | data[key] = transforms.functional.rotate(data[key], rot) 196 | 197 | # crop 198 | data[key] = transforms.functional.crop(data[key], top, left, self.load_size[0], self.load_size[1]) 199 | 200 | # horizontal flip 201 | if apply_hflip: 202 | data[key] = transforms.functional.hflip(data[key]) 203 | 204 | # normalize 205 | if key == 'rgbd': 206 | data[key] = self.normalize['rgbd'](data[key]) 207 | # scale depth according to resizing due to rotation 208 | data[key][3, :, :] /= (1. + tan_abs_rot) 209 | 210 | # add depth noise 211 | if self.depth_noise: 212 | data['rgbd'][3, :, :] = convert_m_to_depth_completion_scaling(add_quadratic_depth_noise( \ 213 | convert_depth_completion_scaling_to_m(data['rgbd'][3, :, :]), data['valid_depth'].squeeze())) 214 | 215 | return data 216 | 217 | def sample_depth_at_image_features(self, depth, valid_depth, id, scale, pad_height): 218 | depth_shape = depth.shape 219 | db_id = self.id2dbid[id] 220 | # 6 affine coordinates 221 | keypoints = [np.frombuffer(coords[0], dtype=np.float32).reshape(-1, 6) if coords[0] is not None else None for coords in self.feature_db.execute( \ 222 | "SELECT data FROM keypoints WHERE image_id=={}".format(db_id))] 223 | if keypoints[0] is not None: 224 | cols = keypoints[0][:, 0] 225 | rows = keypoints[0][:, 1] 226 | rows = rows + pad_height 227 | cols = (cols * scale[1]).astype(int) 228 | rows = (rows * scale[0]).astype(int) 229 | row_col_mask = (rows >= 0) & (rows < depth_shape[1]) & (cols >= 0) & (cols < depth_shape[2]) 230 | rows = rows[row_col_mask] 231 | cols = cols[row_col_mask] 232 | keypoints_mask = torch.full(depth_shape, False) 233 | keypoints_mask[0, rows, cols] = True 234 | valid_depth = torch.logical_and(keypoints_mask, valid_depth) 235 | depth[torch.logical_not(valid_depth)] = 0. 236 | else: 237 | depth = torch.zeros_like(depth) 238 | valid_depth = torch.zeros_like(valid_depth) 239 | return depth, valid_depth 240 | 241 | def __len__(self): 242 | return len(self.rgb_files) 243 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 7 | 8 | 9 | Dense Depth Priors for Neural Radiance Fields from Sparse Input Views 10 | 11 | 13 | 14 | 15 | 16 | 17 | 18 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 60 | 61 |
62 |
63 |
64 |
65 |
66 |

Dense Depth Priors for Neural Radiance Fields from Sparse Input Views

67 |
68 | 69 | Barbara Roessle1, 70 | 71 | 72 | Jonathan T. Barron2, 73 | 74 | 75 | Ben Mildenhall2, 76 | 77 | 78 | Pratul P. Srinivasan2, 79 | 80 | 81 | Matthias Nießner1 82 | 83 |
84 | 85 |
86 | 1Technical University of Munich, 87 | 2Google Research 88 |
89 | 90 |
91 | 142 |
143 |
144 |
145 |
146 |
147 |
148 | 149 |
150 |
151 |
152 | Dense Depth Priors for NeRF Pipeline 153 |

154 | Dense Depth Priors guide the radiance field optimization to render novel views of complete rooms from just a handful of input images. 155 |

156 |
157 |
158 |
159 | 160 | 161 |
162 |
163 |
164 | 196 |
197 |
198 |
199 | 200 | 201 |
202 |
203 | 204 |
205 |
206 |

Abstract

207 |
208 |

209 | Neural radiance fields (NeRF) encode a scene into a neural representation that enables photo-realistic rendering of novel views. 210 | However, a successful reconstruction from RGB images requires a large number of input views taken under static conditions — typically up to a few hundred images for room-size scenes. 211 | Our method aims to synthesize novel views of whole rooms from an order of magnitude fewer images. 212 | To this end, we leverage dense depth priors in order to constrain the NeRF optimization. 213 | First, we take advantage of the sparse depth data that is freely available from the structure from motion (SfM) preprocessing step used to estimate camera poses. 214 | Second, we use depth completion to convert these sparse points into dense depth maps and uncertainty estimates, which are used to guide NeRF optimization. 215 | Our method enables data-efficient novel view synthesis on challenging indoor scenes, using as few as 18 images for an entire scene. 216 |

217 |
218 |
219 |
220 | 221 | 222 | 223 |
224 |
225 |

Video

226 |
227 | 229 |
230 |
231 |
232 | 233 |
234 |
235 | 236 |
237 |
238 |

BibTeX

239 |
@inproceedings{roessle2022depthpriorsnerf,
240 |       title={Dense Depth Priors for Neural Radiance Fields from Sparse Input Views}, 
241 |       author={Barbara Roessle and Jonathan T. Barron and Ben Mildenhall and Pratul P. Srinivasan and Matthias Nie{\ss}ner},
242 |       booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
243 |       month={June},
244 |       year={2022}
245 | }
246 |
247 |
248 | 249 | 250 | 273 | 274 | 275 | 276 | -------------------------------------------------------------------------------- /docs/static/css/bulma-carousel.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.slider{position:relative;width:100%}.slider-container{display:flex;flex-wrap:nowrap;flex-direction:row;overflow:hidden;-webkit-transform:translate3d(0,0,0);transform:translate3d(0,0,0);min-height:100%}.slider-container.is-vertical{flex-direction:column}.slider-container .slider-item{flex:none}.slider-container .slider-item .image.is-covered img{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.slider-container .slider-item .video-container{height:0;padding-bottom:0;padding-top:56.25%;margin:0;position:relative}.slider-container .slider-item .video-container.is-1by1,.slider-container .slider-item .video-container.is-square{padding-top:100%}.slider-container .slider-item .video-container.is-4by3{padding-top:75%}.slider-container .slider-item .video-container.is-21by9{padding-top:42.857143%}.slider-container .slider-item .video-container embed,.slider-container .slider-item .video-container iframe,.slider-container .slider-item .video-container object{position:absolute;top:0;left:0;width:100%!important;height:100%!important}.slider-navigation-next,.slider-navigation-previous{display:flex;justify-content:center;align-items:center;position:absolute;width:42px;height:42px;background:#fff center center no-repeat;background-size:20px 20px;border:1px solid #fff;border-radius:25091983px;box-shadow:0 2px 5px #3232321a;top:50%;margin-top:-20px;left:0;cursor:pointer;transition:opacity .3s,-webkit-transform .3s;transition:transform .3s,opacity .3s;transition:transform .3s,opacity .3s,-webkit-transform .3s}.slider-navigation-next:hover,.slider-navigation-previous:hover{-webkit-transform:scale(1.2);transform:scale(1.2)}.slider-navigation-next.is-hidden,.slider-navigation-previous.is-hidden{display:none;opacity:0}.slider-navigation-next svg,.slider-navigation-previous svg{width:25%}.slider-navigation-next{left:auto;right:0;background:#fff center center no-repeat;background-size:20px 20px}.slider-pagination{display:none;justify-content:center;align-items:center;position:absolute;bottom:0;left:0;right:0;padding:.5rem 1rem;text-align:center}.slider-pagination .slider-page{background:#fff;width:10px;height:10px;border-radius:25091983px;display:inline-block;margin:0 3px;box-shadow:0 2px 5px #3232321a;transition:-webkit-transform .3s;transition:transform .3s;transition:transform .3s,-webkit-transform .3s;cursor:pointer}.slider-pagination .slider-page.is-active,.slider-pagination .slider-page:hover{-webkit-transform:scale(1.4);transform:scale(1.4)}@media screen and (min-width:800px){.slider-pagination{display:flex}}.hero.has-carousel{position:relative}.hero.has-carousel+.hero-body,.hero.has-carousel+.hero-footer,.hero.has-carousel+.hero-head{z-index:10;overflow:hidden}.hero.has-carousel .hero-carousel{position:absolute;top:0;left:0;bottom:0;right:0;height:auto;border:none;margin:auto;padding:0;z-index:0}.hero.has-carousel .hero-carousel .slider{width:100%;max-width:100%;overflow:hidden;height:100%!important;max-height:100%;z-index:0}.hero.has-carousel .hero-carousel .slider .has-background{max-height:100%}.hero.has-carousel .hero-carousel .slider .has-background .is-background{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.hero.has-carousel .hero-body{margin:0 3rem;z-index:10} -------------------------------------------------------------------------------- /docs/static/css/bulma-slider.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}input[type=range].slider{-webkit-appearance:none;-moz-appearance:none;appearance:none;margin:1rem 0;background:0 0;touch-action:none}input[type=range].slider.is-fullwidth{display:block;width:100%}input[type=range].slider:focus{outline:0}input[type=range].slider:not([orient=vertical])::-webkit-slider-runnable-track{width:100%}input[type=range].slider:not([orient=vertical])::-moz-range-track{width:100%}input[type=range].slider:not([orient=vertical])::-ms-track{width:100%}input[type=range].slider:not([orient=vertical]).has-output+output,input[type=range].slider:not([orient=vertical]).has-output-tooltip+output{width:3rem;background:#4a4a4a;border-radius:4px;padding:.4rem .8rem;font-size:.75rem;line-height:.75rem;text-align:center;text-overflow:ellipsis;white-space:nowrap;color:#fff;overflow:hidden;pointer-events:none;z-index:200}input[type=range].slider:not([orient=vertical]).has-output-tooltip:disabled+output,input[type=range].slider:not([orient=vertical]).has-output:disabled+output{opacity:.5}input[type=range].slider:not([orient=vertical]).has-output{display:inline-block;vertical-align:middle;width:calc(100% - (4.2rem))}input[type=range].slider:not([orient=vertical]).has-output+output{display:inline-block;margin-left:.75rem;vertical-align:middle}input[type=range].slider:not([orient=vertical]).has-output-tooltip{display:block}input[type=range].slider:not([orient=vertical]).has-output-tooltip+output{position:absolute;left:0;top:-.1rem}input[type=range].slider[orient=vertical]{-webkit-appearance:slider-vertical;-moz-appearance:slider-vertical;appearance:slider-vertical;-webkit-writing-mode:bt-lr;-ms-writing-mode:bt-lr;writing-mode:bt-lr}input[type=range].slider[orient=vertical]::-webkit-slider-runnable-track{height:100%}input[type=range].slider[orient=vertical]::-moz-range-track{height:100%}input[type=range].slider[orient=vertical]::-ms-track{height:100%}input[type=range].slider::-webkit-slider-runnable-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-moz-range-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-ms-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-ms-fill-lower{background:#dbdbdb;border-radius:4px}input[type=range].slider::-ms-fill-upper{background:#dbdbdb;border-radius:4px}input[type=range].slider::-webkit-slider-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-moz-range-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-ms-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-webkit-slider-thumb{-webkit-appearance:none;appearance:none}input[type=range].slider.is-circle::-webkit-slider-thumb{border-radius:290486px}input[type=range].slider.is-circle::-moz-range-thumb{border-radius:290486px}input[type=range].slider.is-circle::-ms-thumb{border-radius:290486px}input[type=range].slider:active::-webkit-slider-thumb{-webkit-transform:scale(1.25);transform:scale(1.25)}input[type=range].slider:active::-moz-range-thumb{transform:scale(1.25)}input[type=range].slider:active::-ms-thumb{transform:scale(1.25)}input[type=range].slider:disabled{opacity:.5;cursor:not-allowed}input[type=range].slider:disabled::-webkit-slider-thumb{cursor:not-allowed;-webkit-transform:scale(1);transform:scale(1)}input[type=range].slider:disabled::-moz-range-thumb{cursor:not-allowed;transform:scale(1)}input[type=range].slider:disabled::-ms-thumb{cursor:not-allowed;transform:scale(1)}input[type=range].slider:not([orient=vertical]){min-height:calc((1rem + 2px) * 1.25)}input[type=range].slider:not([orient=vertical])::-webkit-slider-runnable-track{height:.5rem}input[type=range].slider:not([orient=vertical])::-moz-range-track{height:.5rem}input[type=range].slider:not([orient=vertical])::-ms-track{height:.5rem}input[type=range].slider[orient=vertical]::-webkit-slider-runnable-track{width:.5rem}input[type=range].slider[orient=vertical]::-moz-range-track{width:.5rem}input[type=range].slider[orient=vertical]::-ms-track{width:.5rem}input[type=range].slider::-webkit-slider-thumb{height:1rem;width:1rem}input[type=range].slider::-moz-range-thumb{height:1rem;width:1rem}input[type=range].slider::-ms-thumb{height:1rem;width:1rem}input[type=range].slider::-ms-thumb{margin-top:0}input[type=range].slider::-webkit-slider-thumb{margin-top:-.25rem}input[type=range].slider[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.25rem}input[type=range].slider.is-small:not([orient=vertical]){min-height:calc((.75rem + 2px) * 1.25)}input[type=range].slider.is-small:not([orient=vertical])::-webkit-slider-runnable-track{height:.375rem}input[type=range].slider.is-small:not([orient=vertical])::-moz-range-track{height:.375rem}input[type=range].slider.is-small:not([orient=vertical])::-ms-track{height:.375rem}input[type=range].slider.is-small[orient=vertical]::-webkit-slider-runnable-track{width:.375rem}input[type=range].slider.is-small[orient=vertical]::-moz-range-track{width:.375rem}input[type=range].slider.is-small[orient=vertical]::-ms-track{width:.375rem}input[type=range].slider.is-small::-webkit-slider-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-moz-range-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-ms-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-ms-thumb{margin-top:0}input[type=range].slider.is-small::-webkit-slider-thumb{margin-top:-.1875rem}input[type=range].slider.is-small[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.1875rem}input[type=range].slider.is-medium:not([orient=vertical]){min-height:calc((1.25rem + 2px) * 1.25)}input[type=range].slider.is-medium:not([orient=vertical])::-webkit-slider-runnable-track{height:.625rem}input[type=range].slider.is-medium:not([orient=vertical])::-moz-range-track{height:.625rem}input[type=range].slider.is-medium:not([orient=vertical])::-ms-track{height:.625rem}input[type=range].slider.is-medium[orient=vertical]::-webkit-slider-runnable-track{width:.625rem}input[type=range].slider.is-medium[orient=vertical]::-moz-range-track{width:.625rem}input[type=range].slider.is-medium[orient=vertical]::-ms-track{width:.625rem}input[type=range].slider.is-medium::-webkit-slider-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-moz-range-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-ms-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-ms-thumb{margin-top:0}input[type=range].slider.is-medium::-webkit-slider-thumb{margin-top:-.3125rem}input[type=range].slider.is-medium[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.3125rem}input[type=range].slider.is-large:not([orient=vertical]){min-height:calc((1.5rem + 2px) * 1.25)}input[type=range].slider.is-large:not([orient=vertical])::-webkit-slider-runnable-track{height:.75rem}input[type=range].slider.is-large:not([orient=vertical])::-moz-range-track{height:.75rem}input[type=range].slider.is-large:not([orient=vertical])::-ms-track{height:.75rem}input[type=range].slider.is-large[orient=vertical]::-webkit-slider-runnable-track{width:.75rem}input[type=range].slider.is-large[orient=vertical]::-moz-range-track{width:.75rem}input[type=range].slider.is-large[orient=vertical]::-ms-track{width:.75rem}input[type=range].slider.is-large::-webkit-slider-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-moz-range-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-ms-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-ms-thumb{margin-top:0}input[type=range].slider.is-large::-webkit-slider-thumb{margin-top:-.375rem}input[type=range].slider.is-large[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.375rem}input[type=range].slider.is-white::-moz-range-track{background:#fff!important}input[type=range].slider.is-white::-webkit-slider-runnable-track{background:#fff!important}input[type=range].slider.is-white::-ms-track{background:#fff!important}input[type=range].slider.is-white::-ms-fill-lower{background:#fff}input[type=range].slider.is-white::-ms-fill-upper{background:#fff}input[type=range].slider.is-white .has-output-tooltip+output,input[type=range].slider.is-white.has-output+output{background-color:#fff;color:#0a0a0a}input[type=range].slider.is-black::-moz-range-track{background:#0a0a0a!important}input[type=range].slider.is-black::-webkit-slider-runnable-track{background:#0a0a0a!important}input[type=range].slider.is-black::-ms-track{background:#0a0a0a!important}input[type=range].slider.is-black::-ms-fill-lower{background:#0a0a0a}input[type=range].slider.is-black::-ms-fill-upper{background:#0a0a0a}input[type=range].slider.is-black .has-output-tooltip+output,input[type=range].slider.is-black.has-output+output{background-color:#0a0a0a;color:#fff}input[type=range].slider.is-light::-moz-range-track{background:#f5f5f5!important}input[type=range].slider.is-light::-webkit-slider-runnable-track{background:#f5f5f5!important}input[type=range].slider.is-light::-ms-track{background:#f5f5f5!important}input[type=range].slider.is-light::-ms-fill-lower{background:#f5f5f5}input[type=range].slider.is-light::-ms-fill-upper{background:#f5f5f5}input[type=range].slider.is-light .has-output-tooltip+output,input[type=range].slider.is-light.has-output+output{background-color:#f5f5f5;color:#363636}input[type=range].slider.is-dark::-moz-range-track{background:#363636!important}input[type=range].slider.is-dark::-webkit-slider-runnable-track{background:#363636!important}input[type=range].slider.is-dark::-ms-track{background:#363636!important}input[type=range].slider.is-dark::-ms-fill-lower{background:#363636}input[type=range].slider.is-dark::-ms-fill-upper{background:#363636}input[type=range].slider.is-dark .has-output-tooltip+output,input[type=range].slider.is-dark.has-output+output{background-color:#363636;color:#f5f5f5}input[type=range].slider.is-primary::-moz-range-track{background:#00d1b2!important}input[type=range].slider.is-primary::-webkit-slider-runnable-track{background:#00d1b2!important}input[type=range].slider.is-primary::-ms-track{background:#00d1b2!important}input[type=range].slider.is-primary::-ms-fill-lower{background:#00d1b2}input[type=range].slider.is-primary::-ms-fill-upper{background:#00d1b2}input[type=range].slider.is-primary .has-output-tooltip+output,input[type=range].slider.is-primary.has-output+output{background-color:#00d1b2;color:#fff}input[type=range].slider.is-link::-moz-range-track{background:#3273dc!important}input[type=range].slider.is-link::-webkit-slider-runnable-track{background:#3273dc!important}input[type=range].slider.is-link::-ms-track{background:#3273dc!important}input[type=range].slider.is-link::-ms-fill-lower{background:#3273dc}input[type=range].slider.is-link::-ms-fill-upper{background:#3273dc}input[type=range].slider.is-link .has-output-tooltip+output,input[type=range].slider.is-link.has-output+output{background-color:#3273dc;color:#fff}input[type=range].slider.is-info::-moz-range-track{background:#209cee!important}input[type=range].slider.is-info::-webkit-slider-runnable-track{background:#209cee!important}input[type=range].slider.is-info::-ms-track{background:#209cee!important}input[type=range].slider.is-info::-ms-fill-lower{background:#209cee}input[type=range].slider.is-info::-ms-fill-upper{background:#209cee}input[type=range].slider.is-info .has-output-tooltip+output,input[type=range].slider.is-info.has-output+output{background-color:#209cee;color:#fff}input[type=range].slider.is-success::-moz-range-track{background:#23d160!important}input[type=range].slider.is-success::-webkit-slider-runnable-track{background:#23d160!important}input[type=range].slider.is-success::-ms-track{background:#23d160!important}input[type=range].slider.is-success::-ms-fill-lower{background:#23d160}input[type=range].slider.is-success::-ms-fill-upper{background:#23d160}input[type=range].slider.is-success .has-output-tooltip+output,input[type=range].slider.is-success.has-output+output{background-color:#23d160;color:#fff}input[type=range].slider.is-warning::-moz-range-track{background:#ffdd57!important}input[type=range].slider.is-warning::-webkit-slider-runnable-track{background:#ffdd57!important}input[type=range].slider.is-warning::-ms-track{background:#ffdd57!important}input[type=range].slider.is-warning::-ms-fill-lower{background:#ffdd57}input[type=range].slider.is-warning::-ms-fill-upper{background:#ffdd57}input[type=range].slider.is-warning .has-output-tooltip+output,input[type=range].slider.is-warning.has-output+output{background-color:#ffdd57;color:rgba(0,0,0,.7)}input[type=range].slider.is-danger::-moz-range-track{background:#ff3860!important}input[type=range].slider.is-danger::-webkit-slider-runnable-track{background:#ff3860!important}input[type=range].slider.is-danger::-ms-track{background:#ff3860!important}input[type=range].slider.is-danger::-ms-fill-lower{background:#ff3860}input[type=range].slider.is-danger::-ms-fill-upper{background:#ff3860}input[type=range].slider.is-danger .has-output-tooltip+output,input[type=range].slider.is-danger.has-output+output{background-color:#ff3860;color:#fff} -------------------------------------------------------------------------------- /docs/static/css/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Noto Sans', sans-serif; 3 | } 4 | 5 | 6 | .footer .icon-link { 7 | font-size: 25px; 8 | color: #000; 9 | } 10 | 11 | .link-block a { 12 | margin-top: 5px; 13 | margin-bottom: 5px; 14 | } 15 | 16 | .dnerf { 17 | font-variant: small-caps; 18 | } 19 | 20 | 21 | .teaser .hero-body { 22 | padding-top: 0; 23 | padding-bottom: 3rem; 24 | } 25 | 26 | .teaser { 27 | font-family: 'Google Sans', sans-serif; 28 | } 29 | 30 | 31 | .publication-title { 32 | } 33 | 34 | .publication-banner { 35 | max-height: parent; 36 | 37 | } 38 | 39 | .publication-banner video { 40 | position: relative; 41 | left: auto; 42 | top: auto; 43 | transform: none; 44 | object-fit: fit; 45 | } 46 | 47 | .publication-header .hero-body { 48 | } 49 | 50 | .publication-title { 51 | font-family: 'Google Sans', sans-serif; 52 | } 53 | 54 | .publication-authors { 55 | font-family: 'Google Sans', sans-serif; 56 | } 57 | 58 | .publication-venue { 59 | color: #555; 60 | width: fit-content; 61 | font-weight: bold; 62 | } 63 | 64 | .publication-awards { 65 | color: #ff3860; 66 | width: fit-content; 67 | font-weight: bolder; 68 | } 69 | 70 | .publication-authors { 71 | } 72 | 73 | .publication-authors a { 74 | color: hsl(204, 86%, 53%) !important; 75 | } 76 | 77 | .publication-authors a:hover { 78 | text-decoration: underline; 79 | } 80 | 81 | .author-block { 82 | display: inline-block; 83 | } 84 | 85 | .publication-banner img { 86 | } 87 | 88 | .publication-authors { 89 | /*color: #4286f4;*/ 90 | } 91 | 92 | .publication-video { 93 | position: relative; 94 | width: 100%; 95 | height: 0; 96 | padding-bottom: 56.25%; 97 | 98 | overflow: hidden; 99 | border-radius: 10px !important; 100 | } 101 | 102 | .publication-video iframe { 103 | position: absolute; 104 | top: 0; 105 | left: 0; 106 | width: 100%; 107 | height: 100%; 108 | } 109 | 110 | .publication-body img { 111 | } 112 | 113 | .results-carousel { 114 | overflow: hidden; 115 | } 116 | 117 | .results-carousel .item { 118 | margin: 5px; 119 | overflow: hidden; 120 | border: 1px solid #bbb; 121 | border-radius: 10px; 122 | padding: 0; 123 | font-size: 0; 124 | } 125 | 126 | .results-carousel video { 127 | margin: 0; 128 | } 129 | 130 | 131 | .interpolation-panel { 132 | background: #f5f5f5; 133 | border-radius: 10px; 134 | } 135 | 136 | .interpolation-panel .interpolation-image { 137 | width: 100%; 138 | border-radius: 5px; 139 | } 140 | 141 | .interpolation-video-column { 142 | } 143 | 144 | .interpolation-panel .slider { 145 | margin: 0 !important; 146 | } 147 | 148 | .interpolation-panel .slider { 149 | margin: 0 !important; 150 | } 151 | 152 | #interpolation-image-wrapper { 153 | width: 100%; 154 | } 155 | #interpolation-image-wrapper img { 156 | border-radius: 5px; 157 | } 158 | -------------------------------------------------------------------------------- /docs/static/images/pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/barbararoessle/dense_depth_priors_nerf/5bb060894984ec51952142b7526990990dc87618/docs/static/images/pipeline.jpg -------------------------------------------------------------------------------- /docs/static/images/video_s708_reduced.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/barbararoessle/dense_depth_priors_nerf/5bb060894984ec51952142b7526990990dc87618/docs/static/images/video_s708_reduced.mp4 -------------------------------------------------------------------------------- /docs/static/images/video_s710_reduced.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/barbararoessle/dense_depth_priors_nerf/5bb060894984ec51952142b7526990990dc87618/docs/static/images/video_s710_reduced.mp4 -------------------------------------------------------------------------------- /docs/static/images/video_s738_reduced.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/barbararoessle/dense_depth_priors_nerf/5bb060894984ec51952142b7526990990dc87618/docs/static/images/video_s738_reduced.mp4 -------------------------------------------------------------------------------- /docs/static/images/video_s758_reduced.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/barbararoessle/dense_depth_priors_nerf/5bb060894984ec51952142b7526990990dc87618/docs/static/images/video_s758_reduced.mp4 -------------------------------------------------------------------------------- /docs/static/images/video_s781_reduced.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/barbararoessle/dense_depth_priors_nerf/5bb060894984ec51952142b7526990990dc87618/docs/static/images/video_s781_reduced.mp4 -------------------------------------------------------------------------------- /docs/static/js/bulma-carousel.min.js: -------------------------------------------------------------------------------- 1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaCarousel=e():t.bulmaCarousel=e()}("undefined"!=typeof self?self:this,function(){return function(i){var n={};function s(t){if(n[t])return n[t].exports;var e=n[t]={i:t,l:!1,exports:{}};return i[t].call(e.exports,e,e.exports,s),e.l=!0,e.exports}return s.m=i,s.c=n,s.d=function(t,e,i){s.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:i})},s.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return s.d(e,"a",e),e},s.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},s.p="",s(s.s=5)}([function(t,e,i){"use strict";i.d(e,"d",function(){return s}),i.d(e,"e",function(){return r}),i.d(e,"b",function(){return o}),i.d(e,"c",function(){return a}),i.d(e,"a",function(){return l});var n=i(2),s=function(e,t){(t=Array.isArray(t)?t:t.split(" ")).forEach(function(t){e.classList.remove(t)})},r=function(t){return t.getBoundingClientRect().width||t.offsetWidth},o=function(t){return t.getBoundingClientRect().height||t.offsetHeight},a=function(t){var e=1=t._x&&this._x<=e._x&&this._y>=t._y&&this._y<=e._y}},{key:"constrain",value:function(t,e){if(t._x>e._x||t._y>e._y)return this;var i=this._x,n=this._y;return null!==t._x&&(i=Math.max(i,t._x)),null!==e._x&&(i=Math.min(i,e._x)),null!==t._y&&(n=Math.max(n,t._y)),null!==e._y&&(n=Math.min(n,e._y)),new s(i,n)}},{key:"reposition",value:function(t){t.style.top=this._y+"px",t.style.left=this._x+"px"}},{key:"toString",value:function(){return"("+this._x+","+this._y+")"}},{key:"x",get:function(){return this._x},set:function(){var t=0this.state.length-this.slidesToShow&&!this.options.centerMode?this.state.next=this.state.index:this.state.next=this.state.index+this.slidesToScroll,this.show()}},{key:"previous",value:function(){this.options.loop||this.options.infinite||0!==this.state.index?this.state.next=this.state.index-this.slidesToScroll:this.state.next=this.state.index,this.show()}},{key:"start",value:function(){this._autoplay.start()}},{key:"pause",value:function(){this._autoplay.pause()}},{key:"stop",value:function(){this._autoplay.stop()}},{key:"show",value:function(t){var e=1this.options.slidesToShow&&(this.options.slidesToScroll=this.slidesToShow),this._breakpoint.init(),this.state.index>=this.state.length&&0!==this.state.index&&(this.state.index=this.state.index-this.slidesToScroll),this.state.length<=this.slidesToShow&&(this.state.index=0),this._ui.wrapper.appendChild(this._navigation.init().render()),this._ui.wrapper.appendChild(this._pagination.init().render()),this.options.navigationSwipe?this._swipe.bindEvents():this._swipe._bindEvents(),this._breakpoint.apply(),this._slides.forEach(function(t){return e._ui.container.appendChild(t)}),this._transitioner.init().apply(!0,this._setHeight.bind(this)),this.options.autoplay&&this._autoplay.init().start()}},{key:"destroy",value:function(){var e=this;this._unbindEvents(),this._items.forEach(function(t){e.element.appendChild(t)}),this.node.remove()}},{key:"id",get:function(){return this._id}},{key:"index",set:function(t){this._index=t},get:function(){return this._index}},{key:"length",set:function(t){this._length=t},get:function(){return this._length}},{key:"slides",get:function(){return this._slides},set:function(t){this._slides=t}},{key:"slidesToScroll",get:function(){return"translate"===this.options.effect?this._breakpoint.getSlidesToScroll():1}},{key:"slidesToShow",get:function(){return"translate"===this.options.effect?this._breakpoint.getSlidesToShow():1}},{key:"direction",get:function(){return"rtl"===this.element.dir.toLowerCase()||"rtl"===this.element.style.direction?"rtl":"ltr"}},{key:"wrapper",get:function(){return this._ui.wrapper}},{key:"wrapperWidth",get:function(){return this._wrapperWidth||0}},{key:"container",get:function(){return this._ui.container}},{key:"containerWidth",get:function(){return this._containerWidth||0}},{key:"slideWidth",get:function(){return this._slideWidth||0}},{key:"transitioner",get:function(){return this._transitioner}}],[{key:"attach",value:function(){var i=this,t=0>t/4).toString(16)})}},function(t,e,i){"use strict";var n=i(3),s=i(8),r=function(){function n(t,e){for(var i=0;i=t.slider.state.length-t.slider.slidesToShow&&!t.slider.options.loop&&!t.slider.options.infinite?t.stop():t.slider.next())},this.slider.options.autoplaySpeed))}},{key:"stop",value:function(){this._interval=clearInterval(this._interval),this.emit("stop",this)}},{key:"pause",value:function(){var t=this,e=0parseInt(e.changePoint,10)}),this._currentBreakpoint=this._getActiveBreakpoint(),this}},{key:"destroy",value:function(){this._unbindEvents()}},{key:"_bindEvents",value:function(){window.addEventListener("resize",this[s]),window.addEventListener("orientationchange",this[s])}},{key:"_unbindEvents",value:function(){window.removeEventListener("resize",this[s]),window.removeEventListener("orientationchange",this[s])}},{key:"_getActiveBreakpoint",value:function(){var t=!0,e=!1,i=void 0;try{for(var n,s=this.options.breakpoints[Symbol.iterator]();!(t=(n=s.next()).done);t=!0){var r=n.value;if(r.changePoint>=window.innerWidth)return r}}catch(t){e=!0,i=t}finally{try{!t&&s.return&&s.return()}finally{if(e)throw i}}return this._defaultBreakpoint}},{key:"getSlidesToShow",value:function(){return this._currentBreakpoint?this._currentBreakpoint.slidesToShow:this._defaultBreakpoint.slidesToShow}},{key:"getSlidesToScroll",value:function(){return this._currentBreakpoint?this._currentBreakpoint.slidesToScroll:this._defaultBreakpoint.slidesToScroll}},{key:"apply",value:function(){this.slider.state.index>=this.slider.state.length&&0!==this.slider.state.index&&(this.slider.state.index=this.slider.state.index-this._currentBreakpoint.slidesToScroll),this.slider.state.length<=this._currentBreakpoint.slidesToShow&&(this.slider.state.index=0),this.options.loop&&this.slider._loop.init().apply(),this.options.infinite&&this.slider._infinite.init().apply(),this.slider._setDimensions(),this.slider._transitioner.init().apply(!0,this.slider._setHeight.bind(this.slider)),this.slider._setClasses(),this.slider._navigation.refresh(),this.slider._pagination.refresh()}},{key:s,value:function(t){var e=this._getActiveBreakpoint();e.slidesToShow!==this._currentBreakpoint.slidesToShow&&(this._currentBreakpoint=e,this.apply())}}]),e}();e.a=r},function(t,e,i){"use strict";var n=function(){function n(t,e){for(var i=0;ithis.slider.state.length-1-this._infiniteCount;i-=1)e=i-1,t.unshift(this._cloneSlide(this.slider.slides[e],e-this.slider.state.length));for(var n=[],s=0;s=this.slider.state.length?(this.slider.state.index=this.slider.state.next=this.slider.state.next-this.slider.state.length,this.slider.transitioner.apply(!0)):this.slider.state.next<0&&(this.slider.state.index=this.slider.state.next=this.slider.state.length+this.slider.state.next,this.slider.transitioner.apply(!0)))}},{key:"_cloneSlide",value:function(t,e){var i=t.cloneNode(!0);return i.dataset.sliderIndex=e,i.dataset.cloned=!0,(i.querySelectorAll("[id]")||[]).forEach(function(t){t.setAttribute("id","")}),i}}]),e}();e.a=s},function(t,e,i){"use strict";var n=i(12),s=function(){function n(t,e){for(var i=0;ithis.slider.state.length-this.slider.slidesToShow&&Object(n.a)(this.slider._slides[this.slider.state.length-1],this.slider.wrapper)?this.slider.state.next=0:this.slider.state.next=Math.min(Math.max(this.slider.state.next,0),this.slider.state.length-this.slider.slidesToShow):this.slider.state.next=0:this.slider.state.next<=0-this.slider.slidesToScroll?this.slider.state.next=this.slider.state.length-this.slider.slidesToShow:this.slider.state.next=0)}}]),e}();e.a=r},function(t,e,i){"use strict";i.d(e,"a",function(){return n});var n=function(t,e){var i=t.getBoundingClientRect();return e=e||document.documentElement,0<=i.top&&0<=i.left&&i.bottom<=(window.innerHeight||e.clientHeight)&&i.right<=(window.innerWidth||e.clientWidth)}},function(t,e,i){"use strict";var n=i(14),s=i(1),r=function(){function n(t,e){for(var i=0;ithis.slider.slidesToShow?(this._ui.previous.classList.remove("is-hidden"),this._ui.next.classList.remove("is-hidden"),0===this.slider.state.next?(this._ui.previous.classList.add("is-hidden"),this._ui.next.classList.remove("is-hidden")):this.slider.state.next>=this.slider.state.length-this.slider.slidesToShow&&!this.slider.options.centerMode?(this._ui.previous.classList.remove("is-hidden"),this._ui.next.classList.add("is-hidden")):this.slider.state.next>=this.slider.state.length-1&&this.slider.options.centerMode&&(this._ui.previous.classList.remove("is-hidden"),this._ui.next.classList.add("is-hidden"))):(this._ui.previous.classList.add("is-hidden"),this._ui.next.classList.add("is-hidden")))}},{key:"render",value:function(){return this.node}}]),e}();e.a=o},function(t,e,i){"use strict";e.a=function(t){return'
'+t.previous+'
\n
'+t.next+"
"}},function(t,e,i){"use strict";var n=i(16),s=i(17),r=i(1),o=function(){function n(t,e){for(var i=0;ithis.slider.slidesToShow){for(var t=0;t<=this._count;t++){var e=document.createRange().createContextualFragment(Object(s.a)()).firstChild;e.dataset.index=t*this.slider.slidesToScroll,this._pages.push(e),this._ui.container.appendChild(e)}this._bindEvents()}}},{key:"onPageClick",value:function(t){this._supportsPassive||t.preventDefault(),this.slider.state.next=t.currentTarget.dataset.index,this.slider.show()}},{key:"onResize",value:function(){this._draw()}},{key:"refresh",value:function(){var e=this,t=void 0;(t=this.slider.options.infinite?Math.ceil(this.slider.state.length-1/this.slider.slidesToScroll):Math.ceil((this.slider.state.length-this.slider.slidesToShow)/this.slider.slidesToScroll))!==this._count&&(this._count=t,this._draw()),this._pages.forEach(function(t){t.classList.remove("is-active"),parseInt(t.dataset.index,10)===e.slider.state.next%e.slider.state.length&&t.classList.add("is-active")})}},{key:"render",value:function(){return this.node}}]),e}();e.a=a},function(t,e,i){"use strict";e.a=function(){return'
'}},function(t,e,i){"use strict";e.a=function(){return'
'}},function(t,e,i){"use strict";var n=i(4),s=i(1),r=function(){function n(t,e){for(var i=0;iMath.abs(this._lastTranslate.y)&&(this._supportsPassive||t.preventDefault(),t.stopPropagation())}}},{key:"onStopDrag",value:function(t){this._origin&&this._lastTranslate&&(Math.abs(this._lastTranslate.x)>.2*this.width?this._lastTranslate.x<0?this.slider.next():this.slider.previous():this.slider.show(!0)),this._origin=null,this._lastTranslate=null}}]),e}();e.a=o},function(t,e,i){"use strict";var n=i(20),s=i(21),r=function(){function n(t,e){for(var i=0;it.x?(s.x=0,this.slider.state.next=0):this.options.vertical&&Math.abs(this._position.y)>t.y&&(s.y=0,this.slider.state.next=0)),this._position.x=s.x,this._position.y=s.y,this.options.centerMode&&(this._position.x=this._position.x+this.slider.wrapperWidth/2-Object(o.e)(i)/2),"rtl"===this.slider.direction&&(this._position.x=-this._position.x,this._position.y=-this._position.y),this.slider.container.style.transform="translate3d("+this._position.x+"px, "+this._position.y+"px, 0)",n.x>t.x&&this.slider.transitioner.end()}}},{key:"onTransitionEnd",value:function(t){"translate"===this.options.effect&&(this.transitioner.isAnimating()&&t.target==this.slider.container&&this.options.infinite&&this.slider._infinite.onTransitionEnd(t),this.transitioner.end())}}]),n}();e.a=n},function(t,e,i){"use strict";e.a={initialSlide:0,slidesToScroll:1,slidesToShow:1,navigation:!0,navigationKeys:!0,navigationSwipe:!0,pagination:!0,loop:!1,infinite:!1,effect:"translate",duration:300,timing:"ease",autoplay:!1,autoplaySpeed:3e3,pauseOnHover:!0,breakpoints:[{changePoint:480,slidesToShow:1,slidesToScroll:1},{changePoint:640,slidesToShow:2,slidesToScroll:2},{changePoint:768,slidesToShow:3,slidesToScroll:3}],onReady:null,icons:{previous:'\n \n ',next:'\n \n '}}},function(t,e,i){"use strict";e.a=function(t){return'
\n
\n
'}},function(t,e,i){"use strict";e.a=function(){return'
'}}]).default}); -------------------------------------------------------------------------------- /docs/static/js/bulma-slider.js: -------------------------------------------------------------------------------- 1 | (function webpackUniversalModuleDefinition(root, factory) { 2 | if(typeof exports === 'object' && typeof module === 'object') 3 | module.exports = factory(); 4 | else if(typeof define === 'function' && define.amd) 5 | define([], factory); 6 | else if(typeof exports === 'object') 7 | exports["bulmaSlider"] = factory(); 8 | else 9 | root["bulmaSlider"] = factory(); 10 | })(typeof self !== 'undefined' ? self : this, function() { 11 | return /******/ (function(modules) { // webpackBootstrap 12 | /******/ // The module cache 13 | /******/ var installedModules = {}; 14 | /******/ 15 | /******/ // The require function 16 | /******/ function __webpack_require__(moduleId) { 17 | /******/ 18 | /******/ // Check if module is in cache 19 | /******/ if(installedModules[moduleId]) { 20 | /******/ return installedModules[moduleId].exports; 21 | /******/ } 22 | /******/ // Create a new module (and put it into the cache) 23 | /******/ var module = installedModules[moduleId] = { 24 | /******/ i: moduleId, 25 | /******/ l: false, 26 | /******/ exports: {} 27 | /******/ }; 28 | /******/ 29 | /******/ // Execute the module function 30 | /******/ modules[moduleId].call(module.exports, module, module.exports, __webpack_require__); 31 | /******/ 32 | /******/ // Flag the module as loaded 33 | /******/ module.l = true; 34 | /******/ 35 | /******/ // Return the exports of the module 36 | /******/ return module.exports; 37 | /******/ } 38 | /******/ 39 | /******/ 40 | /******/ // expose the modules object (__webpack_modules__) 41 | /******/ __webpack_require__.m = modules; 42 | /******/ 43 | /******/ // expose the module cache 44 | /******/ __webpack_require__.c = installedModules; 45 | /******/ 46 | /******/ // define getter function for harmony exports 47 | /******/ __webpack_require__.d = function(exports, name, getter) { 48 | /******/ if(!__webpack_require__.o(exports, name)) { 49 | /******/ Object.defineProperty(exports, name, { 50 | /******/ configurable: false, 51 | /******/ enumerable: true, 52 | /******/ get: getter 53 | /******/ }); 54 | /******/ } 55 | /******/ }; 56 | /******/ 57 | /******/ // getDefaultExport function for compatibility with non-harmony modules 58 | /******/ __webpack_require__.n = function(module) { 59 | /******/ var getter = module && module.__esModule ? 60 | /******/ function getDefault() { return module['default']; } : 61 | /******/ function getModuleExports() { return module; }; 62 | /******/ __webpack_require__.d(getter, 'a', getter); 63 | /******/ return getter; 64 | /******/ }; 65 | /******/ 66 | /******/ // Object.prototype.hasOwnProperty.call 67 | /******/ __webpack_require__.o = function(object, property) { return Object.prototype.hasOwnProperty.call(object, property); }; 68 | /******/ 69 | /******/ // __webpack_public_path__ 70 | /******/ __webpack_require__.p = ""; 71 | /******/ 72 | /******/ // Load entry module and return exports 73 | /******/ return __webpack_require__(__webpack_require__.s = 0); 74 | /******/ }) 75 | /************************************************************************/ 76 | /******/ ([ 77 | /* 0 */ 78 | /***/ (function(module, __webpack_exports__, __webpack_require__) { 79 | 80 | "use strict"; 81 | Object.defineProperty(__webpack_exports__, "__esModule", { value: true }); 82 | /* harmony export (binding) */ __webpack_require__.d(__webpack_exports__, "isString", function() { return isString; }); 83 | /* harmony import */ var __WEBPACK_IMPORTED_MODULE_0__events__ = __webpack_require__(1); 84 | var _extends = Object.assign || function (target) { for (var i = 1; i < arguments.length; i++) { var source = arguments[i]; for (var key in source) { if (Object.prototype.hasOwnProperty.call(source, key)) { target[key] = source[key]; } } } return target; }; 85 | 86 | var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; }(); 87 | 88 | var _typeof = typeof Symbol === "function" && typeof Symbol.iterator === "symbol" ? function (obj) { return typeof obj; } : function (obj) { return obj && typeof Symbol === "function" && obj.constructor === Symbol && obj !== Symbol.prototype ? "symbol" : typeof obj; }; 89 | 90 | function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } } 91 | 92 | function _possibleConstructorReturn(self, call) { if (!self) { throw new ReferenceError("this hasn't been initialised - super() hasn't been called"); } return call && (typeof call === "object" || typeof call === "function") ? call : self; } 93 | 94 | function _inherits(subClass, superClass) { if (typeof superClass !== "function" && superClass !== null) { throw new TypeError("Super expression must either be null or a function, not " + typeof superClass); } subClass.prototype = Object.create(superClass && superClass.prototype, { constructor: { value: subClass, enumerable: false, writable: true, configurable: true } }); if (superClass) Object.setPrototypeOf ? Object.setPrototypeOf(subClass, superClass) : subClass.__proto__ = superClass; } 95 | 96 | 97 | 98 | var isString = function isString(unknown) { 99 | return typeof unknown === 'string' || !!unknown && (typeof unknown === 'undefined' ? 'undefined' : _typeof(unknown)) === 'object' && Object.prototype.toString.call(unknown) === '[object String]'; 100 | }; 101 | 102 | var bulmaSlider = function (_EventEmitter) { 103 | _inherits(bulmaSlider, _EventEmitter); 104 | 105 | function bulmaSlider(selector) { 106 | var options = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {}; 107 | 108 | _classCallCheck(this, bulmaSlider); 109 | 110 | var _this = _possibleConstructorReturn(this, (bulmaSlider.__proto__ || Object.getPrototypeOf(bulmaSlider)).call(this)); 111 | 112 | _this.element = typeof selector === 'string' ? document.querySelector(selector) : selector; 113 | // An invalid selector or non-DOM node has been provided. 114 | if (!_this.element) { 115 | throw new Error('An invalid selector or non-DOM node has been provided.'); 116 | } 117 | 118 | _this._clickEvents = ['click']; 119 | /// Set default options and merge with instance defined 120 | _this.options = _extends({}, options); 121 | 122 | _this.onSliderInput = _this.onSliderInput.bind(_this); 123 | 124 | _this.init(); 125 | return _this; 126 | } 127 | 128 | /** 129 | * Initiate all DOM element containing selector 130 | * @method 131 | * @return {Array} Array of all slider instances 132 | */ 133 | 134 | 135 | _createClass(bulmaSlider, [{ 136 | key: 'init', 137 | 138 | 139 | /** 140 | * Initiate plugin 141 | * @method init 142 | * @return {void} 143 | */ 144 | value: function init() { 145 | this._id = 'bulmaSlider' + new Date().getTime() + Math.floor(Math.random() * Math.floor(9999)); 146 | this.output = this._findOutputForSlider(); 147 | 148 | this._bindEvents(); 149 | 150 | if (this.output) { 151 | if (this.element.classList.contains('has-output-tooltip')) { 152 | // Get new output position 153 | var newPosition = this._getSliderOutputPosition(); 154 | 155 | // Set output position 156 | this.output.style['left'] = newPosition.position; 157 | } 158 | } 159 | 160 | this.emit('bulmaslider:ready', this.element.value); 161 | } 162 | }, { 163 | key: '_findOutputForSlider', 164 | value: function _findOutputForSlider() { 165 | var _this2 = this; 166 | 167 | var result = null; 168 | var outputs = document.getElementsByTagName('output') || []; 169 | 170 | Array.from(outputs).forEach(function (output) { 171 | if (output.htmlFor == _this2.element.getAttribute('id')) { 172 | result = output; 173 | return true; 174 | } 175 | }); 176 | return result; 177 | } 178 | }, { 179 | key: '_getSliderOutputPosition', 180 | value: function _getSliderOutputPosition() { 181 | // Update output position 182 | var newPlace, minValue; 183 | 184 | var style = window.getComputedStyle(this.element, null); 185 | // Measure width of range input 186 | var sliderWidth = parseInt(style.getPropertyValue('width'), 10); 187 | 188 | // Figure out placement percentage between left and right of input 189 | if (!this.element.getAttribute('min')) { 190 | minValue = 0; 191 | } else { 192 | minValue = this.element.getAttribute('min'); 193 | } 194 | var newPoint = (this.element.value - minValue) / (this.element.getAttribute('max') - minValue); 195 | 196 | // Prevent bubble from going beyond left or right (unsupported browsers) 197 | if (newPoint < 0) { 198 | newPlace = 0; 199 | } else if (newPoint > 1) { 200 | newPlace = sliderWidth; 201 | } else { 202 | newPlace = sliderWidth * newPoint; 203 | } 204 | 205 | return { 206 | 'position': newPlace + 'px' 207 | }; 208 | } 209 | 210 | /** 211 | * Bind all events 212 | * @method _bindEvents 213 | * @return {void} 214 | */ 215 | 216 | }, { 217 | key: '_bindEvents', 218 | value: function _bindEvents() { 219 | if (this.output) { 220 | // Add event listener to update output when slider value change 221 | this.element.addEventListener('input', this.onSliderInput, false); 222 | } 223 | } 224 | }, { 225 | key: 'onSliderInput', 226 | value: function onSliderInput(e) { 227 | e.preventDefault(); 228 | 229 | if (this.element.classList.contains('has-output-tooltip')) { 230 | // Get new output position 231 | var newPosition = this._getSliderOutputPosition(); 232 | 233 | // Set output position 234 | this.output.style['left'] = newPosition.position; 235 | } 236 | 237 | // Check for prefix and postfix 238 | var prefix = this.output.hasAttribute('data-prefix') ? this.output.getAttribute('data-prefix') : ''; 239 | var postfix = this.output.hasAttribute('data-postfix') ? this.output.getAttribute('data-postfix') : ''; 240 | 241 | // Update output with slider value 242 | this.output.value = prefix + this.element.value + postfix; 243 | 244 | this.emit('bulmaslider:ready', this.element.value); 245 | } 246 | }], [{ 247 | key: 'attach', 248 | value: function attach() { 249 | var _this3 = this; 250 | 251 | var selector = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : 'input[type="range"].slider'; 252 | var options = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {}; 253 | 254 | var instances = new Array(); 255 | 256 | var elements = isString(selector) ? document.querySelectorAll(selector) : Array.isArray(selector) ? selector : [selector]; 257 | elements.forEach(function (element) { 258 | if (typeof element[_this3.constructor.name] === 'undefined') { 259 | var instance = new bulmaSlider(element, options); 260 | element[_this3.constructor.name] = instance; 261 | instances.push(instance); 262 | } else { 263 | instances.push(element[_this3.constructor.name]); 264 | } 265 | }); 266 | 267 | return instances; 268 | } 269 | }]); 270 | 271 | return bulmaSlider; 272 | }(__WEBPACK_IMPORTED_MODULE_0__events__["a" /* default */]); 273 | 274 | /* harmony default export */ __webpack_exports__["default"] = (bulmaSlider); 275 | 276 | /***/ }), 277 | /* 1 */ 278 | /***/ (function(module, __webpack_exports__, __webpack_require__) { 279 | 280 | "use strict"; 281 | var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; }(); 282 | 283 | function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } } 284 | 285 | var EventEmitter = function () { 286 | function EventEmitter() { 287 | var listeners = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : []; 288 | 289 | _classCallCheck(this, EventEmitter); 290 | 291 | this._listeners = new Map(listeners); 292 | this._middlewares = new Map(); 293 | } 294 | 295 | _createClass(EventEmitter, [{ 296 | key: "listenerCount", 297 | value: function listenerCount(eventName) { 298 | if (!this._listeners.has(eventName)) { 299 | return 0; 300 | } 301 | 302 | var eventListeners = this._listeners.get(eventName); 303 | return eventListeners.length; 304 | } 305 | }, { 306 | key: "removeListeners", 307 | value: function removeListeners() { 308 | var _this = this; 309 | 310 | var eventName = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : null; 311 | var middleware = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : false; 312 | 313 | if (eventName !== null) { 314 | if (Array.isArray(eventName)) { 315 | name.forEach(function (e) { 316 | return _this.removeListeners(e, middleware); 317 | }); 318 | } else { 319 | this._listeners.delete(eventName); 320 | 321 | if (middleware) { 322 | this.removeMiddleware(eventName); 323 | } 324 | } 325 | } else { 326 | this._listeners = new Map(); 327 | } 328 | } 329 | }, { 330 | key: "middleware", 331 | value: function middleware(eventName, fn) { 332 | var _this2 = this; 333 | 334 | if (Array.isArray(eventName)) { 335 | name.forEach(function (e) { 336 | return _this2.middleware(e, fn); 337 | }); 338 | } else { 339 | if (!Array.isArray(this._middlewares.get(eventName))) { 340 | this._middlewares.set(eventName, []); 341 | } 342 | 343 | this._middlewares.get(eventName).push(fn); 344 | } 345 | } 346 | }, { 347 | key: "removeMiddleware", 348 | value: function removeMiddleware() { 349 | var _this3 = this; 350 | 351 | var eventName = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : null; 352 | 353 | if (eventName !== null) { 354 | if (Array.isArray(eventName)) { 355 | name.forEach(function (e) { 356 | return _this3.removeMiddleware(e); 357 | }); 358 | } else { 359 | this._middlewares.delete(eventName); 360 | } 361 | } else { 362 | this._middlewares = new Map(); 363 | } 364 | } 365 | }, { 366 | key: "on", 367 | value: function on(name, callback) { 368 | var _this4 = this; 369 | 370 | var once = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false; 371 | 372 | if (Array.isArray(name)) { 373 | name.forEach(function (e) { 374 | return _this4.on(e, callback); 375 | }); 376 | } else { 377 | name = name.toString(); 378 | var split = name.split(/,|, | /); 379 | 380 | if (split.length > 1) { 381 | split.forEach(function (e) { 382 | return _this4.on(e, callback); 383 | }); 384 | } else { 385 | if (!Array.isArray(this._listeners.get(name))) { 386 | this._listeners.set(name, []); 387 | } 388 | 389 | this._listeners.get(name).push({ once: once, callback: callback }); 390 | } 391 | } 392 | } 393 | }, { 394 | key: "once", 395 | value: function once(name, callback) { 396 | this.on(name, callback, true); 397 | } 398 | }, { 399 | key: "emit", 400 | value: function emit(name, data) { 401 | var _this5 = this; 402 | 403 | var silent = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false; 404 | 405 | name = name.toString(); 406 | var listeners = this._listeners.get(name); 407 | var middlewares = null; 408 | var doneCount = 0; 409 | var execute = silent; 410 | 411 | if (Array.isArray(listeners)) { 412 | listeners.forEach(function (listener, index) { 413 | // Start Middleware checks unless we're doing a silent emit 414 | if (!silent) { 415 | middlewares = _this5._middlewares.get(name); 416 | // Check and execute Middleware 417 | if (Array.isArray(middlewares)) { 418 | middlewares.forEach(function (middleware) { 419 | middleware(data, function () { 420 | var newData = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : null; 421 | 422 | if (newData !== null) { 423 | data = newData; 424 | } 425 | doneCount++; 426 | }, name); 427 | }); 428 | 429 | if (doneCount >= middlewares.length) { 430 | execute = true; 431 | } 432 | } else { 433 | execute = true; 434 | } 435 | } 436 | 437 | // If Middleware checks have been passed, execute 438 | if (execute) { 439 | if (listener.once) { 440 | listeners[index] = null; 441 | } 442 | listener.callback(data); 443 | } 444 | }); 445 | 446 | // Dirty way of removing used Events 447 | while (listeners.indexOf(null) !== -1) { 448 | listeners.splice(listeners.indexOf(null), 1); 449 | } 450 | } 451 | } 452 | }]); 453 | 454 | return EventEmitter; 455 | }(); 456 | 457 | /* harmony default export */ __webpack_exports__["a"] = (EventEmitter); 458 | 459 | /***/ }) 460 | /******/ ])["default"]; 461 | }); -------------------------------------------------------------------------------- /docs/static/js/bulma-slider.min.js: -------------------------------------------------------------------------------- 1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaSlider=e():t.bulmaSlider=e()}("undefined"!=typeof self?self:this,function(){return function(n){var r={};function i(t){if(r[t])return r[t].exports;var e=r[t]={i:t,l:!1,exports:{}};return n[t].call(e.exports,e,e.exports,i),e.l=!0,e.exports}return i.m=n,i.c=r,i.d=function(t,e,n){i.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:n})},i.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return i.d(e,"a",e),e},i.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},i.p="",i(i.s=0)}([function(t,e,n){"use strict";Object.defineProperty(e,"__esModule",{value:!0}),n.d(e,"isString",function(){return l});var r=n(1),i=Object.assign||function(t){for(var e=1;e=l.length&&(s=!0)):s=!0),s&&(t.once&&(u[e]=null),t.callback(r))});-1!==u.indexOf(null);)u.splice(u.indexOf(null),1)}}]),e}();e.a=i}]).default}); -------------------------------------------------------------------------------- /docs/static/js/index.js: -------------------------------------------------------------------------------- 1 | window.HELP_IMPROVE_VIDEOJS = false; 2 | 3 | $(document).ready(function() { 4 | // Check for click events on the navbar burger icon 5 | $(".navbar-burger").click(function() { 6 | // Toggle the "is-active" class on both the "navbar-burger" and the "navbar-menu" 7 | $(".navbar-burger").toggleClass("is-active"); 8 | $(".navbar-menu").toggleClass("is-active"); 9 | 10 | }); 11 | 12 | var options = { 13 | slidesToScroll: 1, 14 | slidesToShow: 3, 15 | loop: true, 16 | infinite: true, 17 | autoplay: false, 18 | autoplaySpeed: 3000, 19 | } 20 | 21 | // Initialize all div with carousel class 22 | var carousels = bulmaCarousel.attach('.carousel', options); 23 | 24 | // Loop on each carousel initialized 25 | for(var i = 0; i < carousels.length; i++) { 26 | // Add listener to event 27 | carousels[i].on('before:show', state => { 28 | console.log(state); 29 | }); 30 | } 31 | 32 | // Access to bulmaCarousel instance of an element 33 | var element = document.querySelector('#my-element'); 34 | if (element && element.bulmaCarousel) { 35 | // bulmaCarousel instance is available as element.bulmaCarousel 36 | element.bulmaCarousel.on('before-show', function(state) { 37 | console.log(state); 38 | }); 39 | } 40 | 41 | bulmaSlider.attach(); 42 | 43 | }) 44 | -------------------------------------------------------------------------------- /docs/static/ppt/20220809_dense_depth_priors_nerf.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/barbararoessle/dense_depth_priors_nerf/5bb060894984ec51952142b7526990990dc87618/docs/static/ppt/20220809_dense_depth_priors_nerf.pptx -------------------------------------------------------------------------------- /metric/__init__.py: -------------------------------------------------------------------------------- 1 | from .rmse import compute_rmse 2 | -------------------------------------------------------------------------------- /metric/rmse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def compute_rmse(prediction, target): 4 | return torch.sqrt((prediction - target).pow(2).mean()) -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .run_nerf_helpers import NeRF, get_embedder, get_rays, sample_pdf, img2mse, mse2psnr, to8b, to16b, \ 2 | precompute_quadratic_samples, compute_depth_loss, select_coordinates 3 | from .cspn import resnet18_skip 4 | -------------------------------------------------------------------------------- /model/cspn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Sat Feb 3 15:32:49 2018 3 | 4 | @author: norbot 5 | """ 6 | 7 | import os 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.utils.model_zoo as model_zoo 12 | from .cspn_affinity import Affinity_Propagate 13 | import torch.nn.functional as F 14 | 15 | # memory analyze 16 | import gc 17 | 18 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 19 | 'resnet152'] 20 | 21 | model_urls = { 22 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 23 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 24 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 25 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 26 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 27 | } 28 | 29 | model_path ={ 30 | 'resnet18': 'resnet18.pth', 31 | 'resnet50': 'resnet50.pth' 32 | } 33 | 34 | # update pretrained model params according to my model params 35 | def update_model(my_model, pretrained_dict): 36 | my_model_dict = my_model.state_dict() 37 | # 1. filter out unnecessary keys 38 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in my_model_dict} 39 | # 2. overwrite entries in the existing state dict 40 | my_model_dict.update(pretrained_dict) 41 | 42 | return my_model_dict 43 | 44 | # dont know why my offline saved model has 'module.' in front of all key name 45 | def remove_module(remove_dict): 46 | for k, v in remove_dict.items(): 47 | if 'module' in k : 48 | print("==> model dict with addtional module, remove it...") 49 | removed_dict = { k[7:]: v for k, v in remove_dict.items()} 50 | else: 51 | removed_dict = remove_dict 52 | break 53 | return removed_dict 54 | 55 | def conv3x3(in_planes, out_planes, stride=1): 56 | """3x3 convolution with padding""" 57 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 58 | padding=1, bias=False) 59 | 60 | class Unpool(nn.Module): 61 | # Unpool: 2*2 unpooling with zero padding 62 | def __init__(self, num_channels, stride=2): 63 | super(Unpool, self).__init__() 64 | 65 | self.num_channels = num_channels 66 | self.stride = stride 67 | 68 | # create kernel [1, 0; 0, 0] 69 | self.weights = torch.autograd.Variable(torch.zeros(num_channels, 1, stride, stride).cuda()) # currently not compatible with running on CPU 70 | self.weights[:,:,0,0] = 1 71 | 72 | def forward(self, x): 73 | return F.conv_transpose2d(x, self.weights, stride=self.stride, groups=self.num_channels) 74 | 75 | class BasicBlock(nn.Module): 76 | expansion = 1 77 | 78 | def __init__(self, inplanes, planes, stride=1, downsample=None): 79 | super(BasicBlock, self).__init__() 80 | self.conv1 = conv3x3(inplanes, planes, stride) 81 | self.bn1 = nn.BatchNorm2d(planes) 82 | self.relu = nn.ReLU(inplace=True) 83 | self.conv2 = conv3x3(planes, planes) 84 | self.bn2 = nn.BatchNorm2d(planes) 85 | self.downsample = downsample 86 | self.stride = stride 87 | 88 | def forward(self, x): 89 | residual = x 90 | 91 | out = self.conv1(x) 92 | out = self.bn1(out) 93 | out = self.relu(out) 94 | 95 | out = self.conv2(out) 96 | out = self.bn2(out) 97 | 98 | if self.downsample is not None: 99 | residual = self.downsample(x) 100 | 101 | out += residual 102 | out = self.relu(out) 103 | 104 | return out 105 | 106 | 107 | class Bottleneck(nn.Module): 108 | expansion = 4 109 | 110 | def __init__(self, inplanes, planes, stride=1, downsample=None): 111 | super(Bottleneck, self).__init__() 112 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 113 | self.bn1 = nn.BatchNorm2d(planes) 114 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 115 | padding=1, bias=False) 116 | self.bn2 = nn.BatchNorm2d(planes) 117 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 118 | self.bn3 = nn.BatchNorm2d(planes * 4) 119 | self.relu = nn.ReLU(inplace=True) 120 | self.downsample = downsample 121 | self.stride = stride 122 | 123 | def forward(self, x): 124 | residual = x 125 | 126 | out = self.conv1(x) 127 | out = self.bn1(out) 128 | out = self.relu(out) 129 | 130 | out = self.conv2(out) 131 | out = self.bn2(out) 132 | out = self.relu(out) 133 | 134 | out = self.conv3(out) 135 | out = self.bn3(out) 136 | 137 | if self.downsample is not None: 138 | residual = self.downsample(x) 139 | 140 | out += residual 141 | out = self.relu(out) 142 | 143 | return out 144 | 145 | class UpProj_Block(nn.Module): 146 | def __init__(self, in_channels, out_channels, oheight=0, owidth=0): 147 | super(UpProj_Block, self).__init__() 148 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, padding=2, bias=False) 149 | self.bn1 = nn.BatchNorm2d(out_channels) 150 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) 151 | self.bn2 = nn.BatchNorm2d(out_channels) 152 | self.sc_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, padding=2, bias=False) 153 | self.sc_bn1 = nn.BatchNorm2d(out_channels) 154 | self.relu = nn.ReLU(inplace=True) 155 | self.oheight = oheight 156 | self.owidth = owidth 157 | self._up_pool = Unpool(in_channels) 158 | 159 | def _up_pooling(self, x, scale): 160 | oheight = 0 161 | owidth = 0 162 | if self.oheight == 0 and self.owidth == 0: 163 | oheight = scale * x.size(2) 164 | owidth = scale * x.size(3) 165 | x = self._up_pool(x) 166 | else: 167 | oheight = self.oheight 168 | owidth = self.owidth 169 | x = self._up_pool(x) 170 | return x 171 | 172 | def forward(self, x): 173 | x = self._up_pooling(x, 2) 174 | out = self.relu(self.bn1(self.conv1(x))) 175 | out = self.bn2(self.conv2(out)) 176 | short_cut = self.sc_bn1(self.sc_conv1(x)) 177 | out += short_cut 178 | out = self.relu(out) 179 | return out 180 | 181 | class Simple_Gudi_UpConv_Block(nn.Module): 182 | def __init__(self, in_channels, out_channels, oheight=0, owidth=0): 183 | super(Simple_Gudi_UpConv_Block, self).__init__() 184 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, padding=2, bias=False) 185 | self.bn1 = nn.BatchNorm2d(out_channels) 186 | self.relu = nn.ReLU(inplace=True) 187 | self.oheight = oheight 188 | self.owidth = owidth 189 | self._up_pool = Unpool(in_channels) 190 | 191 | 192 | def _up_pooling(self, x, scale): 193 | 194 | x = self._up_pool(x) 195 | if self.oheight !=0 and self.owidth !=0: 196 | x = x.narrow(2,0,self.oheight) 197 | x = x.narrow(3,0,self.owidth) 198 | return x 199 | 200 | 201 | def forward(self, x): 202 | x = self._up_pooling(x, 2) 203 | out = self.relu(self.bn1(self.conv1(x))) 204 | return out 205 | 206 | class Simple_Gudi_UpConv_Block_Last_Layer(nn.Module): 207 | def __init__(self, in_channels, out_channels, oheight=0, owidth=0, bias=False): 208 | super(Simple_Gudi_UpConv_Block_Last_Layer, self).__init__() 209 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=bias) 210 | self.oheight = oheight 211 | self.owidth = owidth 212 | self._up_pool = Unpool(in_channels) 213 | 214 | def _up_pooling(self, x, scale): 215 | 216 | x = self._up_pool(x) 217 | if self.oheight != 0 and self.owidth != 0: 218 | x = x.narrow(2, 0, self.oheight) 219 | x = x.narrow(3, 0, self.owidth) 220 | return x 221 | 222 | def forward(self, x): 223 | x = self._up_pooling(x, 2) 224 | out = self.conv1(x) 225 | return out 226 | 227 | class Gudi_UpProj_Block(nn.Module): 228 | def __init__(self, in_channels, out_channels, oheight=0, owidth=0, bias=False): 229 | super(Gudi_UpProj_Block, self).__init__() 230 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, padding=2, bias=bias) 231 | self.bn1 = nn.BatchNorm2d(out_channels) 232 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=bias) 233 | self.bn2 = nn.BatchNorm2d(out_channels) 234 | self.sc_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, padding=2, bias=bias) 235 | self.sc_bn1 = nn.BatchNorm2d(out_channels) 236 | self.relu = nn.ReLU(inplace=True) 237 | self.oheight = oheight 238 | self.owidth = owidth 239 | 240 | def _up_pooling(self, x, scale): 241 | 242 | x = nn.Upsample(scale_factor=scale, mode='nearest')(x) 243 | if self.oheight !=0 and self.owidth !=0: 244 | x = x[:,:,0:self.oheight, 0:self.owidth] 245 | mask = torch.zeros_like(x) 246 | for h in range(0, self.oheight, 2): 247 | for w in range(0, self.owidth, 2): 248 | mask[:,:,h,w] = 1 249 | x = torch.mul(mask, x) 250 | return x 251 | 252 | def forward(self, x): 253 | x = self._up_pooling(x, 2) 254 | out = self.relu(self.bn1(self.conv1(x))) 255 | out = self.bn2(self.conv2(out)) 256 | short_cut = self.sc_bn1(self.sc_conv1(x)) 257 | out += short_cut 258 | out = self.relu(out) 259 | return out 260 | 261 | 262 | class Gudi_UpProj_Block_Cat(nn.Module): 263 | def __init__(self, in_channels, out_channels, oheight=0, owidth=0, bias=False): 264 | super(Gudi_UpProj_Block_Cat, self).__init__() 265 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, padding=2, bias=bias) 266 | self.bn1 = nn.BatchNorm2d(out_channels) 267 | self.conv1_1 = nn.Conv2d(out_channels*2, out_channels, kernel_size=3, stride=1, padding=1, bias=bias) 268 | self.bn1_1 = nn.BatchNorm2d(out_channels) 269 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=bias) 270 | self.bn2 = nn.BatchNorm2d(out_channels) 271 | self.sc_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, padding=2, bias=bias) 272 | self.sc_bn1 = nn.BatchNorm2d(out_channels) 273 | self.relu = nn.ReLU(inplace=True) 274 | self.oheight = oheight 275 | self.owidth = owidth 276 | self._up_pool = Unpool(in_channels) 277 | 278 | def _up_pooling(self, x, scale): 279 | 280 | x = self._up_pool(x) 281 | if self.oheight !=0 and self.owidth !=0: 282 | x = x.narrow(2, 0, self.oheight) 283 | x = x.narrow(3, 0, self.owidth) 284 | return x 285 | 286 | def forward(self, x, side_input): 287 | x = self._up_pooling(x, 2) 288 | out = self.relu(self.bn1(self.conv1(x))) 289 | out = torch.cat((out, side_input), 1) 290 | out = self.relu(self.bn1_1(self.conv1_1(out))) 291 | out = self.bn2(self.conv2(out)) 292 | short_cut = self.sc_bn1(self.sc_conv1(x)) 293 | out += short_cut 294 | out = self.relu(out) 295 | return out 296 | 297 | class ResNet(nn.Module): 298 | def __init__(self, block, layers, up_proj_block, cspn_config=None, input_size=(240, 320)): 299 | self.inplanes = 64 300 | iterations = 48 301 | std_iterations = 24 302 | cspn_config_default = {'step': iterations, 'kernel': 3, 'norm_type': '8sum'} 303 | if not (cspn_config is None): 304 | cspn_config_default.update(cspn_config) 305 | print(cspn_config_default) 306 | 307 | super(ResNet, self).__init__() 308 | in_channels = 4 309 | self.conv1_1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 310 | self.bn1 = nn.BatchNorm2d(64) 311 | self.relu = nn.ReLU(inplace=True) 312 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 313 | self.layer1 = self._make_layer(block, 64, layers[0]) 314 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 315 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 316 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 317 | self.mid_channel = 256*block.expansion 318 | self.conv2 = nn.Conv2d(512*block.expansion, 512*block.expansion, kernel_size=3, 319 | stride=1, padding=1, bias=False) 320 | self.bn2 = nn.BatchNorm2d(512*block.expansion) 321 | 322 | h_2, w_2 = input_size[0] // 2, input_size[1] // 2 323 | h_4, w_4 = h_2 // 2, w_2 // 2 324 | h_8, w_8 = h_4 // 2, w_4 // 2 325 | h_16, w_16 = h_8 // 2, w_8 // 2 326 | self.post_process_layer = self._make_post_process_layer(cspn_config_default) 327 | 328 | # depth branch 329 | self.gud_up_proj_layer1 = self._make_gud_up_conv_layer(Gudi_UpProj_Block, 512 * block.expansion, 256 * block.expansion, h_16, w_16) 330 | self.gud_up_proj_layer2 = self._make_gud_up_conv_layer(Gudi_UpProj_Block_Cat, 256 * block.expansion, 128 * block.expansion, h_8, w_8) 331 | self.gud_up_proj_layer3 = self._make_gud_up_conv_layer(Gudi_UpProj_Block_Cat, 128 * block.expansion, 64 * block.expansion, h_4, w_4) 332 | self.gud_up_proj_layer4 = self._make_gud_up_conv_layer(Gudi_UpProj_Block_Cat, 64 * block.expansion, 64, h_2, w_2) 333 | self.gud_up_proj_layer5 = self._make_gud_up_conv_layer(Simple_Gudi_UpConv_Block_Last_Layer, 64, 1, input_size[0], input_size[1]) 334 | self.gud_up_proj_layer6 = self._make_gud_up_conv_layer(Simple_Gudi_UpConv_Block_Last_Layer, 64, 8, input_size[0], input_size[1]) 335 | 336 | # standard deviation branch 337 | self.gud_up_proj_layer1_std = self._make_gud_up_conv_layer(Gudi_UpProj_Block, 512 * block.expansion, 256 * block.expansion, h_16, w_16) 338 | self.gud_up_proj_layer2_std = self._make_gud_up_conv_layer(Gudi_UpProj_Block_Cat, 256 * block.expansion, 128 * block.expansion, h_8, w_8) 339 | self.gud_up_proj_layer3_std = self._make_gud_up_conv_layer(Gudi_UpProj_Block_Cat, 128 * block.expansion, 64 * block.expansion, h_4, w_4) 340 | self.gud_up_proj_layer4_std = self._make_gud_up_conv_layer(Gudi_UpProj_Block_Cat, 64 * block.expansion, 64, h_2, w_2) 341 | self.gud_up_proj_layer5_std = self._make_gud_up_conv_layer(Simple_Gudi_UpConv_Block_Last_Layer, 64, 1, input_size[0], input_size[1]) 342 | self.gud_up_proj_layer6_std = self._make_gud_up_conv_layer(Simple_Gudi_UpConv_Block_Last_Layer, 64, 8, input_size[0], input_size[1]) 343 | cspn_config_std = {'step': std_iterations, 'kernel': 3, 'norm_type': '8sum_abs'} 344 | self.post_process_layer_std = self._make_post_process_layer(cspn_config_std) 345 | 346 | def _make_layer(self, block, planes, blocks, stride=1): 347 | downsample = None 348 | if stride != 1 or self.inplanes != planes * block.expansion: 349 | downsample = nn.Sequential( 350 | nn.Conv2d(self.inplanes, planes * block.expansion, 351 | kernel_size=1, stride=stride, bias=False), 352 | nn.BatchNorm2d(planes * block.expansion), 353 | ) 354 | 355 | layers = [] 356 | layers.append(block(self.inplanes, planes, stride, downsample)) 357 | self.inplanes = planes * block.expansion 358 | for i in range(1, blocks): 359 | layers.append(block(self.inplanes, planes)) 360 | 361 | return nn.Sequential(*layers) 362 | 363 | def _make_up_conv_layer(self, up_proj_block, in_channels, out_channels): 364 | return up_proj_block(in_channels, out_channels) 365 | 366 | def _make_gud_up_conv_layer(self, up_proj_block, in_channels, out_channels, oheight, owidth, bias=False): 367 | return up_proj_block(in_channels, out_channels, oheight, owidth, bias) 368 | 369 | def _make_post_process_layer(self, cspn_config=None): 370 | return Affinity_Propagate(cspn_config['step'], 371 | cspn_config['kernel'], 372 | norm_type=cspn_config['norm_type']) 373 | 374 | def forward(self, x): 375 | [batch_size, channel, height, width] = x.size() 376 | sparse_depth = x.narrow(1,channel - 1,1).clone() 377 | x = self.conv1_1(x) 378 | skip4 = x 379 | 380 | x = self.bn1(x) 381 | x = self.relu(x) 382 | x = self.maxpool(x) 383 | x = self.layer1(x) 384 | skip3 = x 385 | 386 | x = self.layer2(x) 387 | skip2 = x 388 | 389 | x = self.layer3(x) 390 | x = self.layer4(x) 391 | x = self.bn2(self.conv2(x)) 392 | 393 | std = self.gud_up_proj_layer1_std(x) 394 | std = self.gud_up_proj_layer2_std(std, skip2) 395 | std = self.gud_up_proj_layer3_std(std, skip3) 396 | std = self.gud_up_proj_layer4_std(std, skip4) 397 | guidance_std = self.gud_up_proj_layer6_std(std) 398 | std = self.gud_up_proj_layer5_std(std) 399 | std = F.softplus(self.post_process_layer_std(guidance_std, std), beta=20) 400 | 401 | x = self.gud_up_proj_layer1(x) 402 | x = self.gud_up_proj_layer2(x, skip2) 403 | x = self.gud_up_proj_layer3(x, skip3) 404 | x = self.gud_up_proj_layer4(x, skip4) 405 | guidance = self.gud_up_proj_layer6(x) 406 | x = self.gud_up_proj_layer5(x) 407 | x = self.post_process_layer(guidance, x, sparse_depth) 408 | 409 | return x, std 410 | 411 | def resnet18_skip(pretrained=False, pretrained_path='', map_location=None, **kwargs): 412 | """Constructs a ResNet-18 model. 413 | Args: 414 | pretrained (bool): If True, returns a model pre-trained on ImageNet 415 | """ 416 | model = ResNet(BasicBlock, [2, 2, 2, 2], UpProj_Block, **kwargs) 417 | if pretrained: 418 | print('==> Load pretrained model..') 419 | pretrained_dict = torch.load(pretrained_path, map_location=map_location) 420 | model.load_state_dict(update_model(model, pretrained_dict)) 421 | return model 422 | 423 | def resnet34(pretrained=False, **kwargs): 424 | """Constructs a ResNet-34 model. 425 | Args: 426 | pretrained (bool): If True, returns a model pre-trained on ImageNet 427 | """ 428 | model = ResNet(BasicBlock, [3, 4, 6, 3], UpProj_Block, **kwargs) 429 | if pretrained: 430 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 431 | return model 432 | 433 | 434 | def resnet50_skip(pretrained=False, checkpoint_dir='', **kwargs): 435 | """Constructs a ResNet-50 model. 436 | Args: 437 | pretrained (bool): If True, returns a model pre-trained on ImageNet 438 | """ 439 | model = ResNet(Bottleneck, [3, 4, 6, 3], UpProj_Block, **kwargs) 440 | if pretrained: 441 | print('==> Load pretrained model from ', model_path['resnet50']) 442 | pretrained_dict = torch.load(os.path.join(checkpoint_dir, model_path['resnet50'])) 443 | model.load_state_dict(update_model(model, pretrained_dict)) 444 | return model 445 | 446 | 447 | def resnet101(pretrained=False, **kwargs): 448 | """Constructs a ResNet-101 model. 449 | Args: 450 | pretrained (bool): If True, returns a model pre-trained on ImageNet 451 | """ 452 | model = ResNet(Bottleneck, [3, 4, 23, 3], UpProj_Block, **kwargs) 453 | if pretrained: 454 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 455 | return model 456 | 457 | 458 | def resnet152(pretrained=False, **kwargs): 459 | """Constructs a ResNet-152 model. 460 | Args: 461 | pretrained (bool): If True, returns a model pre-trained on ImageNet 462 | """ 463 | model = ResNet(Bottleneck, [3, 8, 36, 3], UpProj_Block, **kwargs) 464 | if pretrained: 465 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 466 | return model 467 | -------------------------------------------------------------------------------- /model/cspn_affinity.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Xinjing Cheng & Peng Wang 3 | 4 | """ 5 | 6 | import torch.nn as nn 7 | import torch 8 | 9 | class Affinity_Propagate(nn.Module): 10 | 11 | def __init__(self, 12 | prop_time, 13 | prop_kernel, 14 | norm_type='8sum'): 15 | """ 16 | 17 | Inputs: 18 | prop_time: how many steps for CSPN to perform 19 | prop_kernel: the size of kernel (current only support 3x3) 20 | way to normalize affinity 21 | '8sum': normalize using 8 surrounding neighborhood 22 | '8sum_abs': normalization enforcing affinity to be positive 23 | This will lead the center affinity to be 0 24 | """ 25 | super(Affinity_Propagate, self).__init__() 26 | self.prop_time = prop_time 27 | self.prop_kernel = prop_kernel 28 | assert prop_kernel == 3, 'this version only support 8 (3x3 - 1) neighborhood' 29 | 30 | self.norm_type = norm_type 31 | assert norm_type in ['8sum', '8sum_abs'] 32 | 33 | self.in_feature = 1 34 | self.out_feature = 1 35 | 36 | 37 | def forward(self, guidance, blur_depth, sparse_depth=None): 38 | 39 | self.sum_conv = nn.Conv3d(in_channels=8, 40 | out_channels=1, 41 | kernel_size=(1, 1, 1), 42 | stride=1, 43 | padding=0, 44 | bias=False) 45 | weight = torch.ones(1, 8, 1, 1, 1).cuda() 46 | self.sum_conv.weight = nn.Parameter(weight) 47 | for param in self.sum_conv.parameters(): 48 | param.requires_grad = False 49 | 50 | gate_wb, gate_sum = self.affinity_normalization(guidance) 51 | 52 | # pad input and convert to 8 channel 3D features 53 | raw_depth_input = blur_depth 54 | 55 | #blur_depht_pad = nn.ZeroPad2d((1,1,1,1)) 56 | result_depth = blur_depth 57 | 58 | if sparse_depth is not None: 59 | sparse_mask = sparse_depth.sign() 60 | 61 | for i in range(self.prop_time): 62 | # one propagation 63 | spn_kernel = self.prop_kernel 64 | result_depth = self.pad_blur_depth(result_depth) 65 | neigbor_weighted_sum = self.sum_conv(gate_wb * result_depth) 66 | neigbor_weighted_sum = neigbor_weighted_sum.squeeze(1) 67 | neigbor_weighted_sum = neigbor_weighted_sum[:, :, 1:-1, 1:-1] 68 | result_depth = neigbor_weighted_sum 69 | 70 | if '8sum' in self.norm_type: 71 | result_depth = (1.0 - gate_sum) * raw_depth_input + result_depth 72 | else: 73 | raise ValueError('unknown norm %s' % self.norm_type) 74 | 75 | if sparse_depth is not None: 76 | result_depth = (1 - sparse_mask) * result_depth + sparse_mask * raw_depth_input 77 | 78 | return result_depth 79 | 80 | def affinity_normalization(self, guidance): 81 | 82 | # normalize features 83 | if 'abs' in self.norm_type: 84 | guidance = torch.abs(guidance) 85 | 86 | gate1_wb_cmb = guidance.narrow(1, 0 , self.out_feature) 87 | gate2_wb_cmb = guidance.narrow(1, 1 * self.out_feature, self.out_feature) 88 | gate3_wb_cmb = guidance.narrow(1, 2 * self.out_feature, self.out_feature) 89 | gate4_wb_cmb = guidance.narrow(1, 3 * self.out_feature, self.out_feature) 90 | gate5_wb_cmb = guidance.narrow(1, 4 * self.out_feature, self.out_feature) 91 | gate6_wb_cmb = guidance.narrow(1, 5 * self.out_feature, self.out_feature) 92 | gate7_wb_cmb = guidance.narrow(1, 6 * self.out_feature, self.out_feature) 93 | gate8_wb_cmb = guidance.narrow(1, 7 * self.out_feature, self.out_feature) 94 | 95 | # gate1:left_top, gate2:center_top, gate3:right_top 96 | # gate4:left_center, , gate5: right_center 97 | # gate6:left_bottom, gate7: center_bottom, gate8: right_bottm 98 | 99 | # top pad 100 | left_top_pad = nn.ZeroPad2d((0,2,0,2)) 101 | gate1_wb_cmb = left_top_pad(gate1_wb_cmb).unsqueeze(1) 102 | 103 | center_top_pad = nn.ZeroPad2d((1,1,0,2)) 104 | gate2_wb_cmb = center_top_pad(gate2_wb_cmb).unsqueeze(1) 105 | 106 | right_top_pad = nn.ZeroPad2d((2,0,0,2)) 107 | gate3_wb_cmb = right_top_pad(gate3_wb_cmb).unsqueeze(1) 108 | 109 | # center pad 110 | left_center_pad = nn.ZeroPad2d((0,2,1,1)) 111 | gate4_wb_cmb = left_center_pad(gate4_wb_cmb).unsqueeze(1) 112 | 113 | right_center_pad = nn.ZeroPad2d((2,0,1,1)) 114 | gate5_wb_cmb = right_center_pad(gate5_wb_cmb).unsqueeze(1) 115 | 116 | # bottom pad 117 | left_bottom_pad = nn.ZeroPad2d((0,2,2,0)) 118 | gate6_wb_cmb = left_bottom_pad(gate6_wb_cmb).unsqueeze(1) 119 | 120 | center_bottom_pad = nn.ZeroPad2d((1,1,2,0)) 121 | gate7_wb_cmb = center_bottom_pad(gate7_wb_cmb).unsqueeze(1) 122 | 123 | right_bottm_pad = nn.ZeroPad2d((2,0,2,0)) 124 | gate8_wb_cmb = right_bottm_pad(gate8_wb_cmb).unsqueeze(1) 125 | 126 | gate_wb = torch.cat((gate1_wb_cmb,gate2_wb_cmb,gate3_wb_cmb,gate4_wb_cmb, 127 | gate5_wb_cmb,gate6_wb_cmb,gate7_wb_cmb,gate8_wb_cmb), 1) 128 | 129 | # normalize affinity using their abs sum 130 | gate_wb_abs = torch.abs(gate_wb) 131 | abs_weight = self.sum_conv(gate_wb_abs) 132 | 133 | gate_wb = torch.div(gate_wb, abs_weight.clamp(min=1e-6)) 134 | gate_sum = self.sum_conv(gate_wb) 135 | 136 | gate_sum = gate_sum.squeeze(1) 137 | gate_sum = gate_sum[:, :, 1:-1, 1:-1] 138 | 139 | return gate_wb, gate_sum 140 | 141 | def pad_blur_depth(self, blur_depth): 142 | # top pad 143 | left_top_pad = nn.ZeroPad2d((0,2,0,2)) 144 | blur_depth_1 = left_top_pad(blur_depth).unsqueeze(1) 145 | center_top_pad = nn.ZeroPad2d((1,1,0,2)) 146 | blur_depth_2 = center_top_pad(blur_depth).unsqueeze(1) 147 | right_top_pad = nn.ZeroPad2d((2,0,0,2)) 148 | blur_depth_3 = right_top_pad(blur_depth).unsqueeze(1) 149 | 150 | # center pad 151 | left_center_pad = nn.ZeroPad2d((0,2,1,1)) 152 | blur_depth_4 = left_center_pad(blur_depth).unsqueeze(1) 153 | right_center_pad = nn.ZeroPad2d((2,0,1,1)) 154 | blur_depth_5 = right_center_pad(blur_depth).unsqueeze(1) 155 | 156 | # bottom pad 157 | left_bottom_pad = nn.ZeroPad2d((0,2,2,0)) 158 | blur_depth_6 = left_bottom_pad(blur_depth).unsqueeze(1) 159 | center_bottom_pad = nn.ZeroPad2d((1,1,2,0)) 160 | blur_depth_7 = center_bottom_pad(blur_depth).unsqueeze(1) 161 | right_bottm_pad = nn.ZeroPad2d((2,0,2,0)) 162 | blur_depth_8 = right_bottm_pad(blur_depth).unsqueeze(1) 163 | 164 | result_depth = torch.cat((blur_depth_1, blur_depth_2, blur_depth_3, blur_depth_4, 165 | blur_depth_5, blur_depth_6, blur_depth_7, blur_depth_8), 1) 166 | return result_depth 167 | 168 | 169 | def normalize_gate(self, guidance): 170 | gate1_x1_g1 = guidance.narrow(1,0,1) 171 | gate1_x1_g2 = guidance.narrow(1,1,1) 172 | gate1_x1_g1_abs = torch.abs(gate1_x1_g1) 173 | gate1_x1_g2_abs = torch.abs(gate1_x1_g2) 174 | elesum_gate1_x1 = torch.add(gate1_x1_g1_abs, gate1_x1_g2_abs) 175 | gate1_x1_g1_cmb = torch.div(gate1_x1_g1, elesum_gate1_x1) 176 | gate1_x1_g2_cmb = torch.div(gate1_x1_g2, elesum_gate1_x1) 177 | return gate1_x1_g1_cmb, gate1_x1_g2_cmb 178 | 179 | 180 | def max_of_4_tensor(self, element1, element2, element3, element4): 181 | max_element1_2 = torch.max(element1, element2) 182 | max_element3_4 = torch.max(element3, element4) 183 | return torch.max(max_element1_2, max_element3_4) 184 | 185 | def max_of_8_tensor(self, element1, element2, element3, element4, element5, element6, element7, element8): 186 | max_element1_2 = self.max_of_4_tensor(element1, element2, element3, element4) 187 | max_element3_4 = self.max_of_4_tensor(element5, element6, element7, element8) 188 | return torch.max(max_element1_2, max_element3_4) 189 | -------------------------------------------------------------------------------- /model/run_nerf_helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | # Misc 7 | img2mse = lambda x, y : torch.mean((x - y) ** 2) 8 | mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.full((1,), 10., device=x.device)) 9 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8) 10 | to16b = lambda x : ((2**16 - 1) * np.clip(x,0,1)).astype(np.uint16) 11 | 12 | def precompute_quadratic_samples(near, far, num_samples): 13 | # normal parabola between 0.1 and 1, shifted and scaled to have y range between near and far 14 | start = 0.1 15 | x = torch.linspace(0, 1, num_samples) 16 | c = near 17 | a = (far - near)/(1. + 2. * start) 18 | b = 2. * start * a 19 | return a * x.pow(2) + b * x + c 20 | 21 | def is_not_in_expected_distribution(depth_mean, depth_var, depth_measurement_mean, depth_measurement_std): 22 | delta_greater_than_expected = ((depth_mean - depth_measurement_mean).abs() - depth_measurement_std) > 0. 23 | var_greater_than_expected = depth_measurement_std.pow(2) < depth_var 24 | return torch.logical_or(delta_greater_than_expected, var_greater_than_expected) 25 | 26 | def compute_depth_loss(depth_map, z_vals, weights, target_depth, target_valid_depth): 27 | pred_mean = depth_map[target_valid_depth] 28 | if pred_mean.shape[0] == 0: 29 | return torch.zeros((1,), device=depth_map.device, requires_grad=True) 30 | pred_var = ((z_vals[target_valid_depth] - pred_mean.unsqueeze(-1)).pow(2) * weights[target_valid_depth]).sum(-1) + 1e-5 31 | target_mean = target_depth[..., 0][target_valid_depth] 32 | target_std = target_depth[..., 1][target_valid_depth] 33 | apply_depth_loss = is_not_in_expected_distribution(pred_mean, pred_var, target_mean, target_std) 34 | pred_mean = pred_mean[apply_depth_loss] 35 | if pred_mean.shape[0] == 0: 36 | return torch.zeros((1,), device=depth_map.device, requires_grad=True) 37 | pred_var = pred_var[apply_depth_loss] 38 | target_mean = target_mean[apply_depth_loss] 39 | target_std = target_std[apply_depth_loss] 40 | f = nn.GaussianNLLLoss(eps=0.001) 41 | return float(pred_mean.shape[0]) / float(target_valid_depth.shape[0]) * f(pred_mean, target_mean, pred_var) 42 | 43 | class DenseLayer(nn.Linear): 44 | def __init__(self, in_dim: int, out_dim: int, activation: str = "relu", *args, **kwargs) -> None: 45 | self.activation = activation 46 | super().__init__(in_dim, out_dim, *args, **kwargs) 47 | 48 | def reset_parameters(self) -> None: 49 | torch.nn.init.xavier_uniform_(self.weight, gain=torch.nn.init.calculate_gain(self.activation)) 50 | if self.bias is not None: 51 | torch.nn.init.zeros_(self.bias) 52 | 53 | # Positional encoding (section 5.1) 54 | class Embedder: 55 | def __init__(self, **kwargs): 56 | self.kwargs = kwargs 57 | self.create_embedding_fn() 58 | 59 | def create_embedding_fn(self): 60 | embed_fns = [] 61 | d = self.kwargs['input_dims'] 62 | out_dim = 0 63 | if self.kwargs['include_input']: 64 | embed_fns.append(lambda x : x) 65 | out_dim += d 66 | 67 | max_freq = self.kwargs['max_freq_log2'] 68 | N_freqs = self.kwargs['num_freqs'] 69 | 70 | if self.kwargs['log_sampling']: 71 | freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) 72 | else: 73 | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) 74 | 75 | for freq in freq_bands: 76 | for p_fn in self.kwargs['periodic_fns']: 77 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * np.pi * freq)) 78 | out_dim += d 79 | 80 | self.embed_fns = embed_fns 81 | self.out_dim = out_dim 82 | 83 | def embed(self, inputs): 84 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 85 | 86 | def get_embedder(multires, i=0): 87 | if i == -1: 88 | return nn.Identity(), 3 89 | 90 | embed_kwargs = { 91 | 'include_input' : True, 92 | 'input_dims' : 3, 93 | 'max_freq_log2' : multires-1, 94 | 'num_freqs' : multires, 95 | 'log_sampling' : True, 96 | 'periodic_fns' : [torch.sin, torch.cos], 97 | } 98 | 99 | embedder_obj = Embedder(**embed_kwargs) 100 | embed = lambda x, eo=embedder_obj : eo.embed(x) 101 | return embed, embedder_obj.out_dim 102 | 103 | 104 | # Model 105 | class NeRF(nn.Module): 106 | def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, input_ch_cam=0, output_ch=4, skips=[4], use_viewdirs=False): 107 | """ 108 | """ 109 | super(NeRF, self).__init__() 110 | self.D = D 111 | self.W = W 112 | self.input_ch = input_ch 113 | self.input_ch_views = input_ch_views 114 | self.input_ch_cam = input_ch_cam 115 | self.skips = skips 116 | self.use_viewdirs = use_viewdirs 117 | 118 | self.pts_linears = nn.ModuleList( 119 | [DenseLayer(input_ch, W, activation="relu")] + [DenseLayer(W, W, activation="relu") if i not in self.skips else DenseLayer(W + input_ch, W, activation="relu") for i in range(D-1)]) 120 | 121 | ### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105) 122 | self.views_linears = nn.ModuleList([DenseLayer(input_ch_views + input_ch_cam + W, W//2, activation="relu")]) 123 | 124 | ### Implementation according to the paper 125 | # self.views_linears = nn.ModuleList( 126 | # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)]) 127 | 128 | if use_viewdirs: 129 | self.feature_linear = DenseLayer(W, W, activation="linear") 130 | self.alpha_linear = DenseLayer(W, 1, activation="linear") 131 | self.rgb_linear = DenseLayer(W//2, 3, activation="linear") 132 | else: 133 | self.output_linear = DenseLayer(W, output_ch, activation="linear") 134 | 135 | def forward(self, x): 136 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views + self.input_ch_cam], dim=-1) 137 | h = input_pts 138 | for i, l in enumerate(self.pts_linears): 139 | h = self.pts_linears[i](h) 140 | h = F.relu(h) 141 | if i in self.skips: 142 | h = torch.cat([input_pts, h], -1) 143 | 144 | if self.use_viewdirs: 145 | alpha = self.alpha_linear(h) 146 | feature = self.feature_linear(h) 147 | h = torch.cat([feature, input_views], -1) 148 | 149 | for i, l in enumerate(self.views_linears): 150 | h = self.views_linears[i](h) 151 | h = F.relu(h) 152 | 153 | rgb = self.rgb_linear(h) 154 | outputs = torch.cat([rgb, F.softplus(alpha, beta=10)], -1) 155 | else: 156 | outputs = self.output_linear(h) 157 | outputs = torch.cat([outputs[..., :3], F.softplus(outputs[..., 3:], beta=10)], -1) 158 | 159 | return outputs 160 | 161 | def load_weights_from_keras(self, weights): 162 | assert self.use_viewdirs, "Not implemented if use_viewdirs=False" 163 | 164 | # Load pts_linears 165 | for i in range(self.D): 166 | idx_pts_linears = 2 * i 167 | self.pts_linears[i].weight.data = torch.from_numpy(np.transpose(weights[idx_pts_linears])) 168 | self.pts_linears[i].bias.data = torch.from_numpy(np.transpose(weights[idx_pts_linears+1])) 169 | 170 | # Load feature_linear 171 | idx_feature_linear = 2 * self.D 172 | self.feature_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_feature_linear])) 173 | self.feature_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_feature_linear+1])) 174 | 175 | # Load views_linears 176 | idx_views_linears = 2 * self.D + 2 177 | self.views_linears[0].weight.data = torch.from_numpy(np.transpose(weights[idx_views_linears])) 178 | self.views_linears[0].bias.data = torch.from_numpy(np.transpose(weights[idx_views_linears+1])) 179 | 180 | # Load rgb_linear 181 | idx_rbg_linear = 2 * self.D + 4 182 | self.rgb_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear])) 183 | self.rgb_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear+1])) 184 | 185 | # Load alpha_linear 186 | idx_alpha_linear = 2 * self.D + 6 187 | self.alpha_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear])) 188 | self.alpha_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear+1])) 189 | 190 | def select_coordinates(coords, N_rand): 191 | coords = torch.reshape(coords, [-1,2]) # (H * W, 2) 192 | select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,) 193 | select_coords = coords[select_inds].long() # (N_rand, 2) 194 | return select_coords 195 | 196 | def get_ray_dirs(H, W, intrinsic, c2w, coords=None): 197 | device = intrinsic.device 198 | fx, fy, cx, cy = intrinsic[0], intrinsic[1], intrinsic[2], intrinsic[3] 199 | if coords is None: 200 | i, j = torch.meshgrid(torch.linspace(0, W-1, W, device=device), torch.linspace(0, H-1, H, device=device), indexing='ij') # pytorch's meshgrid has indexing='ij' 201 | i = i.t() 202 | j = j.t() 203 | else: 204 | i, j = coords[:, 1], coords[:, 0] 205 | # conversion from pixels to camera coordinates 206 | dirs = torch.stack([((i + 0.5)-cx)/fx, (H - (j + 0.5) - cy)/fy, -torch.ones_like(i)], -1) # center of pixel 207 | # Rotate ray directions from camera frame to the world frame 208 | rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 209 | return rays_d 210 | 211 | # Ray helpers 212 | def get_rays(H, W, intrinsic, c2w, coords=None): 213 | rays_d = get_ray_dirs(H, W, intrinsic, c2w, coords) 214 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 215 | rays_o = c2w[:3,-1].expand(rays_d.shape) 216 | return rays_o, rays_d 217 | 218 | def get_rays_np(H, W, focal, c2w): 219 | i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy') 220 | dirs = np.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -np.ones_like(i)], -1) 221 | # Rotate ray directions from camera frame to the world frame 222 | rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 223 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 224 | rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d)) 225 | return rays_o, rays_d 226 | 227 | def ndc_rays(H, W, focal, near, rays_o, rays_d): 228 | # Shift ray origins to near plane 229 | t = -(near + rays_o[...,2]) / rays_d[...,2] 230 | rays_o = rays_o + t[...,None] * rays_d 231 | 232 | # Projection 233 | o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2] 234 | o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2] 235 | o2 = 1. + 2. * near / rays_o[...,2] 236 | 237 | d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2]) 238 | d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2]) 239 | d2 = -2. * near / rays_o[...,2] 240 | 241 | rays_o = torch.stack([o0,o1,o2], -1) 242 | rays_d = torch.stack([d0,d1,d2], -1) 243 | 244 | return rays_o, rays_d 245 | 246 | 247 | # Hierarchical sampling (section 5.2) 248 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False): 249 | # Get pdf 250 | weights = weights + 1e-5 # prevent nans 251 | pdf = weights / torch.sum(weights, -1, keepdim=True) 252 | cdf = torch.cumsum(pdf, -1) 253 | cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1) # (batch, len(bins)) 254 | 255 | # Take uniform samples 256 | if det: 257 | u = torch.linspace(0., 1., steps=N_samples, device=bins.device) 258 | u = u.expand(list(cdf.shape[:-1]) + [N_samples]) 259 | else: 260 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples], device=bins.device) 261 | 262 | # Pytest, overwrite u with numpy's fixed random numbers 263 | if pytest: 264 | np.random.seed(0) 265 | new_shape = list(cdf.shape[:-1]) + [N_samples] 266 | if det: 267 | u = np.linspace(0., 1., N_samples) 268 | u = np.broadcast_to(u, new_shape) 269 | else: 270 | u = np.random.rand(*new_shape) 271 | u = torch.Tensor(u) 272 | 273 | # Invert CDF 274 | u = u.contiguous() 275 | inds = torch.searchsorted(cdf, u, right=True) 276 | below = torch.max(torch.zeros_like(inds-1), inds-1) 277 | above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds) 278 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 279 | 280 | # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 281 | # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 282 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 283 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 284 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 285 | 286 | denom = (cdf_g[...,1]-cdf_g[...,0]) 287 | denom = torch.where(denom<1e-5, torch.ones_like(denom), denom) 288 | t = (u-cdf_g[...,0])/denom 289 | samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0]) 290 | 291 | return samples 292 | -------------------------------------------------------------------------------- /preprocessing/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.11) 2 | 3 | project(ogl_renderer LANGUAGES CXX) 4 | 5 | find_package(OpenCV REQUIRED) 6 | find_package(Boost COMPONENTS filesystem REQUIRED) 7 | find_package(glm REQUIRED) 8 | find_package(RapidJSON REQUIRED) 9 | 10 | include_directories(${OpenCV_INCLUDE_DIRS} ${Boost_INCLUDE_DIR}) 11 | 12 | add_executable(extract_scannet_scene 13 | extract_scannet_scene.cpp 14 | ) 15 | 16 | target_link_libraries(extract_scannet_scene 17 | io 18 | io_colmap 19 | camera 20 | ${Boost_FILESYSTEM_LIBRARY} 21 | ) 22 | 23 | add_subdirectory(io) 24 | add_subdirectory(io_colmap) 25 | add_subdirectory(camera) 26 | -------------------------------------------------------------------------------- /preprocessing/camera/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(camera 2 | include/camera.h 3 | src/camera.cpp 4 | ) 5 | target_include_directories(camera PUBLIC 6 | ${CMAKE_CURRENT_SOURCE_DIR}/include 7 | ) 8 | -------------------------------------------------------------------------------- /preprocessing/camera/include/camera.h: -------------------------------------------------------------------------------- 1 | #ifndef CAMERA_H 2 | #define CAMERA_H 3 | 4 | #include 5 | 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | // The camera frame is defined with z pointing in negative viewing direction. This corresponds to the OpenGL camera/ view space 13 | // The origin of the image frame is in the bottom left corner. 14 | class Camera 15 | { 16 | public: 17 | Camera(); 18 | Camera(float f_x, float f_y, float c_x, float c_y, const glm::mat4& world2cam); 19 | const glm::mat4& GetWorld2Cam() const; 20 | const glm::mat4 GetK() const; 21 | float GetFx() const; 22 | float GetFy() const; 23 | float GetCx() const; 24 | float GetCy() const; 25 | glm::vec3 GetPose() const; 26 | 27 | private: 28 | float f_x_; 29 | float f_y_; 30 | float c_x_; 31 | float c_y_; 32 | glm::mat4 world2cam_; 33 | }; 34 | 35 | int Y2Row(float y, int height); 36 | 37 | template 38 | float ComputeDepth(const std::vector& point_cloud, float depth_scaling_factor, float max_depth, const Camera& camera, 39 | cv::Mat& depth_map, Functor& Verify) 40 | { 41 | const auto height = depth_map.rows; 42 | const auto width = depth_map.cols; 43 | 44 | const auto& K(camera.GetK()); 45 | const auto& world2cam(camera.GetWorld2Cam()); 46 | const auto rot_world2cam = glm::mat3(world2cam); 47 | 48 | int point_idx{0}; 49 | std::size_t count{0}; 50 | for (const auto& point : point_cloud) 51 | { 52 | if (Verify(point_idx)) 53 | { 54 | const glm::vec4 point_world(point, 1.f); 55 | glm::vec3 point_cam = glm::vec3(world2cam * point_world); 56 | float z_cam(-point_cam[2]); 57 | if (z_cam > 0.f) 58 | { 59 | glm::vec3 point_img = glm::mat3(K) * point_cam; 60 | point_img = point_img / point_img[2]; 61 | if (point_img[0] >= 0 && point_img[0] < width && point_img[1] > 0 && point_img[1] <= height) 62 | { 63 | const auto r = Y2Row(point_img[1], height); 64 | assert(r >= 0 && r < height); 65 | const auto c = static_cast(point_img[0]); 66 | assert(c >= 0 && c < width); 67 | 68 | if (z_cam <= max_depth) 69 | { 70 | const auto z_cam_scaled = z_cam * depth_scaling_factor; 71 | auto z_cam_integer = static_cast(z_cam_scaled); 72 | 73 | const auto z_before = depth_map.at(r, c); 74 | if (z_before == 0 || z_cam_integer < z_before) 75 | { 76 | depth_map.at(r, c) = z_cam_integer; 77 | if (z_before == 0) 78 | { 79 | ++count; 80 | } 81 | } 82 | } 83 | else 84 | { 85 | std::cout << "Warning: Depth " << z_cam << " is ignored, because it is > maximal depth " << max_depth << std::endl; 86 | } 87 | } 88 | } 89 | } 90 | ++point_idx; 91 | } 92 | return static_cast(count) / static_cast(width * height); 93 | } 94 | 95 | std::pair ComputeMeanScaling(const cv::Mat& int_matrix_0, const cv::Mat& int_matrix_1); 96 | 97 | #endif // CAMERA_H 98 | -------------------------------------------------------------------------------- /preprocessing/camera/src/camera.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | 6 | #include 7 | 8 | Camera::Camera() 9 | : f_x_(0.f), f_y_(0.f), c_x_(0.f), c_y_(0.f), world2cam_(1.f) 10 | {} 11 | 12 | Camera::Camera(float f_x, float f_y, float c_x, float c_y, const glm::mat4 &world2cam) 13 | : f_x_(f_x), f_y_(f_y), c_x_(c_x), c_y_(c_y), world2cam_(world2cam) 14 | {} 15 | 16 | const glm::mat4& Camera::GetWorld2Cam() const 17 | { 18 | return world2cam_; 19 | } 20 | 21 | const glm::mat4 Camera::GetK() const 22 | { 23 | glm::mat4 K(1.f); 24 | K[0][0] = f_x_; 25 | K[1][1] = f_y_; 26 | K[2][0] = -c_x_; 27 | K[2][1] = -c_y_; 28 | K[2][2] = -1.f; 29 | return K; 30 | } 31 | 32 | float Camera::GetFx() const 33 | { 34 | return f_x_; 35 | } 36 | 37 | float Camera::GetFy() const 38 | { 39 | return f_y_; 40 | } 41 | 42 | float Camera::GetCx() const 43 | { 44 | return c_x_; 45 | } 46 | 47 | float Camera::GetCy() const 48 | { 49 | return c_y_; 50 | } 51 | 52 | glm::vec3 Camera::GetPose() const 53 | { 54 | const glm::mat3 rot_world2cam(world2cam_); 55 | const glm::mat3 rot_cam2world = glm::transpose(rot_world2cam); 56 | const glm::vec3 pose(- rot_cam2world * glm::vec3(world2cam_[3])); 57 | return pose; 58 | } 59 | 60 | int Y2Row(float y, int height) 61 | { 62 | return static_cast(static_cast(height) - y); 63 | } 64 | 65 | std::pair ComputeMeanScaling(const cv::Mat& int_matrix_0, const cv::Mat& int_matrix_1) 66 | { 67 | assert(int_matrix_0.size() == int_matrix_1.size()); 68 | double sum_scaling_factors{0}; 69 | std::size_t valid_count{0}; 70 | 71 | cv::MatConstIterator_ it_0(int_matrix_0.begin()), it_1(int_matrix_1.begin()), et_0(int_matrix_0.end()); 72 | for(; it_0 != et_0; ++it_0, ++it_1) 73 | { 74 | // both are valid 75 | if (*it_0 != 0 && *it_1 != 0) 76 | { 77 | sum_scaling_factors += static_cast(*it_0) / static_cast(*it_1); 78 | ++valid_count; 79 | } 80 | } 81 | return std::make_pair(sum_scaling_factors / static_cast(valid_count), valid_count); 82 | } 83 | 84 | -------------------------------------------------------------------------------- /preprocessing/extract_scannet_scene.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | struct VisibilityCheck 22 | { 23 | bool operator() (int point_idx) const 24 | { 25 | const auto& image_idxs = visibility_[point_idx]; 26 | return image_idxs.cend() != std::find(image_idxs.cbegin(), image_idxs.cend(), image_idx_); 27 | } 28 | 29 | const int image_idx_{0}; 30 | const std::vector>& visibility_{}; 31 | }; 32 | 33 | struct ColmapHelper 34 | { 35 | void Init(const boost::filesystem::path& path2colmap, float dist2m) 36 | { 37 | const auto path2sparse_train = path2colmap / "sparse_train" / "0"; 38 | train_camera_configs_ = ReadCameras((path2sparse_train / "cameras.txt").string(), (path2sparse_train / "images.txt").string(), 39 | dist2m); 40 | const auto path2sparse = path2colmap / "sparse" / "0"; 41 | camera_configs_ = ReadCameras((path2sparse / "cameras.txt").string(), (path2sparse / "images.txt").string(), dist2m); 42 | float max_reproj_error{std::numeric_limits::max()}; 43 | unsigned int min_track_length{0}; 44 | ReadSparsePointCloud((path2sparse_train / "points3D.txt").string(), dist2m, max_reproj_error, min_track_length, 45 | train_point_cloud_, train_visibility_); 46 | ReadSparsePointCloud((path2sparse / "points3D.txt").string(), dist2m, max_reproj_error, min_track_length, 47 | point_cloud_, visibility_); 48 | } 49 | 50 | CameraConfig GetCameraConfig(const std::string& filename, bool test_camera = true) const 51 | { 52 | bool found{false}; 53 | CameraConfig camera_config{}; 54 | for (const auto& config : (test_camera ? camera_configs_ : train_camera_configs_)) 55 | { 56 | if (config.rgb_image == filename) 57 | { 58 | found = true; 59 | camera_config = config; 60 | break; 61 | } 62 | } 63 | if (!found) 64 | { 65 | std::cout << "Error: Camera config " << filename << " was not found in Colmap reconstruction" << std::endl; 66 | } 67 | return camera_config; 68 | } 69 | void GetSparseDepth(const std::string& filename, const std::string& dataset_type, cv::Mat& sparse_depth, float max_depth, 70 | float depth_scaling_factor, const cv::Mat& target_depth) 71 | { 72 | // compute sparse depth to determine scaling using sparse reconstruction from all images 73 | const auto config = GetCameraConfig(filename); 74 | const auto curr_colmap_id = config.id; 75 | VisibilityCheck visibility_check{curr_colmap_id, visibility_}; 76 | ComputeDepth(point_cloud_, depth_scaling_factor, max_depth, config.camera, sparse_depth, visibility_check); 77 | // compute scaling between colmap and target depth 78 | const auto scaling_result = ComputeMeanScaling(target_depth, sparse_depth); 79 | const auto scaling = scaling_result.first; 80 | const auto count = scaling_result.second; 81 | if (count > 0) 82 | { 83 | global_scaling_ = static_cast(global_scaling_count_) / static_cast(global_scaling_count_ + count) * global_scaling_ + 84 | static_cast(count) / static_cast(global_scaling_count_ + count) * scaling; 85 | } 86 | global_scaling_count_ += count; 87 | 88 | // compute train sparse depth using sparse reconstruction from train images 89 | if (dataset_type == "train") 90 | { 91 | const auto train_config = GetCameraConfig(filename, false); 92 | const auto train_curr_colmap_id = train_config.id; 93 | VisibilityCheck train_visibility_check{train_curr_colmap_id, train_visibility_}; 94 | sparse_depth = cv::Mat::zeros(sparse_depth.rows, sparse_depth.cols, CV_16UC1); // reset to zero 95 | const auto percent = ComputeDepth(train_point_cloud_, depth_scaling_factor, max_depth, train_config.camera, 96 | sparse_depth, train_visibility_check); 97 | sum_percent_ += percent; 98 | ++count_; 99 | if (percent < 1e-6) 100 | { 101 | std::cout << "Warning: No train sparse depth in " << filename << std::endl; 102 | } 103 | } 104 | } 105 | float GetPercentValid() const 106 | { 107 | return sum_percent_ / static_cast(count_); 108 | } 109 | double GetScaling() const 110 | { 111 | return global_scaling_; 112 | } 113 | private: 114 | // sparse reconstruction from train images 115 | std::vector train_camera_configs_{}; 116 | std::vector train_point_cloud_{}; 117 | std::vector> train_visibility_{}; 118 | // sparse reconstruction from all images 119 | std::vector camera_configs_{}; 120 | std::vector point_cloud_{}; 121 | std::vector> visibility_{}; 122 | int count_{0}; // count processed train files 123 | float sum_percent_{0.f}; // sum pf percentage of train sparse depth 124 | double global_scaling_{0.}; 125 | std::size_t global_scaling_count_{0}; 126 | }; 127 | 128 | struct SceneConfig 129 | { 130 | std::string kName{}; // name of the scene 131 | float kMaxDepth{}; // maximal depth value in the scene, larger values are invalidated 132 | float kDist2M{}; // scaling factor that scales the sparse reconstruction to meters 133 | bool kRgbOnly{}; // write rgb only, f.ex. to get input for colmap 134 | }; 135 | 136 | SceneConfig LoadConfig(const boost::filesystem::path& path2scene) 137 | { 138 | std::cout << "Loading config for " << path2scene.string() << std::endl; 139 | std::ifstream ifs((path2scene / "config.json").string()); 140 | if (!ifs.is_open()) 141 | { 142 | std::cout << "Error: Could not open config.json" << std::endl; 143 | } 144 | rapidjson::IStreamWrapper isw(ifs); 145 | 146 | rapidjson::Document d; 147 | d.ParseStream(isw); 148 | 149 | return SceneConfig{d["name"].GetString(), static_cast(d["max_depth"].GetDouble()), 150 | static_cast(d["dist2m"].GetDouble()), d["rgb_only"].GetBool()}; 151 | } 152 | 153 | int main(int argc, char** argv) 154 | { 155 | if (argc != 3) 156 | { 157 | std::cout << "Usage: ./extract_scannet_scene " << std::endl; 158 | return 0; 159 | } 160 | boost::filesystem::path path2scene(argv[1]); 161 | boost::filesystem::path path2scannet(argv[2]); 162 | const auto& config = LoadConfig(path2scene); 163 | boost::filesystem::path path2scannetscene(path2scannet / "scans_test" / config.kName); 164 | 165 | // constants across all scenes 166 | constexpr float kDepthScalingFactor{1000.f}; 167 | constexpr int kWidth{640}; 168 | constexpr int kHeight{480}; 169 | 170 | // read reconstruction 171 | ColmapHelper recon; 172 | if (!config.kRgbOnly) 173 | { 174 | recon.Init(path2scene / "colmap", config.kDist2M); 175 | } 176 | 177 | float max_depth{0.f}; 178 | std::unordered_map> camera_frames{{"train", {}}, {"test", {}}}; 179 | 180 | for (const std::string& dataset_type : {"train", "test"}) 181 | { 182 | const boost::filesystem::path path2scene_type(path2scene / dataset_type); 183 | boost::filesystem::create_directory(path2scene_type); 184 | for (const std::string& subdir : {"rgb", "depth", "target_depth"}) 185 | { 186 | const boost::filesystem::path path2subdir(path2scene_type / subdir); 187 | boost::filesystem::create_directory(path2subdir); 188 | } 189 | const boost::filesystem::path path2csv(path2scene / (dataset_type + "_set.csv")); 190 | std::vector filenames = ReadSequence(path2csv.string()); 191 | 192 | for(const auto& filename : filenames) 193 | { 194 | // read rgb 195 | auto rgb = ReadRgb((path2scannetscene / "color" / filename).string()); 196 | 197 | // fix different aspect ratio between rgb and depth 198 | if (rgb.cols == 1296 && rgb.rows == 968) 199 | { 200 | int border = 2; 201 | cv::Mat rgb_padded(rgb.rows + border * 2, rgb.cols, rgb.depth()); // construct with padding 202 | cv::copyMakeBorder(rgb, rgb_padded, border, border, 0, 0, cv::BORDER_CONSTANT); 203 | rgb = rgb_padded; 204 | } 205 | 206 | // resize rgb to depth size 207 | int orig_width{rgb.cols}, orig_height{rgb.rows}; 208 | cv::resize(rgb, rgb, cv::Size(kWidth, kHeight), 0, 0, cv::INTER_AREA); 209 | auto depth = ReadDepth((path2scannetscene / "depth" / filename).replace_extension(".png").string()); 210 | 211 | // crop black areas from calibration 212 | int h_crop = 6; 213 | int w_crop = 8; 214 | rgb = cv::Mat(rgb, cv::Rect(w_crop, h_crop, rgb.cols - 2 * w_crop, rgb.rows - 2 * h_crop)); 215 | depth = cv::Mat(depth, cv::Rect(w_crop, h_crop, depth.cols - 2 * w_crop, depth.rows - 2 * h_crop)); 216 | 217 | // compute maximum 218 | double curr_min_depth{0.}, curr_max_depth{0.}; 219 | cv::minMaxLoc(depth, &curr_min_depth, &curr_max_depth); 220 | max_depth = std::max(static_cast(curr_max_depth) / kDepthScalingFactor, max_depth); 221 | if (max_depth > config.kMaxDepth) 222 | { 223 | const auto max_depth_int(static_cast(config.kMaxDepth * kDepthScalingFactor)); 224 | depth.forEach([max_depth_int](unsigned short& pixel, const int[]) -> void { 225 | pixel = (pixel >= max_depth_int) ? 0 : pixel; 226 | }); 227 | std::cout << "Warning: " << filename << " maximal depth " << max_depth << " invalidate values >= " 228 | << config.kMaxDepth << std::endl; 229 | max_depth = config.kMaxDepth; 230 | } 231 | 232 | // write rgb 233 | std::string rgb_file_rel{(boost::filesystem::path(dataset_type) / "rgb" / filename).string()}; 234 | WriteRgb(rgb, (path2scene / rgb_file_rel).string()); 235 | if (config.kRgbOnly) 236 | { 237 | continue; 238 | } 239 | 240 | // write sparse depth 241 | std::string depth_file_rel{(boost::filesystem::path(dataset_type) / "depth" / filename).replace_extension(".png").string()}; 242 | cv::Mat sparse_depth = cv::Mat::zeros(depth.rows, depth.cols, CV_16UC1); 243 | recon.GetSparseDepth(filename, dataset_type, sparse_depth, config.kMaxDepth, kDepthScalingFactor, depth); 244 | WriteDepth(sparse_depth, (path2scene / depth_file_rel).string()); 245 | 246 | // write target depth 247 | std::string target_depth_file_rel{(boost::filesystem::path(dataset_type) / "target_depth" / filename).replace_extension(".png").string()}; 248 | WriteDepth(depth, (path2scene / target_depth_file_rel).string()); 249 | 250 | // set camera 251 | auto camera = recon.GetCameraConfig(filename).camera; 252 | camera_frames[dataset_type].emplace_back(CameraFrame{rgb_file_rel, depth_file_rel, camera}); 253 | } 254 | std::cout << "Processed " << filenames.size() << " " << dataset_type << " views" << std::endl; 255 | } 256 | if (config.kRgbOnly) 257 | { 258 | return 0; 259 | } 260 | 261 | // write camera pose files 262 | const float far = max_depth * 1.025f; 263 | for (const std::string& dataset_type : {"train", "test"}) 264 | { 265 | const std::string camera_frame_file((path2scene / ("transforms_" + dataset_type + ".json")).string()); 266 | WriteCameraFrames(camera_frame_file, camera_frames[dataset_type], 0.1f, far, kDepthScalingFactor); 267 | } 268 | std::cout << "Set far plane to " << far << std::endl; 269 | std::cout << "Percent valid depth in train views: " << recon.GetPercentValid() * 100.f << std::endl; 270 | std::cout << "Scaling between sparse depth and target depth: " << recon.GetScaling() << std::endl; 271 | 272 | return 0; 273 | } 274 | -------------------------------------------------------------------------------- /preprocessing/io/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(io 2 | include/rgbd.h 3 | src/rgbd.cpp 4 | include/file_utils.h 5 | src/file_utils.cpp 6 | include/camera_frames.h 7 | src/camera_frames.cpp 8 | ) 9 | target_include_directories(io PUBLIC 10 | ${CMAKE_CURRENT_SOURCE_DIR}/include 11 | ) 12 | target_link_libraries(io 13 | camera 14 | ${OpenCV_LIBS} 15 | ) 16 | -------------------------------------------------------------------------------- /preprocessing/io/include/camera_frames.h: -------------------------------------------------------------------------------- 1 | #ifndef CAMERA_FRAMES_H 2 | #define CAMERA_FRAMES_H 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | struct CameraFrame 10 | { 11 | std::string rgb_file_path{}; 12 | std::string depth_file_path{}; 13 | Camera camera; 14 | }; 15 | 16 | void WriteCameraFrames(const std::string& file, const std::vector& camera_frames, float camera_near, 17 | float camera_far, float depth_scaling_factor); 18 | 19 | #endif // CAMERA_FRAMES_H 20 | -------------------------------------------------------------------------------- /preprocessing/io/include/file_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef FILE_UTILS_H 2 | #define FILE_UTILS_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | 12 | std::vector SplitByChar(const std::string &s, char c, bool allow_empty = false); 13 | 14 | void StringToType(const std::string& s, float& instance); 15 | void StringToType(const std::string& s, int& instance); 16 | void StringToType(const std::string& s, std::string& instance); 17 | 18 | template 19 | std::vector ReadSequence(const std::string& path, char delimiter = ' ') 20 | { 21 | std::vector result{}; 22 | std::ifstream file(path); 23 | std::string line; 24 | while (std::getline(file, line)) 25 | { 26 | const auto& elements = SplitByChar(line, delimiter); 27 | for (const auto e : elements) 28 | { 29 | T element{}; 30 | StringToType(e, element); 31 | result.emplace_back(element); 32 | } 33 | } 34 | return result; 35 | } 36 | 37 | #endif // FILE_UTILS_H 38 | -------------------------------------------------------------------------------- /preprocessing/io/include/rgbd.h: -------------------------------------------------------------------------------- 1 | #ifndef RGBD_WRITER_H 2 | #define RGBD_WRITER_H 3 | 4 | #include 5 | 6 | #include 7 | 8 | cv::Mat ReadRgb(const std::string& file); 9 | 10 | cv::Mat ReadDepth(const std::string& file); 11 | 12 | void WriteDepth(const cv::Mat& depth_image, const std::string& file); 13 | 14 | void WriteRgb(const cv::Mat& image, const std::string& file); 15 | 16 | #endif // RGBD_WRITER_H 17 | -------------------------------------------------------------------------------- /preprocessing/io/src/camera_frames.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | 12 | template 13 | void AddFloat(const std::string& name, float value, rapidjson::Document& json, T& object) 14 | { 15 | rapidjson::Value key(name.c_str(), json.GetAllocator()); 16 | rapidjson::Value val; 17 | val.SetFloat(value); 18 | object.AddMember(key, val, json.GetAllocator()); 19 | } 20 | 21 | template 22 | void AddNonEmptyString(const std::string& name, const std::string& value, rapidjson::Document& json, T& object) 23 | { 24 | if (!name.empty()) 25 | { 26 | rapidjson::Value key(name.c_str(), json.GetAllocator()); 27 | rapidjson::Value val(value.c_str(), json.GetAllocator()); 28 | object.AddMember(key, val, json.GetAllocator()); 29 | } 30 | } 31 | 32 | template 33 | void AddMatrix4(const std::string& name, const glm::mat4& matrix, rapidjson::Document& json, T& object) 34 | { 35 | rapidjson::Value key(name.c_str(), json.GetAllocator()); 36 | rapidjson::Value mat(rapidjson::Type::kArrayType); 37 | for (int r(0); r != 4; ++r) 38 | { 39 | rapidjson::Value mat_row(rapidjson::Type::kArrayType); 40 | const auto row = glm::row(matrix, r); 41 | for (int c(0); c != 4; ++c) 42 | { 43 | rapidjson::Value element; 44 | element.SetFloat(row[c]); 45 | mat_row.PushBack(element, json.GetAllocator()); 46 | } 47 | mat.PushBack(mat_row, json.GetAllocator()); 48 | } 49 | 50 | object.AddMember(key, mat, json.GetAllocator()); 51 | } 52 | 53 | void WriteCameraFrames(const std::string& file, const std::vector& camera_frames, float camera_near, 54 | float camera_far, float depth_scaling_factor) 55 | { 56 | rapidjson::Document json; 57 | json.SetObject(); 58 | 59 | AddFloat("near", camera_near, json, json); 60 | AddFloat("far", camera_far, json, json); 61 | AddFloat("depth_scaling_factor", depth_scaling_factor, json, json); 62 | 63 | // frames 64 | rapidjson::Value frames(rapidjson::Type::kArrayType); 65 | for (const auto& camera_frame : camera_frames) 66 | { 67 | rapidjson::Value frame(rapidjson::Type::kObjectType); 68 | 69 | // file paths 70 | AddNonEmptyString("file_path", camera_frame.rgb_file_path, json, frame); 71 | AddNonEmptyString("depth_file_path", camera_frame.depth_file_path, json, frame); 72 | 73 | // intrinsics 74 | AddFloat("fx", camera_frame.camera.GetFx(), json, frame); 75 | AddFloat("fy", camera_frame.camera.GetFy(), json, frame); 76 | AddFloat("cx", camera_frame.camera.GetCx(), json, frame); 77 | AddFloat("cy", camera_frame.camera.GetCy(), json, frame); 78 | 79 | // transform matrix 80 | AddMatrix4("transform_matrix", glm::inverse(camera_frame.camera.GetWorld2Cam()), json, frame); 81 | 82 | frames.PushBack(frame, json.GetAllocator()); 83 | } 84 | json.AddMember("frames", frames, json.GetAllocator()); 85 | 86 | std::ofstream ofs(file); 87 | rapidjson::OStreamWrapper osw(ofs); 88 | 89 | rapidjson::PrettyWriter writer(osw); 90 | json.Accept(writer); 91 | } 92 | -------------------------------------------------------------------------------- /preprocessing/io/src/file_utils.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | std::vector SplitByChar(const std::string &s, char c, bool allow_empty) 6 | { 7 | std::vector result{}; 8 | std::size_t end = s.find(c); 9 | std::size_t start = 0; 10 | for (; end != std::string::npos;) 11 | { 12 | if (allow_empty || start != end) 13 | { 14 | result.push_back( s.substr(start, end - start)); 15 | } 16 | 17 | start = end + 1; 18 | end = s.find(c, start); 19 | } 20 | 21 | // add the rest 22 | if (start != s.size()) 23 | { 24 | result.push_back(s.substr(start, s.size() - start)); 25 | } 26 | 27 | return result; 28 | } 29 | 30 | void StringToType(const std::string& s, float& instance) 31 | { 32 | instance = std::stof(s); 33 | } 34 | 35 | void StringToType(const std::string& s, int& instance) 36 | { 37 | instance = std::stoi(s); 38 | } 39 | 40 | void StringToType(const std::string& s, std::string& instance) 41 | { 42 | instance = s; 43 | } 44 | -------------------------------------------------------------------------------- /preprocessing/io/src/rgbd.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | 7 | cv::Mat ReadRgb(const std::string& file) 8 | { 9 | cv::Mat image = cv::imread(file, cv::IMREAD_COLOR); 10 | cv::cvtColor(image, image, cv::COLOR_BGR2RGB); 11 | return image; 12 | } 13 | 14 | cv::Mat ReadDepth(const std::string& file) 15 | { 16 | return cv::imread(file, cv::IMREAD_ANYDEPTH); 17 | } 18 | 19 | void WriteDepth(const cv::Mat& depth_image, const std::string& file) 20 | { 21 | cv::imwrite(file, depth_image); 22 | } 23 | 24 | void WriteRgb(const cv::Mat& image, const std::string& file) 25 | { 26 | cv::cvtColor(image, image, (image.channels() == 3) ? cv::COLOR_RGB2BGR : cv::COLOR_RGBA2BGRA); 27 | cv::imwrite(file, image); 28 | } 29 | -------------------------------------------------------------------------------- /preprocessing/io_colmap/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(io_colmap 2 | include/colmap_reader.h 3 | src/colmap_reader.cpp 4 | ) 5 | target_include_directories(io_colmap PUBLIC 6 | ${CMAKE_CURRENT_SOURCE_DIR}/include 7 | ) 8 | target_link_libraries(io_colmap 9 | camera 10 | io 11 | ) 12 | -------------------------------------------------------------------------------- /preprocessing/io_colmap/include/colmap_reader.h: -------------------------------------------------------------------------------- 1 | #ifndef COLMAP_READER_H 2 | #define COLMAP_READER_H 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | struct CameraConfig 10 | { 11 | int id{0}; 12 | std::string rgb_image{}; 13 | int width{0}; 14 | int height{0}; 15 | Camera camera; 16 | }; 17 | 18 | std::vector ReadCameras(const std::string& cameras_txt, const std::string& images_txt, float dist2m); 19 | 20 | void ReadSparsePointCloud(const std::string& points3D_txt, float scale, float max_reprojection_error, unsigned int min_track_length, 21 | std::vector& point_cloud, std::vector>& visibility); 22 | 23 | #endif // COLMAP_READER_H 24 | -------------------------------------------------------------------------------- /preprocessing/io_colmap/src/colmap_reader.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | std::unordered_map>> ReadIntrinsics(const std::string& cameras_txt) 12 | { 13 | std::unordered_map>> result{}; 14 | std::ifstream file(cameras_txt); 15 | std::string line; 16 | while (std::getline(file, line)) 17 | { 18 | const auto& elements = SplitByChar(line, ' '); 19 | if (elements[0] == std::string("#")) 20 | { 21 | continue; 22 | } 23 | if ("PINHOLE" == elements[1]) 24 | { 25 | if (elements.size() == 8) 26 | { 27 | // width, height, fx, fy, cx, cy 28 | result[std::stoi(elements[0])] = std::make_pair(std::string{"PINHOLE"}, 29 | std::array{std::stof(elements[2]), std::stof(elements[3]), std::stof(elements[4]), 30 | std::stof(elements[5]), std::stof(elements[6]), std::stof(elements[7])}); 31 | } 32 | else 33 | { 34 | std::cout << "Error: Invalid pinhole camera" << std::endl; 35 | } 36 | } 37 | if ("SIMPLE_PINHOLE" == elements[1]) 38 | { 39 | if (elements.size() == 7) 40 | { 41 | // width, height, f, cx, cy 42 | result[std::stoi(elements[0])] = std::make_pair(std::string{"SIMPLE_PINHOLE"}, 43 | std::array{std::stof(elements[2]), std::stof(elements[3]), std::stof(elements[4]), 44 | std::stof(elements[5]), std::stof(elements[6]), 0.f}); 45 | } 46 | else 47 | { 48 | std::cout << "Error: Invalid simple pinhole camera" << std::endl; 49 | } 50 | } 51 | else 52 | { 53 | std::cout << "Warning: Camera model " << elements[1] << " not implemented." << std::endl; 54 | } 55 | } 56 | file.close(); 57 | if (result.empty()) 58 | { 59 | std::cout << "Error: Could not read intrinsics" << std::endl; 60 | } 61 | return result; 62 | } 63 | 64 | std::vector ReadCameras(const std::string& cameras_txt, const std::string& images_txt, float dist2m) 65 | { 66 | std::vector result{}; 67 | 68 | const auto& intrinsics = ReadIntrinsics(cameras_txt); 69 | 70 | std::ifstream file(images_txt); 71 | std::string line; 72 | while (std::getline(file, line)) 73 | { 74 | const auto& elements = SplitByChar(line, ' '); 75 | if (elements.size() != 10 || elements[0] == std::string("#")) 76 | { 77 | continue; 78 | } 79 | 80 | glm::quat quaternion(std::stof(elements[1]), std::stof(elements[2]), std::stof(elements[3]), std::stof(elements[4])); 81 | // rotation from colmap camera frame (z in positive viewing direction and image origin in top left corner) 82 | // to internal camera frame (z in negative viewing direction and image origin in bottom left corner) 83 | const glm::mat4 rot_cam2cam = glm::rotate(glm::mat4(1.0f), glm::radians(180.f), glm::vec3(1.0f, 0.0f, 0.0f)); 84 | glm::mat4 rot_world2cam = glm::mat4_cast(quaternion); 85 | glm::mat4 tra_world2cam = glm::mat4(1.f); 86 | tra_world2cam[3] = glm::vec4(dist2m * std::stof(elements[5]), dist2m * std::stof(elements[6]), dist2m * std::stof(elements[7]), 1.f); 87 | const glm::mat4 world2cam = rot_cam2cam * tra_world2cam * rot_world2cam; 88 | 89 | const auto& model_and_intrinsic = intrinsics.at(std::stoi(elements[8])); 90 | const auto& model = model_and_intrinsic.first; 91 | const auto& intrinsic = model_and_intrinsic.second; 92 | const float w = intrinsic[0]; 93 | const float h = intrinsic[1]; 94 | const float f_x = intrinsic[2]; 95 | float f_y(0.f), c_x(0.f), c_y(0.f); 96 | if (model == "SIMPLE_PINHOLE") 97 | { 98 | f_y = f_x; 99 | c_x = intrinsic[3]; 100 | c_y = intrinsic[4]; 101 | } 102 | else if (model == "PINHOLE") 103 | { 104 | f_y = intrinsic[3]; 105 | c_x = intrinsic[4]; 106 | c_y = intrinsic[5]; 107 | } 108 | 109 | c_y = h - c_y; 110 | Camera camera(f_x, f_y, c_x, c_y, world2cam); 111 | result.emplace_back(CameraConfig{std::stoi(elements[0]), elements[9], static_cast(w), static_cast(h), camera}); 112 | } 113 | 114 | file.close(); 115 | 116 | return result; 117 | } 118 | 119 | void ReadSparsePointCloud(const std::string& points3D_txt, float scale, float max_reprojection_error, unsigned int min_track_length, 120 | std::vector& point_cloud, std::vector>& visibility) 121 | { 122 | std::ifstream file(points3D_txt); 123 | std::string line; 124 | while (std::getline(file, line)) 125 | { 126 | const auto& elements = SplitByChar(line, ' '); 127 | if (elements[0] == std::string("#")) 128 | { 129 | continue; 130 | } 131 | const auto reproj_err = std::stof(elements[7]); 132 | if (reproj_err <= max_reprojection_error) 133 | { 134 | std::vector point_visibility{}; 135 | for (auto it(elements.cbegin() + 8); it < elements.end();) 136 | { 137 | point_visibility.emplace_back(std::stoi(*it)); 138 | it = it + 2; 139 | } 140 | if (point_visibility.size() >= min_track_length) 141 | { 142 | point_cloud.emplace_back(glm::vec3{scale * std::stof(elements[1]), scale * std::stof(elements[2]), scale * std::stof(elements[3])}); 143 | visibility.emplace_back(point_visibility); 144 | } 145 | } 146 | } 147 | } 148 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch 3 | tensorboard 4 | torchvision 5 | opencv-python 6 | pandas 7 | scikit-image 8 | lpips 9 | configargparse -------------------------------------------------------------------------------- /run_depth_completion.py: -------------------------------------------------------------------------------- 1 | import time 2 | import datetime 3 | import os.path 4 | from argparse import ArgumentParser 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.tensorboard import SummaryWriter 9 | from torch.utils.data import DataLoader, Subset 10 | from train_utils import MeanTracker 11 | import cv2 12 | 13 | from data import ScanNetDataset, convert_depth_completion_scaling_to_m, create_random_subsets 14 | from train_utils import print_network_info, get_hours_mins, MeanTracker, make_image_grid, apply_max_filter, \ 15 | update_learning_rate 16 | from model import resnet18_skip 17 | from metric import compute_rmse 18 | 19 | def write_batch(batch, path): 20 | bgr = cv2.cvtColor((batch.permute(1, 2, 0).numpy() * 255.).astype(np.uint8), cv2.COLOR_RGB2BGR) 21 | bgr_width = bgr.shape[1] // 5 22 | depth_columns = cv2.applyColorMap(bgr[:, bgr_width:, :], cv2.COLORMAP_VIRIDIS) 23 | cv2.imwrite(path, np.concatenate((bgr[:, :bgr_width, :], depth_columns), 1)) 24 | 25 | def make_grid(input, pred_x, pred_std, target, unnormalize): 26 | input_grid = make_image_grid(input, unnormalize) 27 | pred_x_grid = make_image_grid(pred_x) 28 | pred_std_grid = make_image_grid(pred_std) 29 | target_grid = make_image_grid(target) 30 | return torch.cat((input_grid, pred_x_grid, pred_std_grid, target_grid), 2) 31 | 32 | def batch2grid(input, pred, target, unnormalize, n_samples): 33 | input = input[:n_samples, ...] 34 | pred_x = pred[0][:n_samples, ...] 35 | target = target[:n_samples, ...] 36 | # clamp at 0.5m and normalize 37 | pred_std = convert_depth_completion_scaling_to_m(pred[1][:n_samples, ...]).clamp(max=0.5) / 0.5 38 | return make_grid(apply_max_filter(input, 3), pred_x, pred_std, target, unnormalize) 39 | 40 | def get_load_path(args): 41 | return os.path.join(args.exp_dir, args.expname + '.tar') 42 | 43 | def load_net(args): 44 | load_path = get_load_path(args) 45 | if os.path.exists(load_path): 46 | load_pretrained = False 47 | else: 48 | load_pretrained = True 49 | 50 | net = resnet18_skip(pretrained=load_pretrained, pretrained_path=args.pretrained_resnet_path) 51 | 52 | print_network_info(net) 53 | 54 | if not load_pretrained: 55 | ckpt = torch.load(load_path) 56 | missing_keys, unexpected_keys = net.load_state_dict(ckpt['network_state_dict'], strict=False) 57 | print("Loading model: \n missing keys: {}\n unexpected keys: {}".format(missing_keys, unexpected_keys)) 58 | 59 | return net 60 | 61 | def load_train_state(args, optimizer): 62 | load_path = get_load_path(args) 63 | start_epoch = 1 64 | min_val_rmse = 1e6 65 | if os.path.exists(load_path): 66 | ckpt = torch.load(load_path) 67 | optimizer.load_state_dict(ckpt['optimizer_state_dict']) 68 | if 'lr' in ckpt: 69 | new_lr = ckpt['lr'] 70 | update_learning_rate(optimizer, new_lr) 71 | print("Set learning rate to {}".format(new_lr)) 72 | 73 | start_epoch = ckpt['epoch'] + 1 74 | return optimizer, start_epoch, min_val_rmse 75 | 76 | def get_device(): 77 | if torch.cuda.is_available(): 78 | device = torch.device("cuda") 79 | print("Training on GPU") 80 | else: 81 | device = torch.device("cpu") 82 | print("Training on CPU") 83 | return device 84 | 85 | class Validator: 86 | def __init__(self, val_dataset, unnormalize, min_val_rmse, device): 87 | self.device = device 88 | self.unnormalize = unnormalize 89 | self.min_val_rmse = min_val_rmse 90 | validate_on_at_least_n_samples = 20000 91 | val_sample_count = len(val_dataset) 92 | if val_sample_count < validate_on_at_least_n_samples: 93 | self.val_subsets = [val_dataset,] 94 | print("Small validation set -> no need to create subsets") 95 | else: 96 | self.val_subsets = create_random_subsets(val_dataset, validate_on_at_least_n_samples) 97 | print("Create {} validation subsets with length {} or {}".format(len(self.val_subsets), len(self.val_subsets[0]), \ 98 | len(self.val_subsets[-1]))) 99 | self.val_subset_index = 0 100 | 101 | def next_subset_index(self): 102 | curr_subset_index = self.val_subset_index 103 | self.val_subset_index += 1 104 | if self.val_subset_index == len(self.val_subsets): 105 | self.val_subset_index = 0 106 | return curr_subset_index 107 | 108 | def validate(self, net, optimizer, args, tb, epoch, step): 109 | with torch.no_grad(): 110 | net.eval() 111 | val_metrics = MeanTracker() 112 | val_start_time = time.time() 113 | for i, data in enumerate(DataLoader(dataset=self.val_subsets[self.next_subset_index()], batch_size=args.batch_size, \ 114 | shuffle=False, num_workers=4, drop_last=True)): 115 | batch_start_time = time.time() 116 | 117 | # move data to gpu and predict 118 | valid_target = data['target_valid_depth'].to(self.device) 119 | if valid_target.sum() <= 0: 120 | continue 121 | input = data['rgbd'].to(self.device) 122 | target = data['target_depth'].to(self.device) 123 | pred = net(input) 124 | 125 | # compute metrics 126 | val_l1_loss = convert_depth_completion_scaling_to_m(torch.nn.functional.l1_loss(pred[0][valid_target], target[valid_target])) 127 | val_rmse = convert_depth_completion_scaling_to_m(compute_rmse(pred[0][valid_target], target[valid_target])) 128 | val_loss = 0.01 * torch.nn.functional.gaussian_nll_loss(pred[0][valid_target], target[valid_target], pred[1][valid_target].pow(2)) 129 | curr_val_metrics = {"l1" : val_l1_loss.item(), "rmse" : val_rmse.item(), "gnll" : val_loss.item(), \ 130 | "batch_time" : time.time() - batch_start_time} 131 | val_metrics.add(curr_val_metrics) 132 | 133 | # visualize the first batch 134 | if i == 0: 135 | batch_grid = batch2grid(input, pred, target, self.unnormalize, 8) 136 | tb.add_image('val_image', batch_grid, step) 137 | 138 | # print statistics 139 | mean_it_time = (time.time() - val_start_time) / (i + 1) 140 | mean_val_rmse = val_metrics.get("rmse") 141 | mean_val_l1_loss = val_loss = val_metrics.get("l1") 142 | tb.add_scalars('l1', {'val': mean_val_l1_loss}, step) 143 | mean_val_gnll = val_loss = val_metrics.get("gnll") 144 | tb.add_scalars('gnll', {'val': mean_val_gnll}, step) 145 | tb.add_scalar('rmse', mean_val_rmse, step) 146 | print("Validate, it_time={:.3f}s, batch_time={:.3f}s, val_metric={:.4f}".format(mean_it_time, val_metrics.get("batch_time"), val_loss)) 147 | 148 | # save checkpoint 149 | if mean_val_rmse < self.min_val_rmse: 150 | self.min_val_rmse = mean_val_rmse 151 | filename = args.expname + '.tar' 152 | os.makedirs(os.path.join(args.exp_dir), exist_ok=True) 153 | path = os.path.join(args.exp_dir, filename) 154 | save_dict = { 155 | 'epoch': epoch, 156 | 'lr' : optimizer.param_groups[0]['lr'], 157 | 'mean_val_rmse': mean_val_rmse, 158 | 'network_state_dict': net.state_dict(), 159 | 'optimizer_state_dict': optimizer.state_dict(),} 160 | torch.save(save_dict, path) 161 | print('Saved checkpoints at', path) 162 | net.train() 163 | return val_loss 164 | 165 | def train_depth_completion(args): 166 | np.random.seed(0) 167 | torch.manual_seed(0) 168 | torch.cuda.manual_seed(0) 169 | device = get_device() 170 | 171 | # load network and optimizer 172 | net = load_net(args).to(device) 173 | 174 | optimizer = torch.optim.Adam(list(net.parameters()), lr=args.lr) 175 | optimizer, start_epoch, min_val_rmse = load_train_state(args, optimizer) 176 | 177 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.8, patience=3, verbose=True) 178 | 179 | tb = SummaryWriter(log_dir=os.path.join("runs", args.expname)) 180 | 181 | # create datasets 182 | train_dataset = ScanNetDataset(args.dataset_dir, "train", args.db_path, random_rot=args.random_rot, horizontal_flip=True, \ 183 | color_jitter=args.color_jitter, depth_noise=True, missing_depth_percent=args.missing_depth_percent) 184 | val_dataset = ScanNetDataset(args.dataset_dir, "val", args.db_path, depth_noise=True, missing_depth_percent=args.missing_depth_percent) 185 | unnormalize = train_dataset.unnormalize 186 | train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=6, drop_last=True) 187 | args.i_val = min(args.i_val, len(train_loader)) 188 | print("Train on {} samples".format(len(train_dataset))) 189 | validator = Validator(val_dataset, unnormalize, min_val_rmse, device) 190 | print("Validate on {} samples".format(len(val_dataset))) 191 | 192 | # start training 193 | train_batch_count = len(train_loader) 194 | train_metrics = MeanTracker() 195 | for epoch in range(start_epoch, args.n_epochs + 1): 196 | net.train() # switch to train mode 197 | epoch_start_time = time.time() 198 | for i, data in enumerate(train_loader): 199 | batch_start_time = time.time() 200 | step = (epoch - 1) * train_batch_count + i + 1 201 | 202 | # move data to gpu and predict 203 | valid_target = data['target_valid_depth'].to(device) 204 | if valid_target.sum() <= 0: 205 | continue 206 | input = data['rgbd'].to(device) 207 | target = data['target_depth'].to(device) 208 | pred = net(input) 209 | 210 | # compute loss and metrics, update network parameters 211 | l1_loss = torch.nn.functional.l1_loss(pred[0][valid_target], target[valid_target]) 212 | curr_train_metrics = {"l1" : convert_depth_completion_scaling_to_m(l1_loss.item()),} 213 | train_loss = 0.01 * torch.nn.functional.gaussian_nll_loss(pred[0][valid_target], target[valid_target], pred[1][valid_target].pow(2)) 214 | curr_train_metrics["gnll"] = train_loss.item() 215 | optimizer.zero_grad() 216 | train_loss.backward() 217 | torch.nn.utils.clip_grad_value_(net.parameters(), 0.1) 218 | optimizer.step() 219 | curr_train_metrics["batch_time"] = time.time() - batch_start_time 220 | train_metrics.add(curr_train_metrics) 221 | 222 | # log results 223 | if (i+1)%args.i_print == 0: 224 | mean_it_time = (time.time() - epoch_start_time) / (i + 1) 225 | portion_of_epoch = (i + 1) / float(train_batch_count) 226 | hours, mins = get_hours_mins(epoch_start_time, time.time()) 227 | print("Epoch {}/{}: {:.2f}% in {:02d}:{:02d}, it_time={:.3f}s, batch_time={:.3f}s, l1={:.4f}".format(epoch, args.n_epochs, \ 228 | 100. * portion_of_epoch, hours, mins, mean_it_time, train_metrics.get("batch_time"), train_metrics.get("l1"))) 229 | tb.add_scalars('l1', {'train': train_metrics.get("l1")}, step) 230 | tb.add_scalars('gnll', {'train': train_metrics.get("gnll")}, step) 231 | train_metrics.reset() 232 | 233 | if (i+1)%args.i_img == 0: 234 | batch_grid = batch2grid(input, pred, target, unnormalize, 8) 235 | tb.add_image('train_image', batch_grid, step) 236 | 237 | if (i+1)%args.i_val == 0: 238 | val_loss = validator.validate(net, optimizer, args, tb, epoch, step) 239 | # update lr 240 | scheduler.step(val_loss) 241 | 242 | tb.flush() 243 | 244 | def main(): 245 | parser = ArgumentParser() 246 | parser.add_argument('task', type=str, help='one out of: "train", "test"') 247 | parser.add_argument("--expname", type=str, default=None, \ 248 | help='specify the experiment, required for "test" or to resume "train"') 249 | 250 | # data 251 | parser.add_argument("--dataset_dir", type=str, default="", \ 252 | help="dataset directory") 253 | parser.add_argument("--db_path", type=str, default="scannet_sift_database.db", \ 254 | help='path to the sift database') 255 | parser.add_argument("--pretrained_resnet_path", type=str, default="resnet18.pth", \ 256 | help='path to the pretrained resnet weights') 257 | parser.add_argument("--ckpt_dir", type=str, default="", \ 258 | help='checkpoint directory') 259 | 260 | # training 261 | parser.add_argument("--missing_depth_percent", type=float, default=0.998, \ 262 | help='portion of missing depth in sparse depth input, value between 0 and 1') 263 | parser.add_argument("--random_rot", type=float, default=10., \ 264 | help='random rotation in degree as data augmentation') 265 | parser.add_argument("--color_jitter", type=float, default=0.4, \ 266 | help='add color jitter as data augmentation, set None to deactivate') 267 | parser.add_argument("--batch_size", type=int, default=8, \ 268 | help='batch size') 269 | parser.add_argument("--n_epochs", type=int, default=12, \ 270 | help='number of epochs') 271 | parser.add_argument("--lr", type=float, default=1e-4, \ 272 | help='learning rate') 273 | 274 | # logging 275 | parser.add_argument("--i_print", type=int, default=1000, \ 276 | help='log train loss every ith batch') 277 | parser.add_argument("--i_img", type=int, default=10000, \ 278 | help='log train images every ith batch') 279 | parser.add_argument("--i_val", type=int, default=25000, \ 280 | help='validate every ith batch or every epoch if the train set is smaller') 281 | 282 | args = parser.parse_args() 283 | 284 | print(args) 285 | 286 | if args.expname is None: 287 | args.expname = "{}".format(datetime.datetime.fromtimestamp(time.time()).strftime('%Y%m%d_%H%M%S')) 288 | args.exp_dir = os.path.join(args.ckpt_dir, args.expname) 289 | 290 | device = get_device() 291 | 292 | if args.task == "test": 293 | # load network weights 294 | net = load_net(args).to(device) 295 | 296 | result_dir = os.path.join(args.exp_dir, "test_results") 297 | os.makedirs(os.path.join(result_dir), exist_ok=True) 298 | 299 | # create dataset 300 | test_dataset = ScanNetDataset(args.dataset_dir, "test", args.db_path, depth_noise=True, missing_depth_percent=args.missing_depth_percent) 301 | unnormalize = test_dataset.unnormalize 302 | test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=6, drop_last=True) 303 | print("Test on {} samples".format(len(test_dataset))) 304 | 305 | visu_sample_count = len(test_dataset) 306 | number_visu_images = 40 # number of images to visualize 307 | visu_samples = range(0, visu_sample_count, visu_sample_count // number_visu_images) 308 | visu_loader = DataLoader(dataset=Subset(test_dataset, visu_samples), batch_size=args.batch_size, shuffle=False, num_workers=2, drop_last=True) 309 | 310 | with torch.no_grad(): 311 | net.eval() 312 | test_metrics = MeanTracker() 313 | for i, data in enumerate(test_loader): 314 | 315 | # move data to gpu and predict 316 | valid_target = data['target_valid_depth'].to(device) 317 | if valid_target.sum() <= 0: 318 | continue 319 | input = data['rgbd'].to(device) 320 | target = data['target_depth'].to(device) 321 | pred = net(input) 322 | 323 | # compute test metrics 324 | pred_depth_m = convert_depth_completion_scaling_to_m(pred[0]) 325 | valid_pred_depth_m = pred_depth_m[valid_target] 326 | target_depth_m = convert_depth_completion_scaling_to_m(target[valid_target]) 327 | mae = torch.nn.functional.l1_loss(valid_pred_depth_m, target_depth_m) 328 | rmse = compute_rmse(valid_pred_depth_m, target_depth_m) 329 | curr_metrics = {"mae" : mae.item(), "rmse" : rmse.item()} 330 | pred_std_m = convert_depth_completion_scaling_to_m(pred[1]) 331 | curr_metrics["std"] = pred_std_m.mean() 332 | test_metrics.add(curr_metrics) 333 | if (i % 1000) == 0: 334 | print("{}/{}".format(i, len(test_loader))) 335 | 336 | with open(os.path.join(result_dir, 'metrics.txt'), 'w') as f: 337 | test_metrics.print(f) 338 | test_metrics.print() 339 | 340 | # write visualization samples 341 | for i, data in enumerate(visu_loader): 342 | valid_target = data['target_valid_depth'].to(device) 343 | input = data['rgbd'].to(device) 344 | target = data['target_depth'].to(device) 345 | pred = net(input) 346 | batch_grid = batch2grid(input, pred, target, unnormalize, args.batch_size) 347 | write_batch(batch_grid.cpu(), os.path.join(result_dir, str(i) + ".jpg")) 348 | exit() 349 | else: 350 | train_depth_completion(args) 351 | 352 | if __name__ == "__main__": 353 | main() -------------------------------------------------------------------------------- /train_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .hyperparameter_update import update_learning_rate 2 | from .logging import MeanTracker, print_network_info, get_hours_mins, make_image_grid, apply_max_filter -------------------------------------------------------------------------------- /train_utils/hyperparameter_update.py: -------------------------------------------------------------------------------- 1 | 2 | def update_learning_rate(optimizer, learning_rate): 3 | for param_group in optimizer.param_groups: 4 | param_group['lr'] = learning_rate 5 | -------------------------------------------------------------------------------- /train_utils/logging.py: -------------------------------------------------------------------------------- 1 | from scipy import ndimage 2 | import torch 3 | import torchvision 4 | 5 | class MeanTracker(object): 6 | def __init__(self): 7 | self.reset() 8 | 9 | def add(self, input, weight=1.): 10 | for key, l in input.items(): 11 | if not key in self.mean_dict: 12 | self.mean_dict[key] = 0 13 | self.mean_dict[key] = (self.mean_dict[key] * self.total_weight + l) / (self.total_weight + weight) 14 | self.total_weight += weight 15 | 16 | def has(self, key): 17 | return (key in self.mean_dict) 18 | 19 | def get(self, key): 20 | return self.mean_dict[key] 21 | 22 | def as_dict(self): 23 | return self.mean_dict 24 | 25 | def reset(self): 26 | self.mean_dict = dict() 27 | self.total_weight = 0 28 | 29 | def print(self, f=None): 30 | for key, l in self.mean_dict.items(): 31 | if f is not None: 32 | print("{}: {}".format(key, l), file=f) 33 | else: 34 | print("{}: {}".format(key, l)) 35 | 36 | def get_hours_mins(start_time, end_time): 37 | dt = end_time - start_time 38 | hours = int(dt // 3600) 39 | mins = int((dt // 60) % 60) 40 | return hours, mins 41 | 42 | def apply_max_filter(batch, channel, kernel=3): 43 | batch_local = batch.detach().clone() 44 | for i, image in enumerate(batch_local.cpu().numpy()): 45 | batch_local[i, channel, :, :] = torch.tensor(ndimage.maximum_filter(image[channel, :, :], \ 46 | size=kernel)).type(torch.FloatTensor) 47 | return batch_local 48 | 49 | def make_image_grid(data, unnormalize=None): 50 | if data.shape[1] == 1: 51 | return torchvision.utils.make_grid(data, nrow=1) 52 | elif data.shape[1] == 3: 53 | return torchvision.utils.make_grid(data if unnormalize is None else unnormalize['rgb'](data), nrow=1) 54 | elif data.shape[1] == 4: 55 | unnormalized = data if unnormalize is None else unnormalize['rgbd'](data) 56 | rgb_grid = torchvision.utils.make_grid(unnormalized[:, :3, :, :], nrow=1) 57 | depth_grid = torchvision.utils.make_grid(torch.unsqueeze(unnormalized[:, 3, :, :], 1), nrow=1) 58 | return torch.cat((rgb_grid, depth_grid), 2) 59 | 60 | def print_network_info(net): 61 | num_params = 0 62 | for param in net.parameters(): 63 | num_params += param.numel() 64 | print(net) 65 | print('Number of model parameters: %.3f M' % (num_params / 1e6)) 66 | --------------------------------------------------------------------------------