├── .gitignore ├── LICENSE ├── README.md ├── colmap_converter ├── __init__.py ├── __main__.py ├── colmap_utils.py ├── frames.py └── meta.py ├── dataset ├── __init__.py ├── annotations.py ├── rays.py └── utils.py ├── environment.yaml ├── evaluate.py ├── evaluation ├── __init__.py ├── metrics.py ├── segmentation.py ├── utils.py └── video.py ├── get_started.sh ├── loss.py ├── model ├── __init__.py ├── embedding.py ├── neuraldiff.py └── rendering.py ├── notebook └── eval.ipynb ├── opt.py ├── scripts ├── eval.sh └── train.sh ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | ckpts/ 3 | results/ 4 | data/ 5 | *.mp4 6 | *.pt 7 | 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | MIT License 3 | 4 | Copyright (c) 2021 Vadim Tschernezki 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | 24 | MIT License 25 | 26 | Copyright (c) 2020 Quei-An Chen 27 | 28 | Permission is hereby granted, free of charge, to any person obtaining a copy 29 | of this software and associated documentation files (the "Software"), to deal 30 | in the Software without restriction, including without limitation the rights 31 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 32 | copies of the Software, and to permit persons to whom the Software is 33 | furnished to do so, subject to the following conditions: 34 | 35 | The above copyright notice and this permission notice shall be included in all 36 | copies or substantial portions of the Software. 37 | 38 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 39 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 40 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 41 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 42 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 43 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 44 | SOFTWARE. 45 | 46 | MIT License 47 | 48 | Copyright (c) 2020 bmild 49 | 50 | Permission is hereby granted, free of charge, to any person obtaining a copy 51 | of this software and associated documentation files (the "Software"), to deal 52 | in the Software without restriction, including without limitation the rights 53 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 54 | copies of the Software, and to permit persons to whom the Software is 55 | furnished to do so, subject to the following conditions: 56 | 57 | The above copyright notice and this permission notice shall be included in all 58 | copies or substantial portions of the Software. 59 | 60 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 61 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 62 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 63 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 64 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 65 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 66 | SOFTWARE. 67 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # NeuralDiff: Segmenting 3D objects that move in egocentric videos 3 | 4 | ## [Project Page](https://www.robots.ox.ac.uk/~vadim/neuraldiff/) | [Paper + Supplementary](https://www.robots.ox.ac.uk/~vgg/publications/2021/Tschernezki21/tschernezki21.pdf) | [Video](https://www.youtube.com/watch?v=0J98WqHMSm4) 5 | 6 | ![teaser](https://user-images.githubusercontent.com/12436822/147008441-f294a1e1-1de6-4ee1-b7c0-9872cac4f953.gif) 7 | 8 | ## Updates 9 | 10 | 18.05.22: You can now convert your own COLMAP models to our data format and start training with it straight away. See [here](https://github.com/dichotomies/NeuralDiff#your-own-data) for more details. 11 | 12 | ## About 13 | 14 | This repository contains the official implementation of the paper *NeuralDiff: Segmenting 3D objects that move in egocentric videos* by [Vadim Tschernezki](https://github.com/dichotomies), [Diane Larlus](https://dlarlus.github.io/) and [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/). Published at 3DV21. 15 | 16 | Given a raw video sequence taken from a freely-moving camera, we study the problem of decomposing the observed 3D scene into a static background and a dynamic foreground containing the objects that move in the video sequence. This task is reminiscent of the classic background subtraction problem, but is significantly harder because all parts of the scene, static and dynamic, generate a large apparent motion due to the camera large viewpoint change. In particular, we consider egocentric videos and further separate the dynamic component into objects and the actor that observes and moves them. We achieve this factorization by reconstructing the video via a triple-stream neural rendering network that explains the different motions based on corresponding inductive biases. We demonstrate that our method can successfully separate the different types of motion, outperforming recent neural rendering baselines at this task, and can accurately segment moving objects. We do so by assessing the method empirically on challenging videos from the EPIC-KITCHENS dataset which we augment with appropriate annotations to create a new benchmark for the task of dynamic object segmentation on unconstrained video sequences, for complex 3D environments. 17 | 18 | ## Getting started 19 | 20 | We provide an environment config file for [anaconda](https://www.anaconda.com/). You can install and activate it with the following commands: 21 | 22 | ``` 23 | conda env create -f environment.yaml 24 | conda activate neuraldiff 25 | ``` 26 | 27 | After that, you can initialise the repository with pretrained models and the data through `sh get_started.sh` (this will download and extract everything you need to train and evaluate models). After that you can proceed directly to **Reproducing results**. If you should have any trouble with that, then initialise the repository manually as described in the next sections (**Dataset**, **Pretrained models**). 28 | 29 | ## Dataset 30 | 31 | ### EPIC-Diff 32 | 33 | The EPIC-Diff dataset can be downloaded [here](https://www.robots.ox.ac.uk/~vadim/neuraldiff/release/EPIC-Diff-annotations.tar.gz). 34 | 35 | After downloading, move the compressed dataset to the directory of the cloned repository (e.g. `NeuralDiff`). Then, apply following commands: 36 | 37 | ``` 38 | mkdir data 39 | mv EPIC-Diff.tar.gz data 40 | cd data 41 | tar -xzvf EPIC-Diff.tar.gz 42 | ``` 43 | 44 | The RGB frames are hosted separately as a subset from the [EPIC-Kitchens](https://epic-kitchens.github.io/2022) dataset. The data are available at the University of Bristol [data repository](https://doi.org/10.5523/bris.296c4vv03j7lb2ejq3874ej3vm), data.bris. Once downloaded, move the folders into the same directory as mentioned before (`data/EPIC-Diff`). 45 | 46 | ### Your own data 47 | 48 | We include a script that converts COLMAP models to our data format. An example command would be: 49 | 50 | ``` 51 | python -m colmap_converter --colmap_dir --scale=8 52 | ``` 53 | 54 | The colmap dir should contain the folder `images` and the sparse model in `sparse/0`. You can also choose to down-scale the images via the scale argument. 55 | 56 | The script switches to GPU computation if you are using more than 1000 images to speed up the formatting. 57 | 58 | ## Pretrained models 59 | 60 | We are providing model checkpoints for all 10 scenes. You can use these to 61 | - evaluate the models with the annotations from the EPIC-Diff benchmark 62 | - create a summary video like at the top of this README to visualise the separation of the video into background, foreground and actor 63 | 64 | The models can be downloaded [here](https://www.robots.ox.ac.uk/~vadim/neuraldiff/release/ckpts.tar.gz) (about 50MB in total). 65 | 66 | Once downloaded, place `ckpts.tar.gz` into the main directory. Then execute `tar -xzvf ckpts.tar.gz`. This will create a folder `ckpts` with the pretrained models. 67 | 68 | ## Reproducing results 69 | 70 | ### Visualisations and metrics per scene 71 | 72 | To evaluate the scene with Video ID `P01_01`, use the following command: 73 | 74 | ``` 75 | sh scripts/eval.sh rel P01_01 rel 'masks' 0 0 76 | ``` 77 | 78 | The results are saved in `results/rel`. The subfolders contain a txt file containing the mAP and PSNR scores per scene and visualisations per sample. 79 | 80 | You can find all scene IDs in the EPIC-Diff data folder (e.g. `P01_01`, `P03_04`, ... `P21_01`). 81 | 82 | ### Average metrics over all scenes 83 | 84 | You can calculate the average of the metrics over all scenes (Table 1 in the paper) with the following command: 85 | 86 | ``` 87 | sh scripts/eval.sh rel 0 0 'average' 0 0 88 | ``` 89 | 90 | Make sure that you have calculated the metrics per scene before proceeding with that (this command simply reads the produced metrics per scene and averages them). 91 | 92 | ### Rendering a video with separation of background, foreground and actor 93 | 94 | To visualise the different model components of a reconstructed video (as seen on top of this page) from 95 | 1) the ground truth camera poses corresponding to the time of the video 96 | 2) and a fixed viewpoint, 97 | use the following command: 98 | 99 | ``` 100 | sh scripts/eval.sh rel P01_01 rel 'summary' 0 0 101 | ``` 102 | 103 | This will result in a corresponding video in the folder `results/rel/P01_01/summary`. 104 | 105 | The fixed viewpoints are pre-defined and correspond to the ones that we used in the videos provided in the supplementary material. You can adjust the viewpoints in `__init__.py` of `dataset`. 106 | 107 | ## Training 108 | 109 | We provide scripts for the proposed model (including colour normalisation). To train a model for scene `P01_01`, use the following command. 110 | 111 | ``` 112 | sh scripts/train.sh P01_01 113 | ``` 114 | 115 | You can visualise the training with tensorboard. The logs are stored in `logs`. 116 | 117 | ## Citation 118 | 119 | If you found our code or paper useful, then please cite our work as follows. 120 | 121 | ```bibtex 122 | @inproceedings{tschernezki21neuraldiff, 123 | author = {Vadim Tschernezki and Diane Larlus and 124 | Andrea Vedaldi}, 125 | booktitle = {Proceedings of the International Conference 126 | on {3D} Vision (3DV)}, 127 | title = {{NeuralDiff}: Segmenting {3D} objects that 128 | move in egocentric videos}, 129 | year = {2021} 130 | } 131 | ``` 132 | 133 | ## Acknowledgements 134 | 135 | This implementation is based on [this](https://github.com/bmild/nerf) (official NeRF) and [this](https://github.com/kwea123/nerf_pl/tree/nerfw) repository (unofficial NeRF-W). 136 | 137 | Our dataset is based on a sub-set of frames from [EPIC-Kitchens](https://epic-kitchens.github.io/2022). [COLMAP](https://colmap.github.io) was used for computing 3D information for these frames and [VGG Image Annotator (VIA)](https://www.robots.ox.ac.uk/~vgg/software/via/) was used for annotating them. 138 | -------------------------------------------------------------------------------- /colmap_converter/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | def convert_file_format(x, format_trg="bmp"): 3 | if format_trg is None: 4 | return x 5 | fn, ext = os.path.splitext(x) 6 | return fn + os.path.extsep + format_trg 7 | -------------------------------------------------------------------------------- /colmap_converter/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from .meta import * 3 | from .frames import * 4 | 5 | 6 | parser = argparse.ArgumentParser() 7 | 8 | 9 | def parse_args(): 10 | 11 | parser = argparse.ArgumentParser() 12 | 13 | parser.add_argument( 14 | "--colmap_dir", 15 | type=str, 16 | help="Root directory of COLMAP project directory, which contains `sparse/0`.", 17 | ) 18 | 19 | parser.add_argument( 20 | "--scale", default=1, type=int, help="Downscaling factor for images." 21 | ) 22 | 23 | parser.add_argument( 24 | "--dir_dst", default='data/custom', type=str, help="Destination directory for converted dataset." 25 | ) 26 | 27 | parser.add_argument( 28 | "--split_nth", default=0, type=int, help="select every n-th frame as validation and every other n-th frame as test frame." 29 | ) 30 | 31 | args = parser.parse_args() 32 | 33 | return args 34 | 35 | def run(args): 36 | colmap_model_dir = os.path.join(args.colmap_dir, 'sparse/0') 37 | colmap = load_colmap(colmap_model_dir) 38 | meta = calc_meta(colmap, split_nth=args.split_nth) 39 | frames_dir_src = os.path.join(args.colmap_dir, 'images') 40 | dataset_id = os.path.split(os.path.normpath(args.colmap_dir))[1] 41 | dataset_dir = os.path.join(args.dir_dst, dataset_id) 42 | frames_dir_dst = os.path.join(dataset_dir, 'images') 43 | os.makedirs(frames_dir_dst) 44 | 45 | save_meta(dataset_dir, meta) 46 | save_frames(frames_dir_src, frames_dir_dst, meta) 47 | 48 | if __name__ == '__main__': 49 | args = parse_args() 50 | run(args) 51 | -------------------------------------------------------------------------------- /colmap_converter/colmap_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch at inf.ethz.ch) 31 | 32 | import os 33 | import sys 34 | import collections 35 | import numpy as np 36 | import struct 37 | 38 | 39 | CameraModel = collections.namedtuple( 40 | "CameraModel", ["model_id", "model_name", "num_params"]) 41 | Camera = collections.namedtuple( 42 | "Camera", ["id", "model", "width", "height", "params"]) 43 | BaseImage = collections.namedtuple( 44 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 45 | Point3D = collections.namedtuple( 46 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 47 | 48 | class Image(BaseImage): 49 | def qvec2rotmat(self): 50 | return qvec2rotmat(self.qvec) 51 | 52 | 53 | CAMERA_MODELS = { 54 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 55 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 56 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 57 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 58 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 59 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 60 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 61 | CameraModel(model_id=7, model_name="FOV", num_params=5), 62 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 63 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 64 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 65 | } 66 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) \ 67 | for camera_model in CAMERA_MODELS]) 68 | 69 | 70 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 71 | """Read and unpack the next bytes from a binary file. 72 | :param fid: 73 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 74 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 75 | :param endian_character: Any of {@, =, <, >, !} 76 | :return: Tuple of read and unpacked values. 77 | """ 78 | data = fid.read(num_bytes) 79 | return struct.unpack(endian_character + format_char_sequence, data) 80 | 81 | 82 | def read_cameras_text(path): 83 | """ 84 | see: src/base/reconstruction.cc 85 | void Reconstruction::WriteCamerasText(const std::string& path) 86 | void Reconstruction::ReadCamerasText(const std::string& path) 87 | """ 88 | cameras = {} 89 | with open(path, "r") as fid: 90 | while True: 91 | line = fid.readline() 92 | if not line: 93 | break 94 | line = line.strip() 95 | if len(line) > 0 and line[0] != "#": 96 | elems = line.split() 97 | camera_id = int(elems[0]) 98 | model = elems[1] 99 | width = int(elems[2]) 100 | height = int(elems[3]) 101 | params = np.array(tuple(map(float, elems[4:]))) 102 | cameras[camera_id] = Camera(id=camera_id, model=model, 103 | width=width, height=height, 104 | params=params) 105 | return cameras 106 | 107 | 108 | def read_cameras_binary(path_to_model_file): 109 | """ 110 | see: src/base/reconstruction.cc 111 | void Reconstruction::WriteCamerasBinary(const std::string& path) 112 | void Reconstruction::ReadCamerasBinary(const std::string& path) 113 | """ 114 | cameras = {} 115 | with open(path_to_model_file, "rb") as fid: 116 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 117 | for camera_line_index in range(num_cameras): 118 | camera_properties = read_next_bytes( 119 | fid, num_bytes=24, format_char_sequence="iiQQ") 120 | camera_id = camera_properties[0] 121 | model_id = camera_properties[1] 122 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 123 | width = camera_properties[2] 124 | height = camera_properties[3] 125 | num_params = CAMERA_MODEL_IDS[model_id].num_params 126 | params = read_next_bytes(fid, num_bytes=8*num_params, 127 | format_char_sequence="d"*num_params) 128 | cameras[camera_id] = Camera(id=camera_id, 129 | model=model_name, 130 | width=width, 131 | height=height, 132 | params=np.array(params)) 133 | assert len(cameras) == num_cameras 134 | return cameras 135 | 136 | 137 | def read_images_text(path): 138 | """ 139 | see: src/base/reconstruction.cc 140 | void Reconstruction::ReadImagesText(const std::string& path) 141 | void Reconstruction::WriteImagesText(const std::string& path) 142 | """ 143 | images = {} 144 | with open(path, "r") as fid: 145 | while True: 146 | line = fid.readline() 147 | if not line: 148 | break 149 | line = line.strip() 150 | if len(line) > 0 and line[0] != "#": 151 | elems = line.split() 152 | image_id = int(elems[0]) 153 | qvec = np.array(tuple(map(float, elems[1:5]))) 154 | tvec = np.array(tuple(map(float, elems[5:8]))) 155 | camera_id = int(elems[8]) 156 | image_name = elems[9] 157 | elems = fid.readline().split() 158 | xys = np.column_stack([tuple(map(float, elems[0::3])), 159 | tuple(map(float, elems[1::3]))]) 160 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 161 | images[image_id] = Image( 162 | id=image_id, qvec=qvec, tvec=tvec, 163 | camera_id=camera_id, name=image_name, 164 | xys=xys, point3D_ids=point3D_ids) 165 | return images 166 | 167 | 168 | def read_images_binary(path_to_model_file): 169 | """ 170 | see: src/base/reconstruction.cc 171 | void Reconstruction::ReadImagesBinary(const std::string& path) 172 | void Reconstruction::WriteImagesBinary(const std::string& path) 173 | """ 174 | images = {} 175 | with open(path_to_model_file, "rb") as fid: 176 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 177 | for image_index in range(num_reg_images): 178 | binary_image_properties = read_next_bytes( 179 | fid, num_bytes=64, format_char_sequence="idddddddi") 180 | image_id = binary_image_properties[0] 181 | qvec = np.array(binary_image_properties[1:5]) 182 | tvec = np.array(binary_image_properties[5:8]) 183 | camera_id = binary_image_properties[8] 184 | image_name = "" 185 | current_char = read_next_bytes(fid, 1, "c")[0] 186 | while current_char != b"\x00": # look for the ASCII 0 entry 187 | image_name += current_char.decode("utf-8") 188 | current_char = read_next_bytes(fid, 1, "c")[0] 189 | num_points2D = read_next_bytes(fid, num_bytes=8, 190 | format_char_sequence="Q")[0] 191 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 192 | format_char_sequence="ddq"*num_points2D) 193 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 194 | tuple(map(float, x_y_id_s[1::3]))]) 195 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 196 | images[image_id] = Image( 197 | id=image_id, qvec=qvec, tvec=tvec, 198 | camera_id=camera_id, name=image_name, 199 | xys=xys, point3D_ids=point3D_ids) 200 | return images 201 | 202 | 203 | def read_points3D_text(path): 204 | """ 205 | see: src/base/reconstruction.cc 206 | void Reconstruction::ReadPoints3DText(const std::string& path) 207 | void Reconstruction::WritePoints3DText(const std::string& path) 208 | """ 209 | points3D = {} 210 | with open(path, "r") as fid: 211 | while True: 212 | line = fid.readline() 213 | if not line: 214 | break 215 | line = line.strip() 216 | if len(line) > 0 and line[0] != "#": 217 | elems = line.split() 218 | point3D_id = int(elems[0]) 219 | xyz = np.array(tuple(map(float, elems[1:4]))) 220 | rgb = np.array(tuple(map(int, elems[4:7]))) 221 | error = float(elems[7]) 222 | image_ids = np.array(tuple(map(int, elems[8::2]))) 223 | point2D_idxs = np.array(tuple(map(int, elems[9::2]))) 224 | points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb, 225 | error=error, image_ids=image_ids, 226 | point2D_idxs=point2D_idxs) 227 | return points3D 228 | 229 | 230 | def read_points3d_binary(path_to_model_file): 231 | """ 232 | see: src/base/reconstruction.cc 233 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 234 | void Reconstruction::WritePoints3DBinary(const std::string& path) 235 | """ 236 | points3D = {} 237 | with open(path_to_model_file, "rb") as fid: 238 | num_points = read_next_bytes(fid, 8, "Q")[0] 239 | for point_line_index in range(num_points): 240 | binary_point_line_properties = read_next_bytes( 241 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 242 | point3D_id = binary_point_line_properties[0] 243 | xyz = np.array(binary_point_line_properties[1:4]) 244 | rgb = np.array(binary_point_line_properties[4:7]) 245 | error = np.array(binary_point_line_properties[7]) 246 | track_length = read_next_bytes( 247 | fid, num_bytes=8, format_char_sequence="Q")[0] 248 | track_elems = read_next_bytes( 249 | fid, num_bytes=8*track_length, 250 | format_char_sequence="ii"*track_length) 251 | image_ids = np.array(tuple(map(int, track_elems[0::2]))) 252 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) 253 | points3D[point3D_id] = Point3D( 254 | id=point3D_id, xyz=xyz, rgb=rgb, 255 | error=error, image_ids=image_ids, 256 | point2D_idxs=point2D_idxs) 257 | return points3D 258 | 259 | 260 | def read_model(path, ext): 261 | if ext == ".txt": 262 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) 263 | images = read_images_text(os.path.join(path, "images" + ext)) 264 | points3D = read_points3D_text(os.path.join(path, "points3D") + ext) 265 | else: 266 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) 267 | images = read_images_binary(os.path.join(path, "images" + ext)) 268 | points3D = read_points3d_binary(os.path.join(path, "points3D") + ext) 269 | return cameras, images, points3D 270 | 271 | 272 | def qvec2rotmat(qvec): 273 | return np.array([ 274 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 275 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 276 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 277 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 278 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 279 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 280 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 281 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 282 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 283 | 284 | 285 | def rotmat2qvec(R): 286 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 287 | K = np.array([ 288 | [Rxx - Ryy - Rzz, 0, 0, 0], 289 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 290 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 291 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 292 | eigvals, eigvecs = np.linalg.eigh(K) 293 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 294 | if qvec[0] < 0: 295 | qvec *= -1 296 | return qvec -------------------------------------------------------------------------------- /colmap_converter/frames.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from PIL import Image 4 | from tqdm import tqdm 5 | 6 | from . import convert_file_format 7 | 8 | 9 | def resize_image(PIL_image, h, w): 10 | if (w, h) != PIL_image.size: 11 | im = PIL_image.resize((int(w), int(h)), Image.LANCZOS) 12 | else: 13 | im = PIL_image 14 | return im 15 | 16 | 17 | def load_image(path): 18 | im = Image.open(path).convert("RGB") 19 | return im 20 | 21 | 22 | def save_frames(dir_src, dir_dst, meta, format_src=None, format_trg=None): 23 | for k in tqdm(meta["ids_all"]): 24 | name_src = convert_file_format(meta["images"][k], format_src) 25 | path_src = os.path.join(dir_src, name_src) 26 | name_trg = convert_file_format(name_src, format_trg) 27 | path_trg = os.path.join(dir_dst, name_trg) 28 | frame = resize_image(load_image(path_src), meta["image_h"], meta["image_w"]) 29 | frame.save(path_trg) 30 | 31 | 32 | def save_annotations(root, dataset, maskloader): 33 | """NOTE: not used for now.""" 34 | root = os.path.join(root, dataset.vid, "annotations") 35 | os.makedirs(root) 36 | 37 | for k in dataset.img_ids_test: 38 | mask_orig = maskloader[k, 0, 1][0] 39 | mask = PIL.Image.fromarray(mask_orig) 40 | sample = dataset[k] 41 | im_path = sample["im_path"] 42 | path = os.path.join(root, im_path.replace("jpg", "bmp")) 43 | mask.save(path, format="bmp") 44 | mask = np.array(PIL.Image.open(path)) 45 | assert (mask == mask_orig).all() 46 | -------------------------------------------------------------------------------- /colmap_converter/meta.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from copy import deepcopy 4 | from os.path import join 5 | 6 | import numpy as np 7 | import torch 8 | from . import convert_file_format 9 | 10 | from .colmap_utils import (read_cameras_binary, read_images_binary, 11 | read_points3d_binary) 12 | 13 | 14 | def compare_meta(meta1, meta2): 15 | assert len(meta1) == len(meta2) 16 | for k in meta1: 17 | if any([x in k for x in ["ids", "image"]]): 18 | assert meta1[k] == meta2[k] 19 | elif k in ["poses", "nears", "fars"]: 20 | assert ( 21 | np.array(list(meta1["poses"].values())) 22 | == np.array(list(meta2["poses"].values())) 23 | ).all() 24 | else: 25 | assert (meta1[k] == meta2[k]).all() 26 | 27 | 28 | def load_colmap(path_model): 29 | colmap = {} 30 | colmap["images"] = read_images_binary(join(path_model, "images.bin")) 31 | colmap["cameras"] = read_cameras_binary(join(path_model, "cameras.bin")) 32 | colmap["pts3d"] = read_points3d_binary(join(path_model, "points3D.bin")) 33 | return colmap 34 | 35 | 36 | def split_ids(ids, nth_image=8): 37 | 38 | if nth_image != 0: 39 | ids_te = ids[::nth_image][::2] 40 | ids_vl = ids[::nth_image][1::2] 41 | else: 42 | ids_te = [] 43 | ids_vl = [] 44 | ids_tr = list(sorted(set(ids).difference(set(ids_vl).union(ids_te)))) 45 | 46 | return ids_tr, ids_vl, ids_te 47 | 48 | 49 | def calculate_intrinsics(colmap_camera, downscale=1): 50 | K = np.zeros((3, 3), dtype=np.float32) 51 | s = 1 / downscale 52 | K[0, 0] = colmap_camera.params[0] * s # fx 53 | K[1, 1] = colmap_camera.params[0] * s # fy 54 | K[0, 2] = colmap_camera.params[1] * s # cx 55 | K[1, 2] = colmap_camera.params[2] * s # cy 56 | K[2, 2] = 1 57 | return K 58 | 59 | 60 | def reduce_ids(meta): 61 | """Reduce large filename IDs to small ones for compatibility with time embedding.""" 62 | min_id = min(meta['ids_all']) 63 | fn_reduce = lambda x, min_id: (x - min_id) // 50 64 | assert len(set([fn_reduce(x, min_id) for x in meta['ids_all']])) == len(set(meta['ids_all'])) 65 | for k in meta: 66 | if 'ids' in k: 67 | meta[k] = [fn_reduce(x, min_id) for x in meta[k]] 68 | elif k in ['nears', 'fars', 'poses', 'images']: 69 | meta[k] = {fn_reduce(k, min_id): v for k, v in meta[k].items()} 70 | 71 | 72 | def calc_meta(colmap, image_downscale=1, with_cuda=True, split_nth=8): 73 | 74 | # e.g., `IMG_0000000200.bmp` to 200 75 | fn2int = lambda fn: int(os.path.splitext(fn.split("_")[1])[0]) 76 | 77 | # 1: load cameras and sort COLMAP indices 78 | # (COLMAP indices are not necessariliy sorted) 79 | 80 | colmap2fn = dict( 81 | sorted( 82 | [(k, colmap["images"][k].name) for k in colmap["images"]], 83 | key=lambda x: x[0], 84 | ) 85 | ) 86 | 87 | assert list(colmap2fn.keys()) == sorted(list(colmap2fn.keys())) 88 | fn2colmap = {v: k for k, v in colmap2fn.items()} 89 | colmap2sortedfn = dict(zip(colmap2fn, sorted(colmap2fn.values()))) 90 | colmap2sortedcolmap = {k: fn2colmap[v] for k, v in colmap2sortedfn.items()} 91 | colmap2sortedframeid = {k: fn2int(fn) for k, fn in colmap2sortedfn.items()} 92 | assert list(colmap2sortedframeid.values()) == sorted( 93 | list(colmap2sortedframeid.values()) 94 | ) 95 | 96 | colmap["images"] = { 97 | fn2int(colmap["images"][colmap2sortedcolmap[i]].name): colmap["images"][ 98 | colmap2sortedcolmap[i] 99 | ] 100 | for i in colmap["images"] 101 | } 102 | 103 | meta = {"ids_all": []} 104 | 105 | meta["images"] = {} # {id: filename} 106 | for k, v in colmap["images"].items(): 107 | filename = v.name 108 | meta["images"][k] = filename 109 | meta["ids_all"] += [k] 110 | 111 | # 2: read and rescale camera intrinsics 112 | 113 | assert len(colmap["cameras"]) == 1 114 | colmap_camera = list(colmap["cameras"].values())[0] 115 | meta["intrinsics"] = calculate_intrinsics(colmap_camera, downscale=image_downscale) 116 | meta['camera'] = {} 117 | for k in ['model', 'width', 'height', 'params']: 118 | if k == 'params': 119 | meta['camera'][k] = getattr(colmap_camera, k).tolist() 120 | else: 121 | meta['camera'][k] = getattr(colmap_camera, k) 122 | 123 | meta["image_w"] = colmap_camera.params[1] * 2 * (1 / image_downscale) 124 | meta["image_h"] = colmap_camera.params[2] * 2 * (1 / image_downscale) 125 | 126 | # 3: read w2c and initialise c2w (poses) from w2c 127 | 128 | w2c_mats = [] 129 | bottom = np.array([0, 0, 0, 1.0]).reshape(1, 4) 130 | for id_ in meta["ids_all"]: 131 | im = colmap["images"][id_] 132 | R = im.qvec2rotmat() 133 | t = im.tvec.reshape(3, 1) 134 | w2c_mats += [np.concatenate([np.concatenate([R, t], 1), bottom], 0)] 135 | w2c_mats = np.stack(w2c_mats, 0) # (N_images, 4, 4) 136 | poses = np.linalg.inv(w2c_mats)[:, :3] # (N_images, 3, 4) 137 | # poses has rotation in form "right down front", change to "right up back" 138 | poses[..., 1:3] *= -1 139 | 140 | xyz_world = np.array([colmap["pts3d"][i].xyz for i in colmap["pts3d"]]) 141 | xyz_world_h = np.concatenate([xyz_world, np.ones((len(xyz_world), 1))], -1) 142 | 143 | # 4: near and far bounds for each image 144 | 145 | from tqdm.notebook import tqdm 146 | 147 | # speed up if 1000s of images to be used 148 | if with_cuda: 149 | to = lambda x: x.cuda() 150 | else: 151 | to = lambda x: x 152 | 153 | w2c_mats_pt = torch.from_numpy(w2c_mats) 154 | xyz_world_h_pt = to(torch.from_numpy(xyz_world_h)) 155 | 156 | meta["nears"], meta["fars"] = {}, {} 157 | n_ids = len(meta["ids_all"]) 158 | for i, id_ in tqdm(enumerate(meta["ids_all"]), total=n_ids, disable=n_ids < 1000): 159 | xyz_cam_i = (xyz_world_h_pt @ (to(w2c_mats_pt[i].T)))[:, :3] 160 | xyz_cam_i = xyz_cam_i[xyz_cam_i[:, 2] > 0] 161 | meta["nears"][id_] = torch.quantile((xyz_cam_i[:, 2]), 0.1 / 100).item() 162 | meta["fars"][id_] = torch.quantile((xyz_cam_i[:, 2]), 99.9 / 100).item() 163 | 164 | meta["poses"] = {id_: poses[i] for i, id_ in enumerate(meta["ids_all"])} 165 | meta["ids_train"], meta["ids_val"], meta["ids_test"] = split_ids( 166 | meta["ids_all"], nth_image=split_nth 167 | ) 168 | 169 | # reduce_ids(meta) 170 | 171 | return meta 172 | 173 | 174 | def update_format(meta, image_format=None): 175 | if format is not None: 176 | for k in meta['images']: 177 | meta['images'][k] = convert_file_format(meta['images'][k], image_format) 178 | 179 | 180 | def load_meta(directory, name="meta.json"): 181 | path = os.path.join(directory, name) 182 | with open(path, "r") as fp: 183 | meta = json.load(fp) 184 | for k in ["nears", "fars", "images"]: 185 | meta[k] = {int(i): meta[k][i] for i in meta[k]} 186 | meta["poses"] = {int(i): np.array(meta["poses"][i]) for i in meta["poses"]} 187 | meta["intrinsics"] = np.array(meta["intrinsics"]) 188 | return meta 189 | 190 | 191 | def save_meta(directory, meta, name="meta.json", deepcopy=False): 192 | tolist = lambda x: x if type(x) is list else x.tolist() 193 | path = os.path.join(directory, name) 194 | if deepcopy: 195 | meta = deepcopy(meta) 196 | 197 | meta["poses"] = {k: tolist(meta["poses"][k]) for k in meta["poses"]} 198 | meta["intrinsics"] = tolist(meta["intrinsics"]) 199 | with open(path, "w") as f: 200 | json.dump(meta, f, indent=2) 201 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .annotations import MaskLoader 2 | from .rays import EPICDiff 3 | 4 | VIDEO_IDS = [ 5 | "P01_01", 6 | "P03_04", 7 | "P04_01", 8 | "P05_01", 9 | "P06_03", 10 | "P08_01", 11 | "P09_02", 12 | "P13_03", 13 | "P16_01", 14 | "P21_01", 15 | ] 16 | 17 | # e.g. for summary video in `evaluate.py` or for debugging 18 | SAMPLE_IDS = { 19 | "P01_01": 716, 20 | "P03_04": 702, 21 | "P04_01": 745, 22 | "P05_01": 237, 23 | "P06_03": 957, 24 | "P08_01": 217, 25 | "P09_02": 89, 26 | "P13_03": 884, 27 | "P16_01": 76, 28 | "P21_01": 238, 29 | } -------------------------------------------------------------------------------- /dataset/annotations.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import PIL.Image 6 | 7 | 8 | def blend_mask(im, mask, colour, alpha, show_im=False): 9 | """Blend an image with a mask (colourised via `colour` and `alpha`).""" 10 | im = im.copy().astype(np.float) / 255 11 | for ch, rgb_v in zip([0, 1, 2], colour): 12 | im[:, :, ch][mask == 1] = im[:, :, ch][mask == 1] * (1 - alpha) + rgb_v * alpha 13 | if show_im: 14 | plt.imshow(im) 15 | plt.axis("off") 16 | plt.show() 17 | return im 18 | 19 | 20 | class MaskLoader: 21 | """Loads masks for a dataset initialised with a video ID.""" 22 | 23 | def __init__(self, dataset, is_debug=False): 24 | self.frames_dir = os.path.join(dataset.root, "frames") 25 | self.annotations_dir = os.path.join(dataset.root, "annotations") 26 | self.image_paths = dataset.image_paths 27 | 28 | self.mask_colour = [1, 0, 0] 29 | self.mask_alpha = 0.5 30 | 31 | self.is_debug = is_debug 32 | 33 | print(f"ID of loaded scene: {dataset.vid}.") 34 | print(f"Number of annotations: {len(os.listdir(self.annotations_dir))}.") 35 | 36 | def __getitem__(self, sample_id): 37 | image_id, image_ext = self.image_paths[sample_id].split(".") 38 | 39 | im = plt.imread(os.path.join(self.frames_dir, image_id + "." + image_ext)) 40 | mask = np.array( 41 | PIL.Image.open( 42 | os.path.join(self.annotations_dir, image_id + "." + image_ext) 43 | ) 44 | ) 45 | 46 | if self.is_debug: 47 | blend_mask(im, mask, self.mask_colour, self.mask_alpha, True) 48 | 49 | return mask, im 50 | -------------------------------------------------------------------------------- /dataset/rays.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | import torchvision.transforms 9 | from PIL import Image 10 | from torch.utils.data import Dataset 11 | 12 | from .utils import * 13 | 14 | 15 | def load_meta(root, name="meta.json"): 16 | """Load meta information per scene and frame (nears, fars, poses etc.).""" 17 | path = os.path.join(root, name) 18 | with open(path, "r") as fp: 19 | ds = json.load(fp) 20 | for k in ["nears", "fars", "images", "poses"]: 21 | ds[k] = {int(i): ds[k][i] for i in ds[k]} 22 | if k == "poses": 23 | ds[k] = {i: np.array(ds[k][i]) for i in ds[k]} 24 | ds["intrinsics"] = np.array(ds["intrinsics"]) 25 | return ds 26 | 27 | 28 | class EPICDiff(Dataset): 29 | def __init__(self, vid, root="data/EPIC-Diff", split=None): 30 | 31 | self.root = os.path.join(root, vid) 32 | self.vid = vid 33 | self.img_w = 228 34 | self.img_h = 128 35 | self.split = split 36 | self.val_num = 1 37 | self.transform = torchvision.transforms.ToTensor() 38 | self.init_meta() 39 | 40 | def imshow(self, index): 41 | plt.imshow(self.imread(index)) 42 | plt.axis("off") 43 | plt.show() 44 | 45 | def imread(self, index): 46 | return plt.imread(os.path.join(self.root, "frames", self.image_paths[index])) 47 | 48 | def x2im(self, x, type_="np"): 49 | """Convert numpy or torch tensor to numpy or torch 'image'.""" 50 | w = self.img_w 51 | h = self.img_h 52 | if len(x.shape) == 2 and x.shape[1] == 3: 53 | x = x.reshape(h, w, 3) 54 | else: 55 | x = x.reshape(h, w) 56 | if type(x) == torch.Tensor: 57 | x = x.detach().cpu() 58 | if type_ == "np": 59 | x = x.numpy() 60 | elif type(x) == np.array: 61 | if type_ == "pt": 62 | x = torch.from_numpy(x) 63 | return x 64 | 65 | def rays_per_image(self, idx, pose=None): 66 | """Return sample with rays, frame index etc.""" 67 | sample = {} 68 | if pose is None: 69 | sample["c2w"] = c2w = torch.FloatTensor(self.poses_dict[idx]) 70 | else: 71 | sample["c2w"] = c2w = pose 72 | 73 | sample["im_path"] = self.image_paths[idx] 74 | 75 | img = Image.open(os.path.join(self.root, "frames", self.image_paths[idx])) 76 | img_w, img_h = img.size 77 | img = self.transform(img) # (3, h, w) 78 | img = img.view(3, -1).permute(1, 0) # (h*w, 3) RGB 79 | 80 | directions = get_ray_directions(img_h, img_w, self.K) 81 | rays_o, rays_d = get_rays(directions, c2w) 82 | 83 | c2c = torch.zeros(3, 4).to(c2w.device) 84 | c2c[:3, :3] = torch.eye(3, 3).to(c2w.device) 85 | rays_o_c, rays_d_c = get_rays(directions, c2c) 86 | 87 | rays_t = idx * torch.ones(len(rays_o), 1).long() 88 | 89 | rays = torch.cat( 90 | [ 91 | rays_o, 92 | rays_d, 93 | self.nears[idx] * torch.ones_like(rays_o[:, :1]), 94 | self.fars[idx] * torch.ones_like(rays_o[:, :1]), 95 | rays_o_c, 96 | rays_d_c, 97 | ], 98 | 1, 99 | ) 100 | 101 | sample["rays"] = rays 102 | sample["img_wh"] = torch.LongTensor([img_w, img_h]) 103 | sample["ts"] = rays_t 104 | sample["rgbs"] = img 105 | 106 | return sample 107 | 108 | def init_meta(self): 109 | """Load meta information, e.g. intrinsics, train, test, val split etc.""" 110 | meta = load_meta(self.root) 111 | self.img_ids = meta["ids_all"] 112 | self.img_ids_train = meta["ids_train"] 113 | self.img_ids_test = meta["ids_test"] 114 | self.img_ids_val = meta["ids_val"] 115 | self.poses_dict = meta["poses"] 116 | self.nears = meta["nears"] 117 | self.fars = meta["fars"] 118 | self.image_paths = meta["images"] 119 | self.K = meta["intrinsics"] 120 | 121 | if self.split == "train": 122 | # create buffer of all rays and rgb data 123 | self.rays = [] 124 | self.rgbs = [] 125 | self.ts = [] 126 | 127 | for idx in self.img_ids_train: 128 | sample = self.rays_per_image(idx) 129 | self.rgbs += [sample["rgbs"]] 130 | self.rays += [sample["rays"]] 131 | self.ts += [sample["ts"]] 132 | 133 | self.rays = torch.cat(self.rays, 0) # ((N_images-1)*h*w, 8) 134 | self.rgbs = torch.cat(self.rgbs, 0) # ((N_images-1)*h*w, 3) 135 | self.ts = torch.cat(self.ts, 0) 136 | 137 | def __len__(self): 138 | if self.split == "train": 139 | # rays are stored concatenated 140 | return len(self.rays) 141 | if self.split == "val": 142 | # evaluate only one image, sampled from val img ids 143 | return 1 144 | else: 145 | # choose any image index 146 | return max(self.img_ids) 147 | 148 | def __getitem__(self, idx, pose=None): 149 | 150 | if self.split == "train": 151 | # samples selected from prefetched train data 152 | sample = { 153 | "rays": self.rays[idx], 154 | "ts": self.ts[idx, 0].long(), 155 | "rgbs": self.rgbs[idx], 156 | } 157 | 158 | elif self.split == "val": 159 | # for tuning hyperparameters, tensorboard samples 160 | idx = random.choice(self.img_ids_val) 161 | sample = self.rays_per_image(idx, pose) 162 | 163 | elif self.split == "test": 164 | # evaluating according to table in paper, chosen index must be in test ids 165 | assert idx in self.img_ids_test 166 | sample = self.rays_per_image(idx, pose) 167 | 168 | else: 169 | # for arbitrary samples, e.g. summary video when rendering over all images 170 | sample = self.rays_per_image(idx, pose) 171 | 172 | return sample 173 | -------------------------------------------------------------------------------- /dataset/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kornia import create_meshgrid 3 | 4 | 5 | def get_ray_directions(H, W, K): 6 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0] 7 | i, j = grid.unbind(-1) 8 | fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] 9 | directions = torch.stack( 10 | [(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1 11 | ) # (H, W, 3) 12 | 13 | return directions 14 | 15 | 16 | def get_rays(directions, c2w): 17 | # Rotate ray directions from camera coordinate to the world coordinate 18 | rays_d = directions @ c2w[:, :3].T # (H, W, 3) 19 | rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 20 | # The origin of all rays is the camera origin in world coordinate 21 | rays_o = c2w[:, 3].expand(rays_d.shape) # (H, W, 3) 22 | 23 | rays_d = rays_d.view(-1, 3) 24 | rays_o = rays_o.view(-1, 3) 25 | 26 | return rays_o, rays_d 27 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: neuraldiff 2 | channels: 3 | - plotly 4 | - conda-forge 5 | - anaconda 6 | - defaults 7 | dependencies: 8 | - pip=21.0.1=py37h06a4308_0 9 | - matplotlib-base=3.3.3=py37h4f6019d_0 10 | - ffmpeg=4.2.2=h20bf706_0 11 | - pip: 12 | - numpy==1.20.1 13 | - pytorch-lightning==1.1.5 14 | - torch==1.7.1 15 | - torchvision==0.8.2 16 | - gitpython==3.1.14 17 | - einops==0.3.0 18 | - cvbase==0.5.5 19 | - opencv-python==4.5.1.48 20 | - kornia==0.4.1 21 | - imageio==2.9.0 22 | - scikit-image==0.18.1 23 | - scikit-learn==0.24.1 24 | - scipy==1.6.1 25 | - test-tube==0.7.5 26 | - imageio-ffmpeg==0.4.2 27 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | 6 | import evaluation 7 | import utils 8 | from dataset import SAMPLE_IDS, VIDEO_IDS, EPICDiff, MaskLoader 9 | 10 | 11 | def parse_args(path=None, vid=None, exp=None): 12 | 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument("--path", type=str, default=path, help="Path to model.") 16 | 17 | parser.add_argument("--vid", type=str, default=vid, help="Video ID of dataset.") 18 | 19 | parser.add_argument("--exp", type=str, default=exp, help="Experiment name.") 20 | 21 | parser.add_argument( 22 | "--outputs", 23 | default=["masks"], 24 | type=str, 25 | nargs="+", 26 | help="Evaluation output. Select `masks` or `summary` or both.", 27 | ) 28 | 29 | parser.add_argument( 30 | "--masks_n_samples", 31 | type=int, 32 | default=0, 33 | help="Select number of samples for evaluation. If kept at 0, then all test samples are evaluated.", 34 | ) 35 | 36 | parser.add_argument( 37 | "--summary_n_samples", 38 | type=int, 39 | default=0, 40 | help="Number of samples to evaluate for summary video. If 0 is selected, then the video is rendered with all frames from the dataset.", 41 | ) 42 | 43 | parser.add_argument( 44 | "--root_data", type=str, default="data/EPIC-Diff", help="Root of the dataset." 45 | ) 46 | 47 | parser.add_argument( 48 | "--suppress_person", 49 | default=False, 50 | action="store_true", 51 | help="Disables person, e.g. for visualising complete foreground without parts missing where person occludes the foreground.", 52 | ) 53 | 54 | # for opt.py 55 | parser.add_argument("--is_eval_script", default=True, action="store_true") 56 | 57 | args = parser.parse_args() 58 | 59 | return args 60 | 61 | 62 | def init(args): 63 | 64 | dataset = EPICDiff(args.vid, root=args.root_data) 65 | 66 | model = utils.init_model(args.path, dataset) 67 | 68 | # update parameters of loaded models 69 | model.hparams["suppress_person"] = args.suppress_person 70 | model.hparams["inference"] = True 71 | 72 | return model, dataset 73 | 74 | 75 | def eval_masks(args, model, dataset, root): 76 | """Evaluate masks to produce mAP (and PSNR) scores.""" 77 | root = os.path.join(root, "masks") 78 | os.makedirs(root) 79 | 80 | maskloader = MaskLoader(dataset=dataset) 81 | 82 | image_ids = evaluation.utils.sample_linear( 83 | dataset.img_ids_test, args.masks_n_samples 84 | )[0] 85 | 86 | results = evaluation.evaluate( 87 | dataset, 88 | model, 89 | maskloader, 90 | vis_i=1, 91 | save_dir=root, 92 | save=True, 93 | vid=args.vid, 94 | image_ids=image_ids, 95 | ) 96 | 97 | 98 | def eval_masks_average(args): 99 | """Calculate average of `eval_masks` results for all 10 scenes.""" 100 | scores = [] 101 | for vid in VIDEO_IDS: 102 | path_metrics = os.path.join("results", args.exp, vid, 'masks', 'metrics.txt') 103 | with open(f'results/rel/{vid}/masks/metrics.txt') as f: 104 | lines = f.readlines() 105 | score_map, score_psnr = [float(s) for s in lines[2].split('\t')[:2]] 106 | scores.append([score_map, score_psnr]) 107 | scores = np.array(scores).mean(axis=0) 108 | print('Average for all 10 scenes:') 109 | print(f'mAP: {(scores[0]*100).round(2)}, PSNR: {scores[1].round(2)}') 110 | 111 | 112 | def render_video(args, model, dataset, root, save_cache=False): 113 | """Render a summary video like shown on the project page.""" 114 | root = os.path.join(root, "summary") 115 | os.makedirs(root) 116 | 117 | sid = SAMPLE_IDS[args.vid] 118 | 119 | top = evaluation.video.render( 120 | dataset, model, n_images=args.summary_n_samples 121 | ) 122 | bot = evaluation.video.render( 123 | dataset, model, sid, n_images=args.summary_n_samples 124 | ) 125 | 126 | if save_cache: 127 | evaluation.video.save_to_cache( 128 | args.vid, sid, root=root, top=top, bot=bot 129 | ) 130 | 131 | ims_cat = [ 132 | evaluation.video.convert_rgb( 133 | evaluation.video.cat_sample(top[k], bot[k]) 134 | ) 135 | for k in bot.keys() 136 | ] 137 | 138 | utils.write_mp4(f"{root}/cat-{sid}-N{len(ims_cat)}", ims_cat) 139 | 140 | 141 | def run(args, model, dataset, root): 142 | 143 | if "masks" in args.outputs: 144 | # segmentations and renderings with mAP and PSNR 145 | eval_masks(args, model, dataset, root) 146 | 147 | if "summary" in args.outputs: 148 | # summary video 149 | render_video(args, model, dataset, root) 150 | 151 | 152 | if __name__ == "__main__": 153 | args = parse_args() 154 | if 'average' in args.outputs: 155 | # calculate average over all 10 scenes for specific experiment 156 | eval_masks_average(args) 157 | else: 158 | model, dataset = init(args) 159 | root = os.path.join("results", args.exp, args.vid) 160 | run(args, model, dataset, root) 161 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from . import video, segmentation 2 | from .segmentation import evaluate, evaluate_sample 3 | -------------------------------------------------------------------------------- /evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mse(image_pred, image_gt): 5 | value = (image_pred - image_gt) ** 2 6 | return torch.mean(value) 7 | 8 | 9 | def psnr(image_pred, image_gt): 10 | return -10 * torch.log10(mse(image_pred, image_gt)) 11 | -------------------------------------------------------------------------------- /evaluation/segmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluate segmentation capacity of model via mAP, 3 | also includes renderings of segmentations and PSNR evaluation. 4 | """ 5 | import os 6 | from collections import defaultdict 7 | 8 | import git 9 | import matplotlib 10 | matplotlib.use('Agg') 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import torch 14 | import tqdm 15 | from sklearn.metrics import average_precision_score 16 | 17 | from . import metrics, utils 18 | 19 | 20 | def evaluate_sample( 21 | ds, 22 | sample_id, 23 | t=None, 24 | visualise=True, 25 | gt_masked=None, 26 | model=None, 27 | mask_targ=None, 28 | save=False, 29 | pose=None, 30 | ): 31 | """ 32 | Evaluate one sample of a dataset (ds). Calculate PSNR and mAP, 33 | and visualise different model components for this sample. Additionally, 34 | 1) a different timestep (`t`) can be chosen, which can be different from the 35 | timestep of the sample (useful for rendering the same view over different 36 | timesteps). 37 | """ 38 | if pose is None: 39 | sample = ds[sample_id] 40 | else: 41 | sample = ds.__getitem__(sample_id, pose) 42 | results = model.render(sample, t=t) 43 | figure = None 44 | 45 | output_person = "person_weights_sum" in results 46 | output_transient = "_rgb_fine_transient" in results 47 | 48 | img_wh = tuple(sample["img_wh"].numpy()) 49 | img_gt = ds.x2im(sample["rgbs"], type_="pt") 50 | img_pred = ds.x2im(results["rgb_fine"][:, :3], type_="pt") 51 | 52 | mask_stat = ds.x2im(results["_rgb_fine_static"][:, 3]) 53 | if output_transient: 54 | mask_transient = ds.x2im(results["_rgb_fine_transient"][:, 4]) 55 | mask_pred = mask_transient 56 | if output_person: 57 | mask_person = ds.x2im(results["_rgb_fine_person"][:, 5]) 58 | mask_pred = mask_pred + mask_person 59 | else: 60 | mask_person = np.zeros_like(mask_transient) 61 | 62 | beta = ds.x2im(results["beta"]) 63 | img_pred_static = ds.x2im(results["rgb_fine_static"][:, :3], type_="pt") 64 | img_pred_transient = ds.x2im(results["_rgb_fine_transient"][:, :3]) 65 | if output_person: 66 | img_pred_person = ds.x2im(results["_rgb_fine_person"][:, :3]) 67 | 68 | if mask_targ is not None: 69 | average_precision = average_precision_score( 70 | mask_targ.reshape(-1), mask_pred.reshape(-1) 71 | ) 72 | 73 | psnr = metrics.psnr(img_pred, img_gt).item() 74 | psnr_static = metrics.psnr(img_pred_static, img_gt).item() 75 | 76 | if visualise: 77 | 78 | figure, ax = plt.subplots(figsize=(8, 5)) 79 | figure.suptitle(f"Sample: {sample_id}.\n") 80 | plt.tight_layout() 81 | plt.subplot(331) 82 | plt.title("GT") 83 | if gt_masked is not None: 84 | plt.imshow(torch.from_numpy(gt_masked)) 85 | else: 86 | plt.imshow(img_gt) 87 | plt.axis("off") 88 | plt.subplot(332) 89 | plt.title(f"Pred. PSNR: {psnr:.2f}") 90 | plt.imshow(img_pred.clamp(0, 1)) 91 | plt.axis("off") 92 | plt.subplot(333) 93 | plt.axis("off") 94 | 95 | plt.subplot(334) 96 | plt.title(f"Static. PSNR: {psnr_static:.2f}") 97 | plt.imshow(img_pred_static) 98 | plt.axis("off") 99 | plt.subplot(335) 100 | plt.title(f"Transient") 101 | plt.imshow(img_pred_transient) 102 | plt.axis("off") 103 | if "_rgb_fine_person" in results: 104 | plt.subplot(336) 105 | plt.title("Person") 106 | plt.axis("off") 107 | plt.imshow(img_pred_person) 108 | else: 109 | plt.subplot(336) 110 | plt.axis("off") 111 | 112 | plt.subplot(337) 113 | if mask_targ is not None: 114 | plt.title(f"Mask. AP: {average_precision:.4f}") 115 | else: 116 | plt.title("Mask.") 117 | plt.imshow(mask_pred) 118 | plt.axis("off") 119 | plt.subplot(338) 120 | plt.title(f"Mask: Transient.") 121 | plt.imshow(mask_transient) 122 | plt.axis("off") 123 | plt.subplot(339) 124 | plt.title(f"Mask: Person.") 125 | plt.imshow(mask_person) 126 | plt.axis("off") 127 | 128 | if visualise and not save: 129 | plt.show() 130 | 131 | results = {} 132 | 133 | results["figure"] = figure 134 | results["im_tran"] = img_pred_transient 135 | results["im_stat"] = img_pred_static 136 | results["im_pred"] = img_pred 137 | results["im_targ"] = img_gt 138 | results["psnr"] = psnr 139 | results["mask_pred"] = mask_pred 140 | results["mask_stat"] = mask_stat 141 | if output_person: 142 | results["mask_pers"] = mask_person 143 | results["im_pers"] = img_pred_person 144 | results["mask_tran"] = mask_transient 145 | if mask_targ is not None: 146 | results["average_precision"] = average_precision 147 | 148 | for k in results: 149 | if k == "figure": 150 | continue 151 | if type(results[k]) == torch.Tensor: 152 | results[k] = results[k].to("cpu") 153 | 154 | return results 155 | 156 | 157 | def evaluate( 158 | dataset, 159 | model, 160 | mask_loader, 161 | vis_i=5, 162 | save_dir="results/test", 163 | save=False, 164 | vid=None, 165 | epoch=None, 166 | timestep_const=None, 167 | image_ids=None, 168 | ): 169 | """ 170 | Like `evaluate_sample`, but evaluates over all selected image_ids. 171 | Saves also visualisations and average scores of the selected samples. 172 | """ 173 | 174 | results = { 175 | k: [] 176 | for k in [ 177 | "avgpre", 178 | "psnr", 179 | "masks", 180 | "out", 181 | "hp", 182 | ] 183 | } 184 | 185 | if image_ids is None: 186 | image_ids = dataset.img_ids_test 187 | 188 | for i, sample_id in utils.tqdm(enumerate(image_ids), total=len(image_ids)): 189 | 190 | do_visualise = i % vis_i == 0 191 | 192 | tqdm.tqdm.write(f"Test sample {i}. Frame {sample_id}.") 193 | 194 | mask_targ, im_masked = mask_loader[sample_id] 195 | # ignore evaluation if no mask available 196 | if mask_targ.sum() == 0: 197 | print(f"No annotations for frame {sample_id}, skipping.") 198 | continue 199 | 200 | results["hp"] = model.hparams 201 | results["hp"]["git_eval"] = git.Repo( 202 | search_parent_directories=True 203 | ).head.object.hexsha 204 | 205 | if timestep_const is not None: 206 | timestep = sample_id 207 | sample_id = timestep_const 208 | else: 209 | timestep = sample_id 210 | out = evaluate_sample( 211 | dataset, 212 | sample_id, 213 | model=model, 214 | t=timestep, 215 | visualise=do_visualise, 216 | gt_masked=im_masked, 217 | mask_targ=mask_targ, 218 | save=save, 219 | ) 220 | 221 | if save and do_visualise: 222 | results_im = utils.plt_to_im(out["figure"]) 223 | os.makedirs(f"{save_dir}/per_sample", exist_ok=True) 224 | path = f"{save_dir}/per_sample/{sample_id}.png" 225 | plt.imsave(path, results_im) 226 | 227 | mask_pred = out["mask_pred"] 228 | 229 | results["avgpre"].append(out["average_precision"]) 230 | 231 | results["psnr"].append(out["psnr"]) 232 | results["masks"].append([mask_targ, mask_pred]) 233 | results["out"].append(out) 234 | 235 | metrics_ = { 236 | "avgpre": {}, 237 | "psnr": {}, 238 | } 239 | for metric in metrics_: 240 | metrics_[metric] = np.array( 241 | [x for x in results[metric] if not np.isnan(x)] 242 | ).mean() 243 | 244 | results["metrics"] = metrics_ 245 | 246 | if save: 247 | with open(f"{save_dir}/metrics.txt", "a") as f: 248 | lines = utils.write_summary(results) 249 | f.writelines(f"Epoch: {epoch}.\n") 250 | f.writelines(lines) 251 | 252 | print(f"avgpre: {results['metrics']['avgpre']}, PSNR: {results['metrics']['psnr']}") 253 | 254 | return results 255 | -------------------------------------------------------------------------------- /evaluation/utils.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from PIL import Image 7 | 8 | 9 | def plt_to_im(f, show=False, with_alpha=False): 10 | # f: figure from previous plot (generated with plt.figure()) 11 | buf = io.BytesIO() 12 | buf.seek(0) 13 | plt.savefig(buf, format="png") 14 | if not show: 15 | plt.close(f) 16 | im = Image.open(buf) 17 | # return without alpha channel (contains only 255 values) 18 | return np.array(im)[..., : 3 + with_alpha] 19 | 20 | 21 | def sample_linear(X, n_samples): 22 | if n_samples == 0: 23 | n_samples = len(X) 24 | n_samples = min(len(X), n_samples) 25 | indices = (np.linspace(0, len(X) - 1, n_samples)).round().astype(np.long) 26 | return [X[i] for i in indices], indices 27 | 28 | 29 | def tqdm(x, **kwargs): 30 | import sys 31 | 32 | if "ipykernel_launcher.py" in sys.argv[0]: 33 | # tqdm from notebook 34 | from tqdm.notebook import tqdm 35 | else: 36 | # otherwise default tqdm 37 | from tqdm import tqdm 38 | return tqdm(x, **kwargs) 39 | 40 | 41 | def write_summary(results): 42 | """Log average precision and PSNR score for evaluation.""" 43 | import io 44 | from contextlib import redirect_stdout 45 | 46 | with io.StringIO() as buf, redirect_stdout(buf): 47 | 48 | n = 0 49 | keys = sorted(results["metrics"].keys()) 50 | 51 | for k in keys: 52 | print(k.ljust(n), end="\t") 53 | print() 54 | for k in keys: 55 | trail = -2 if "psnr" in k else 10 56 | lead = 0 if "psnr" in k else 1 57 | print(f'{results["metrics"][k]:.4f}'[lead:trail].ljust(n), end="\t") 58 | print() 59 | 60 | output = buf.getvalue() 61 | 62 | return output 63 | -------------------------------------------------------------------------------- /evaluation/video.py: -------------------------------------------------------------------------------- 1 | """Render a summary video as shown on the project page.""" 2 | import math 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from . import segmentation, utils 8 | 9 | 10 | def render(dataset, model, sample_id=None, n_images=20): 11 | """ 12 | Render a video for a dataset and model. 13 | If a sample_id is selected, then the view is fixed and images 14 | are rendered for a specific viewpoint over a timerange (the bottom part 15 | of the summary video on the project page). Otherwise, images are rendered 16 | for multiple viewpoints (the top part of the summary video). 17 | """ 18 | 19 | ims = {} 20 | 21 | keys = [ 22 | "mask_pers", 23 | "mask_tran", 24 | "mask_pred", 25 | "im_tran", 26 | "im_stat", 27 | "im_pred", 28 | "im_pers", 29 | "im_targ", 30 | ] 31 | 32 | if n_images > len(dataset.img_ids) or n_images == 0: 33 | n_images = len(dataset.img_ids) 34 | 35 | for i in utils.tqdm(dataset.img_ids[:: math.ceil(len(dataset.img_ids) / n_images)]): 36 | if sample_id is not None: 37 | j = sample_id 38 | else: 39 | j = i 40 | timestep = i 41 | with torch.no_grad(): 42 | x = segmentation.evaluate_sample( 43 | dataset, j, t=timestep, model=model, visualise=False 44 | ) 45 | ims[i] = {k: x[k] for k in x if k in keys} 46 | return ims 47 | 48 | 49 | def cat_sample(top, bot): 50 | """Concatenate images from the top and bottom part of the summary video.""" 51 | keys = ["im_targ", "im_pred", "im_stat", "im_tran", "im_pers"] 52 | top = np.concatenate([(top[k]) for k in keys], axis=1) 53 | bot = np.concatenate([(bot[k]) for k in keys], axis=1) 54 | bot[ 55 | :, 56 | : bot.shape[1] // len(keys), # black background in first column 57 | ] = (0, 0, 0) 58 | z = np.concatenate([top, bot], axis=0) 59 | return z 60 | 61 | 62 | def save_to_cache(vid, sid, root, top=None, bot=None): 63 | """Save the images for rendering the video.""" 64 | if top is not None: 65 | p = f"{root}/images-{sid}-top.pt" 66 | if os.path.exists(p): 67 | print("images exist, aborting.") 68 | return 69 | torch.save(top, p) 70 | if bot is not None: 71 | p = f"{root}/images-{sid}-bot.pt" 72 | if os.path.exists(p): 73 | print("images exist, aborting.") 74 | return 75 | torch.save(bot, p) 76 | 77 | 78 | def load_from_cache(vid, sid, root, version=0): 79 | """Load the images for rendering the video.""" 80 | path_top = f"{root}/images-{sid}-top.pt" 81 | path_bot = f"{root}/images-{sid}-bot.pt" 82 | top = torch.load(path_top) 83 | bot = torch.load(path_bot) 84 | return top, bot 85 | 86 | 87 | def convert_rgb(im): 88 | im[im > 1] = 1 89 | im = (im * 255).astype(np.uint8) 90 | return im 91 | -------------------------------------------------------------------------------- /get_started.sh: -------------------------------------------------------------------------------- 1 | 2 | wget https://www.robots.ox.ac.uk/~vadim/neuraldiff/release/ckpts.tar.gz 3 | 4 | tar -xzvf ckpts.tar.gz 5 | 6 | mkdir data 7 | 8 | cd data 9 | 10 | wget https://www.robots.ox.ac.uk/~vadim/neuraldiff/release/EPIC-Diff-annotations.tar.gz 11 | 12 | tar -xzvf EPIC-Diff-annotations.tar.gz 13 | 14 | wget https://data.bris.ac.uk/datasets/tar/296c4vv03j7lb2ejq3874ej3vm.zip 15 | 16 | unzip 296c4vv03j7lb2ejq3874ej3vm.zip 17 | 18 | export EKPATH=296c4vv03j7lb2ejq3874ej3vm 19 | 20 | for X in $(ls $EKPATH); 21 | do echo $X; 22 | for Z in $(ls $EKPATH/$X); 23 | do echo $Z; 24 | mv $PWD/$EKPATH/$X/$Z EPIC-Diff/$X 25 | done; 26 | done 27 | 28 | mv $PWD/$EKPATH/readme.txt EPIC-Diff/README_EPIC-Kitchens.txt 29 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Loss(nn.Module): 6 | """ 7 | Name abbreviations: 8 | c_l: coarse color loss 9 | f_l: fine color loss 10 | b_l: beta loss 11 | s_l: sigma loss 12 | """ 13 | 14 | def __init__(self, lambda_u=0.01): 15 | """ 16 | lambda_u: regularisation for sigmas. 17 | """ 18 | super().__init__() 19 | self.lambda_u = lambda_u 20 | 21 | def forward(self, inputs, targets, is_nerf_eval=False): 22 | ret = {} 23 | ret["c_l"] = 0.5 * ((inputs["rgb_coarse"] - targets) ** 2).mean() 24 | if "rgb_fine" in inputs: 25 | ret["f_l"] = ( 26 | (inputs["rgb_fine"] - targets) ** 2 27 | / (2 * inputs["beta"].unsqueeze(1) ** 2) 28 | ).mean() 29 | ret["b_l"] = torch.log(inputs["beta"]).mean() 30 | ret["s_l"] = self.lambda_u * inputs["transient_sigmas"].mean() 31 | ret["s_l"] = ret["s_l"] + self.lambda_u * inputs["person_sigmas"].mean() 32 | 33 | return ret 34 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .embedding import LREEmbedding, PosEmbedding 2 | from .neuraldiff import NeuralDiff 3 | from .rendering import render_rays 4 | -------------------------------------------------------------------------------- /model/embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class PosEmbedding(torch.nn.Module): 5 | def __init__(self, max_logscale, N_freqs, logscale=True): 6 | """ 7 | Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...) 8 | """ 9 | super().__init__() 10 | self.funcs = [torch.sin, torch.cos] 11 | 12 | if logscale: 13 | self.freqs = 2 ** torch.linspace(0, max_logscale, N_freqs) 14 | else: 15 | self.freqs = torch.linspace(1, 2 ** max_logscale, N_freqs) 16 | 17 | def forward(self, x): 18 | out = [x] 19 | for freq in self.freqs: 20 | for func in self.funcs: 21 | out += [func(freq * x)] 22 | 23 | return torch.cat(out, -1) 24 | 25 | 26 | class LREEmbedding(torch.nn.Module): 27 | """ 28 | As desribed in "Smooth Dynamics", low rank expansion of trajectory states. 29 | """ 30 | 31 | def __init__(self, N=1000, D=16, K=21): 32 | super().__init__() 33 | self.embedding = PosEmbedding(K // 2 - 1, K // 2) 34 | self.W = W = torch.nn.Parameter(torch.FloatTensor(K, D)) 35 | torch.nn.init.kaiming_uniform_(W) 36 | # normalise input range to [-1, +1] 37 | self.input_range = (torch.arange(0, N).view(-1, 1) / (N - 1) - 1 / 2) * 2 38 | self.P = self.embedding(self.input_range).cuda() 39 | 40 | def __call__(self, indices): 41 | L = self.P @ self.W 42 | return L[indices] 43 | -------------------------------------------------------------------------------- /model/neuraldiff.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class NeuralDiff(nn.Module): 6 | def __init__( 7 | self, 8 | typ, 9 | D=8, 10 | W=256, 11 | skips=[4], 12 | in_channels_xyz=63, 13 | in_channels_dir=27, 14 | encode_dynamic=False, 15 | in_channels_a=48, 16 | in_channels_t=16, 17 | beta_min=0.03, 18 | ): 19 | super().__init__() 20 | self.typ = typ 21 | self.D = D 22 | self.W = W 23 | self.skips = skips 24 | self.in_channels_xyz = in_channels_xyz 25 | self.in_channels_dir = in_channels_dir 26 | 27 | self.encode_dynamic = False if typ == "coarse" else encode_dynamic 28 | self.in_channels_a = in_channels_a if encode_dynamic else 0 29 | self.in_channels_t = in_channels_t 30 | self.beta_min = beta_min 31 | 32 | # xyz encoding layers 33 | for i in range(D): 34 | if i == 0: 35 | layer = nn.Linear(in_channels_xyz, W) 36 | elif i in skips: 37 | layer = nn.Linear(W + in_channels_xyz, W) 38 | else: 39 | layer = nn.Linear(W, W) 40 | layer = nn.Sequential(layer, nn.ReLU(True)) 41 | setattr(self, f"xyz_encoding_{i+1}", layer) 42 | self.xyz_encoding_final = nn.Linear(W, W) 43 | 44 | # direction encoding layers 45 | self.dir_encoding = nn.Sequential( 46 | nn.Linear(W + in_channels_dir + self.in_channels_a, W // 2), nn.ReLU(True) 47 | ) 48 | 49 | # static output layers 50 | self.static_sigma = nn.Sequential(nn.Linear(W, 1), nn.Softplus()) 51 | self.static_rgb = nn.Sequential(nn.Linear(W // 2, 3), nn.Sigmoid()) 52 | 53 | # initialise transient model 54 | if self.encode_dynamic: 55 | # transient encoding layers 56 | in_channels = W + in_channels_t 57 | self.transient_encoding = nn.Sequential( 58 | nn.Linear(in_channels, W // 2), 59 | nn.ReLU(True), 60 | nn.Linear(W // 2, W // 2), 61 | nn.ReLU(True), 62 | nn.Linear(W // 2, W // 2), 63 | nn.ReLU(True), 64 | nn.Linear(W // 2, W // 2), 65 | nn.ReLU(True), 66 | ) 67 | # transient output layers 68 | self.transient_sigma = nn.Sequential(nn.Linear(W // 2, 1), nn.Softplus()) 69 | self.transient_rgb = nn.Sequential(nn.Linear(W // 2, 3), nn.Sigmoid()) 70 | self.transient_beta = nn.Sequential(nn.Linear(W // 2, 1), nn.Softplus()) 71 | 72 | # initialise actor model, same architecture as transient 73 | self.person_encoding = nn.Sequential( 74 | nn.Linear(in_channels, W // 2), 75 | nn.ReLU(True), 76 | nn.Linear(W // 2, W // 2), 77 | nn.ReLU(True), 78 | nn.Linear(W // 2, W // 2), 79 | nn.ReLU(True), 80 | nn.Linear(W // 2, W // 2), 81 | nn.ReLU(True), 82 | ) 83 | # actor output layers 84 | self.person_sigma = nn.Sequential(nn.Linear(W // 2, 1), nn.Softplus()) 85 | self.person_rgb = nn.Sequential(nn.Linear(W // 2, 3), nn.Sigmoid()) 86 | self.person_beta = nn.Sequential(nn.Linear(W // 2, 1), nn.Softplus()) 87 | 88 | def forward(self, x, sigma_only=False, output_dynamic=True): 89 | if sigma_only: 90 | """ 91 | For inference. Inputs for static model. We need only sigmas for sampling depths 92 | during inference for fine model. The rendering of the coarse model is not required. 93 | """ 94 | input_xyz = x 95 | elif output_dynamic: 96 | """Inputs when training/inferring with actor volume.""" 97 | input_xyz, input_dir_a, input_t, input_xyz_c = torch.split( 98 | x, 99 | [ 100 | self.in_channels_xyz, 101 | self.in_channels_dir + self.in_channels_a, 102 | self.in_channels_t, 103 | self.in_channels_xyz, 104 | ], 105 | dim=-1, 106 | ) 107 | else: 108 | """ 109 | Inputs for static model during training. Compared to the case of 'sigma_only', 110 | we also need the colours of the coarse model since the final loss depends on the 111 | rendering of the coarse *and* fine model. 112 | """ 113 | input_xyz, input_dir_a = torch.split( 114 | x, 115 | [self.in_channels_xyz, self.in_channels_dir + self.in_channels_a], 116 | dim=-1, 117 | ) 118 | 119 | xyz_ = input_xyz 120 | for i in range(self.D): 121 | if i in self.skips: 122 | xyz_ = torch.cat([input_xyz, xyz_], 1) 123 | xyz_ = getattr(self, f"xyz_encoding_{i+1}")(xyz_) 124 | 125 | static_sigma = self.static_sigma(xyz_) 126 | if sigma_only: 127 | return static_sigma 128 | 129 | xyz_encoding_final = self.xyz_encoding_final(xyz_) 130 | dir_encoding_input = torch.cat([xyz_encoding_final, input_dir_a], 1) 131 | dir_encoding = self.dir_encoding(dir_encoding_input) 132 | static_rgb = self.static_rgb(dir_encoding) 133 | static = torch.cat([static_rgb, static_sigma], 1) 134 | 135 | if not output_dynamic: 136 | # then return only outputs of static model 137 | return static 138 | 139 | # otherwise continue with transient model 140 | transient_encoding_input = torch.cat([xyz_encoding_final, input_t], 1) 141 | transient_encoding = self.transient_encoding(transient_encoding_input) 142 | transient_sigma = self.transient_sigma(transient_encoding) # (B, 1) 143 | transient_rgb = self.transient_rgb(transient_encoding) # (B, 3) 144 | transient_beta = self.transient_beta(transient_encoding) # (B, 1) 145 | transient = torch.cat( 146 | [transient_rgb, transient_sigma, transient_beta], 1 147 | ) # (B, 5) 148 | 149 | # continue with actor model 150 | input_pad = torch.zeros( 151 | len(input_t), 152 | transient_encoding_input.shape[1] 153 | - (input_xyz_c.shape[1] + input_t.shape[1]), 154 | ).to(input_t.device) 155 | person_encoding_input = torch.cat([input_xyz_c, input_t, input_pad], dim=1) 156 | person_encoding = self.person_encoding(person_encoding_input) 157 | person_sigma = self.person_sigma(person_encoding) # (B, 1) 158 | person_rgb = self.person_rgb(person_encoding) # (B, 3) 159 | person_beta = self.person_beta(person_encoding) # (B, 1) 160 | 161 | person = torch.cat([person_rgb, person_sigma, person_beta], 1) # (B, 5) 162 | 163 | # final outputs contain static, transient and person model 164 | return torch.cat([static, transient, person], 1) # (B, 9 + 5) = (B, 14) 165 | -------------------------------------------------------------------------------- /model/rendering.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange, reduce, repeat 3 | 4 | 5 | def sample_pdf(bins, weights, N_importance, det=False, eps=1e-5): 6 | """ 7 | Sample @N_importance samples from @bins with distribution defined by @weights. 8 | Inputs: 9 | bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2" 10 | weights: (N_rays, N_samples_) 11 | N_importance: the number of samples to draw from the distribution 12 | det: deterministic or not 13 | eps: a small number to prevent division by zero 14 | Outputs: 15 | samples: the sampled samples 16 | """ 17 | N_rays, N_samples_ = weights.shape 18 | weights = weights + eps 19 | pdf = weights / reduce(weights, "n1 n2 -> n1 1", "sum") 20 | cdf = torch.cumsum(pdf, -1) 21 | cdf = torch.cat([torch.zeros_like(cdf[:, :1]), cdf], -1) 22 | # padded to 0~1 inclusive 23 | 24 | u = torch.rand(N_rays, N_importance, device=bins.device) 25 | if det: 26 | u = torch.linspace(0, 1, N_importance, device=bins.device) 27 | u = u.expand(N_rays, N_importance) 28 | u = u.contiguous() 29 | 30 | inds = torch.searchsorted(cdf, u, right=True) 31 | below = torch.clamp_min(inds - 1, 0) 32 | above = torch.clamp_max(inds, N_samples_) 33 | 34 | inds_sampled = rearrange( 35 | torch.stack([below, above], -1), "n1 n2 c -> n1 (n2 c)", c=2 36 | ) 37 | cdf_g = rearrange(torch.gather(cdf, 1, inds_sampled), "n1 (n2 c) -> n1 n2 c", c=2) 38 | bins_g = rearrange(torch.gather(bins, 1, inds_sampled), "n1 (n2 c) -> n1 n2 c", c=2) 39 | 40 | # denom equals 0 means a bin has weight 0, in which case it will not be sampled 41 | # anyway, therefore any value for it is fine (set to 1 here) 42 | denom = cdf_g[..., 1] - cdf_g[..., 0] 43 | denom[denom < eps] = 1 44 | 45 | samples = bins_g[..., 0] + (u - cdf_g[..., 0]) / denom * ( 46 | bins_g[..., 1] - bins_g[..., 0] 47 | ) 48 | return samples 49 | 50 | 51 | def render_rays( 52 | models, 53 | embeddings, 54 | rays, 55 | ts, 56 | N_samples=64, 57 | perturb=0, 58 | noise_std=1, 59 | N_importance=0, 60 | chunk=1024 * 32, 61 | test_time=False, 62 | **kwargs, 63 | ): 64 | """ 65 | Render rays by computing the output of @model applied on @rays and @ts 66 | Inputs: 67 | models: dict of models (coarse and fine) 68 | embeddings: dict of embedding models of origin and direction 69 | rays: (N_rays, 3+3), ray origins and directions 70 | ts: (N_rays), ray time as embedding index 71 | N_samples: number of coarse samples per ray 72 | perturb: factor to perturb the sampling position on the ray (for coarse model only) 73 | noise_std: factor to perturb the model's prediction of sigma 74 | N_importance: number of fine samples per ray 75 | chunk: the chunk size in batched inference 76 | test_time: whether it is test (inference only) or not. If True, it will not do inference 77 | on coarse rgb to save time 78 | Outputs: 79 | result: dictionary containing final rgb and depth maps for coarse and fine models 80 | """ 81 | 82 | def inference(results, model, xyz, z_vals, test_time=False, xyz_c=None, **kwargs): 83 | """ 84 | Helper function that performs model inference. 85 | Inputs: 86 | results: a dict storing all results 87 | model: coarse or fine model 88 | xyz: (N_rays, N_samples_, 3) sampled positions 89 | z_vals: (N_rays, N_samples_) depths of the sampled positions 90 | test_time: test time or not 91 | xyz_c: (N_rays, N_samples_, 3) sampled positions w.r.t. camera coordinates 92 | """ 93 | typ = model.typ 94 | N_samples_ = xyz.shape[1] 95 | xyz_ = rearrange(xyz, "n1 n2 c -> (n1 n2) c", c=3) 96 | xyz_c_ = rearrange(xyz_c, "n1 n2 c -> (n1 n2) c", c=3) 97 | 98 | # Perform model inference to get rgb, sigma 99 | B = xyz_.shape[0] 100 | out_chunks = [] 101 | xyz_c_embedded_ = embedding_xyz(xyz_c_) 102 | if typ == "coarse" and test_time: 103 | for i in range(0, B, chunk): 104 | xyz_embedded = embedding_xyz(xyz_[i : i + chunk]) 105 | out_chunks += [model(xyz_embedded, sigma_only=True)] 106 | out = torch.cat(out_chunks, 0) 107 | static_sigmas = rearrange( 108 | out, "(n1 n2) 1 -> n1 n2", n1=N_rays, n2=N_samples_ 109 | ) 110 | else: 111 | # 112 | dir_embedded_ = repeat(dir_embedded, "n1 c -> (n1 n2) c", n2=N_samples_) 113 | # create other necessary inputs 114 | if output_dynamic: 115 | a_embedded_ = repeat(a_embedded, "n1 c -> (n1 n2) c", n2=N_samples_) 116 | t_embedded_ = repeat(t_embedded, "n1 c -> (n1 n2) c", n2=N_samples_) 117 | for i in range(0, B, chunk): 118 | # inputs for original NeRF 119 | inputs = [ 120 | embedding_xyz(xyz_[i : i + chunk]), 121 | dir_embedded_[i : i + chunk], 122 | ] 123 | # additional inputs 124 | if output_dynamic: 125 | inputs += [a_embedded_[i : i + chunk]] 126 | inputs += [t_embedded_[i : i + chunk]] 127 | inputs += [embedding_xyz(xyz_c_[i : i + chunk])] 128 | out_chunks += [ 129 | model( 130 | torch.cat(inputs, 1), 131 | output_dynamic=output_dynamic, 132 | ) 133 | ] 134 | 135 | out = torch.cat(out_chunks, 0) 136 | out = rearrange(out, "(n1 n2) c -> n1 n2 c", n1=N_rays, n2=N_samples_) 137 | static_rgbs = out[..., :3] # (N_rays, N_samples_, 3) 138 | static_sigmas = out[..., 3] # (N_rays, N_samples_) 139 | if output_dynamic: 140 | transient_rgbs = out[..., 4:7] 141 | transient_sigmas = out[..., 7] 142 | transient_betas = out[..., 8] 143 | person_rgbs = out[..., 9:12] 144 | person_sigmas = out[..., 12] 145 | person_betas = out[..., 13] 146 | 147 | if hp.inference and hp.suppress_person: 148 | # disables person during inference, e.g. for visualising videos 149 | person_sigmas[:] = 0 150 | 151 | if test_time: 152 | n_channels = 1 + output_dynamic * 2 153 | stat = torch.zeros([*static_rgbs.shape[:2], n_channels]).to( 154 | static_rgbs.device 155 | ) 156 | stat[:, :, 0] = 1 157 | 158 | static_rgbs = torch.cat([static_rgbs, stat], dim=2) 159 | if output_dynamic: 160 | tran = torch.zeros([*static_rgbs.shape[:2], n_channels]).to( 161 | static_rgbs.device 162 | ) 163 | tran[:, :, 1] = 1 164 | transient_rgbs = torch.cat([transient_rgbs, tran], dim=2) 165 | pers = torch.zeros([*static_rgbs.shape[:2], n_channels]).to( 166 | static_rgbs.device 167 | ) 168 | pers[:, :, 2] = 1 169 | person_rgbs = torch.cat([person_rgbs, pers], dim=2) 170 | 171 | # Convert these values using volume rendering 172 | deltas = z_vals[:, 1:] - z_vals[:, :-1] # (N_rays, N_samples_-1) 173 | delta_inf = 1e2 * torch.ones_like(deltas[:, :1]) 174 | # (N_rays, 1) the last delta is infinity 175 | deltas = torch.cat([deltas, delta_inf], -1) # (N_rays, N_samples_) 176 | 177 | # add RGB noise to last segments of rays to avoid trivial rendering of "black" colours (0,0,0) 178 | # "transmittance fix" 179 | if not typ == "coarse": 180 | if test_time: 181 | pass 182 | else: 183 | static_sigmas[:, -1:] = 100 184 | static_rgbs[:, -1, :3] = torch.rand(static_rgbs.shape[0], 3).to( 185 | static_rgbs.device 186 | ) 187 | 188 | if output_dynamic: 189 | # for colour normalisation as described in "Improved color mixing" 190 | sum_sigmas = static_sigmas + transient_sigmas + person_sigmas 191 | alphas = 1 - torch.exp(-deltas * (sum_sigmas)) 192 | 193 | # ignore normalisation for last value to stabilise inf delta (described above) 194 | static_alphas = static_sigmas / sum_sigmas * alphas 195 | static_alphas[:, -1:] = (1 - torch.exp(-deltas * static_sigmas))[:, -1:] 196 | transient_alphas = transient_sigmas / sum_sigmas * alphas 197 | transient_alphas[:, -1:] = (1 - torch.exp(-deltas * transient_sigmas))[ 198 | :, -1: 199 | ] 200 | person_alphas = person_sigmas / sum_sigmas * alphas 201 | person_alphas[:, -1:] = (1 - torch.exp(-deltas * person_sigmas))[:, -1:] 202 | 203 | results[f"static_alphas_{typ}"] = static_alphas 204 | results[f"transient_alphas_{typ}"] = transient_alphas 205 | results[f"person_alphas_{typ}"] = person_alphas 206 | else: 207 | noise = torch.randn_like(static_sigmas) * noise_std 208 | alphas = 1 - torch.exp(-deltas * torch.relu(static_sigmas + noise)) 209 | 210 | results[f"alphas_{typ}"] = alphas 211 | 212 | alphas_shifted = torch.cat( 213 | [torch.ones_like(alphas[:, :1]), 1 - alphas], -1 214 | ) # [1, 1-a1, 1-a2, ...] 215 | transmittance = torch.cumprod( 216 | alphas_shifted[:, :-1], -1 217 | ) # [1, 1-a1, (1-a1)(1-a2), ...] 218 | 219 | if not (typ == "coarse" and test_time): 220 | results[f"static_rgbs_{typ}"] = static_rgbs 221 | 222 | results[f"transmittance_{typ}"] = transmittance 223 | 224 | if output_dynamic: 225 | static_weights = static_alphas * transmittance 226 | results[f"static_alphas_{typ}"] = static_alphas 227 | transient_weights = transient_alphas * transmittance 228 | transient_weights_sum = reduce(transient_weights, "n1 n2 -> n1", "sum") 229 | person_weights = person_alphas * transmittance 230 | person_weights_sum = reduce(person_weights, "n1 n2 -> n1", "sum") 231 | 232 | weights = alphas * transmittance 233 | weights_sum = reduce(weights, "n1 n2 -> n1", "sum") 234 | 235 | results[f"weights_{typ}"] = weights 236 | results[f"static_weights_{typ}"] = weights 237 | results[f"static_weights_sum_{typ}"] = weights_sum 238 | results[f"opacity_{typ}"] = weights_sum 239 | results[f"static_sigmas_{typ}"] = static_sigmas 240 | if output_dynamic: 241 | results["transient_sigmas"] = transient_sigmas 242 | results["transient_weights"] = transient_weights 243 | results["transient_weights_sum"] = transient_weights_sum 244 | results["person_sigmas"] = person_sigmas 245 | results["person_weights"] = person_weights 246 | results["person_weights_sum"] = person_weights_sum 247 | if test_time and typ == "coarse": 248 | return 249 | 250 | if output_dynamic: 251 | static_rgb_map = reduce( 252 | rearrange(static_weights, "n1 n2 -> n1 n2 1") * static_rgbs, 253 | "n1 n2 c -> n1 c", 254 | "sum", 255 | ) 256 | 257 | transient_rgb_map = reduce( 258 | rearrange(transient_weights, "n1 n2 -> n1 n2 1") * transient_rgbs, 259 | "n1 n2 c -> n1 c", 260 | "sum", 261 | ) 262 | results["beta"] = reduce( 263 | transient_weights * transient_betas, "n1 n2 -> n1", "sum" 264 | ) 265 | 266 | # the rgb maps here are when both fields exist 267 | results["_rgb_fine_static"] = static_rgb_map 268 | results["_rgb_fine_transient"] = transient_rgb_map 269 | results["rgb_fine"] = static_rgb_map + transient_rgb_map 270 | 271 | person_rgb_map = reduce( 272 | rearrange(person_weights, "n1 n2 -> n1 n2 1") * person_rgbs, 273 | "n1 n2 c -> n1 c", 274 | "sum", 275 | ) 276 | results["beta"] = results["beta"] + reduce( 277 | person_weights * person_betas, "n1 n2 -> n1", "sum" 278 | ) 279 | 280 | # the rgb maps here are when both fields exist 281 | results["_rgb_fine_person"] = person_rgb_map 282 | results["rgb_fine"] = results["rgb_fine"] + person_rgb_map 283 | 284 | results["beta"] += model.beta_min 285 | 286 | if test_time: 287 | static_alphas_shifted = torch.cat( 288 | [torch.ones_like(static_alphas[:, :1]), 1 - static_alphas], -1 289 | ) 290 | static_transmittance = torch.cumprod(static_alphas_shifted[:, :-1], -1) 291 | static_weights_ = static_alphas * static_transmittance 292 | static_rgb_map_ = reduce( 293 | rearrange(static_weights_, "n1 n2 -> n1 n2 1") * static_rgbs, 294 | "n1 n2 c -> n1 c", 295 | "sum", 296 | ) 297 | results["rgb_fine_static"] = static_rgb_map_ 298 | results["depth_fine_static"] = reduce( 299 | static_weights_ * z_vals, "n1 n2 -> n1", "sum" 300 | ) 301 | 302 | transient_alphas_shifted = torch.cat( 303 | [torch.ones_like(transient_alphas[:, :1]), 1 - transient_alphas], -1 304 | ) 305 | transient_transmittance = torch.cumprod( 306 | transient_alphas_shifted[:, :-1], -1 307 | ) 308 | transient_weights_ = transient_alphas * transient_transmittance 309 | results["rgb_fine_transient"] = reduce( 310 | rearrange(transient_weights_, "n1 n2 -> n1 n2 1") * transient_rgbs, 311 | "n1 n2 c -> n1 c", 312 | "sum", 313 | ) 314 | results["depth_fine_transient"] = reduce( 315 | transient_weights_ * z_vals, "n1 n2 -> n1", "sum" 316 | ) 317 | 318 | person_alphas_shifted = torch.cat( 319 | [torch.ones_like(person_alphas[:, :1]), 1 - person_alphas], -1 320 | ) 321 | person_transmittance = torch.cumprod( 322 | person_alphas_shifted[:, :-1], -1 323 | ) 324 | person_weights_ = person_alphas * person_transmittance 325 | results["rgb_fine_person"] = reduce( 326 | rearrange(person_weights_, "n1 n2 -> n1 n2 1") * person_rgbs, 327 | "n1 n2 c -> n1 c", 328 | "sum", 329 | ) 330 | results["depth_fine_person"] = reduce( 331 | person_weights_ * z_vals, "n1 n2 -> n1", "sum" 332 | ) 333 | 334 | else: # no transient field 335 | rgb_map = reduce( 336 | rearrange(weights, "n1 n2 -> n1 n2 1") * static_rgbs, 337 | "n1 n2 c -> n1 c", 338 | "sum", 339 | ) 340 | results[f"rgb_{typ}"] = rgb_map 341 | 342 | results[f"depth_{typ}"] = reduce(weights * z_vals, "n1 n2 -> n1", "sum") 343 | return 344 | 345 | hp = kwargs["hp"] 346 | 347 | embedding_xyz, embedding_dir = embeddings["xyz"], embeddings["dir"] 348 | 349 | # separate input into ray origins, directions, near, far bounds etc. 350 | N_rays = rays.shape[0] 351 | rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] 352 | near, far = rays[:, 6:7], rays[:, 7:8] 353 | rays_o_c, rays_d_c = rays[:, 8:11], rays[:, 11:14] 354 | dir_embedded = embedding_dir(kwargs.get("view_dir", rays_d)) 355 | 356 | rays_o = rearrange(rays_o, "n1 c -> n1 1 c") 357 | rays_d = rearrange(rays_d, "n1 c -> n1 1 c") 358 | 359 | rays_o_c = rearrange(rays_o_c, "n1 c -> n1 1 c") 360 | rays_d_c = rearrange(rays_d_c, "n1 c -> n1 1 c") 361 | 362 | # sample depth points 363 | z_steps = torch.linspace(0, 1, N_samples, device=rays.device) 364 | z_vals = near * (1 - z_steps) + far * z_steps 365 | 366 | z_vals = z_vals.expand(N_rays, N_samples) 367 | 368 | # perturb sampling depths (z_vals) 369 | perturb_rand = perturb * torch.rand_like(z_vals) 370 | if perturb > 0: 371 | # (N_rays, N_samples-1) interval mid points 372 | z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:]) 373 | # get intervals between samples 374 | upper = torch.cat([z_vals_mid, z_vals[:, -1:]], -1) 375 | lower = torch.cat([z_vals[:, :1], z_vals_mid], -1) 376 | 377 | z_vals = lower + (upper - lower) * perturb_rand 378 | 379 | results = {} 380 | xyz_coarse = rays_o + rays_d * rearrange(z_vals, "n1 n2 -> n1 n2 1") 381 | xyz_coarse_c = rays_o_c + rays_d_c * rearrange(z_vals, "n1 n2 -> n1 n2 1") 382 | 383 | # disable transient and person model for coarse model 384 | output_dynamic = False 385 | inference( 386 | results, 387 | models["coarse"], 388 | xyz_coarse, 389 | z_vals, 390 | test_time, 391 | xyz_c=xyz_coarse_c, 392 | **kwargs, 393 | ) 394 | 395 | # then continue with fine model by sampling from z_vals of coarse model 396 | if N_importance > 0: 397 | z_vals_mid = 0.5 * ( 398 | z_vals[:, :-1] + z_vals[:, 1:] 399 | ) # (N_rays, N_samples-1) interval mid points 400 | z_vals_ = sample_pdf( 401 | z_vals_mid, 402 | results["weights_coarse"][:, 1:-1].detach(), 403 | N_importance, 404 | det=(perturb == 0), 405 | ) 406 | # detach so that grad doesn't propogate to weights_coarse from here 407 | z_vals = torch.sort(torch.cat([z_vals, z_vals_], -1), -1)[0] 408 | xyz_fine = rays_o + rays_d * rearrange(z_vals, "n1 n2 -> n1 n2 1") 409 | xyz_fine_c = rays_o_c + rays_d_c * rearrange(z_vals, "n1 n2 -> n1 n2 1") 410 | 411 | model = models["fine"] 412 | output_dynamic = True 413 | t_embedded = embeddings["t"](ts) 414 | a_embedded = embeddings["a"](ts) 415 | inference( 416 | results, model, xyz_fine, z_vals, test_time, xyz_c=xyz_fine_c, **kwargs 417 | ) 418 | 419 | return results 420 | -------------------------------------------------------------------------------- /notebook/eval.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "os.chdir('../')" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import torch\n", 20 | "import utils" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "# Initialise model and dataset" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "vid = 'P01_01'\n", 37 | "epoch = 9" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "from dataset import EPICDiff" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "dataset = EPICDiff(vid, split='test')" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "ckpt_path = f'ckpts/rel/{vid}/epoch={epoch}.ckpt'" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "models = utils.init_model(ckpt_path, dataset)" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "# Evaluate first 5 test images of scene" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "from dataset import MaskLoader\n", 90 | "import evaluation" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "maskloader = MaskLoader(\n", 100 | " dataset=dataset,\n", 101 | " is_debug=True\n", 102 | ")" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": { 109 | "scrolled": true 110 | }, 111 | "outputs": [], 112 | "source": [ 113 | "results = evaluation.evaluate(\n", 114 | " dataset,\n", 115 | " models,\n", 116 | " maskloader,\n", 117 | " vis_i=1,\n", 118 | " save=True,\n", 119 | " save_dir='results/test',\n", 120 | " vid=vid,\n", 121 | " image_ids=dataset.img_ids_test[:5]\n", 122 | ")" 123 | ] 124 | } 125 | ], 126 | "metadata": { 127 | "kernelspec": { 128 | "display_name": "nerfw3.7", 129 | "language": "python", 130 | "name": "nerfw3.7" 131 | }, 132 | "language_info": { 133 | "codemirror_mode": { 134 | "name": "ipython", 135 | "version": 3 136 | }, 137 | "file_extension": ".py", 138 | "mimetype": "text/x-python", 139 | "name": "python", 140 | "nbconvert_exporter": "python", 141 | "pygments_lexer": "ipython3", 142 | "version": "3.7.12" 143 | } 144 | }, 145 | "nbformat": 4, 146 | "nbformat_minor": 4 147 | } 148 | -------------------------------------------------------------------------------- /opt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import git 4 | 5 | 6 | def get_opts(vid=None, root="data/EPIC-Diff"): 7 | parser = argparse.ArgumentParser() 8 | 9 | parser.add_argument( 10 | "--root", type=str, default=root, help="Root directory of dataset." 11 | ) 12 | parser.add_argument( 13 | "--N_emb_xyz", type=int, default=10, help="Number of xyz embedding frequencies." 14 | ) 15 | parser.add_argument( 16 | "--N_emb_dir", 17 | type=int, 18 | default=4, 19 | help="Number of direction embedding frequencies.", 20 | ) 21 | parser.add_argument( 22 | "--N_samples", type=int, default=64, help="Number of coarse samples." 23 | ) 24 | parser.add_argument( 25 | "--N_importance", 26 | type=int, 27 | default=64, 28 | help="Number of additional fine samples.", 29 | ) 30 | parser.add_argument( 31 | "--perturb", 32 | type=float, 33 | default=1.0, 34 | help="Factor to perturb depth sampling points.", 35 | ) 36 | parser.add_argument( 37 | "--noise_std", 38 | type=float, 39 | default=1.0, 40 | help="Std dev of noise added to regularize sigma.", 41 | ) 42 | parser.add_argument( 43 | "--N_vocab", 44 | type=int, 45 | default=1000, 46 | help="Number of frames (max. 1000 for our dataset).", 47 | ) 48 | parser.add_argument( 49 | "--N_a", type=int, default=48, help="Embedding size for appearance encoding." 50 | ) 51 | parser.add_argument( 52 | "--N_tau", 53 | type=int, 54 | default=17, 55 | help="Embedding size for transient encoding.", 56 | ) 57 | parser.add_argument( 58 | "--beta_min", 59 | type=float, 60 | default=0.03, 61 | help="Minimum color variance for loss.", 62 | ) 63 | parser.add_argument("--batch_size", type=int, default=1024, help="Batch size.") 64 | parser.add_argument( 65 | "--chunk", 66 | type=int, 67 | default=32 * 1024, 68 | help="Chunk size to split the input to avoid reduce memory footprint.", 69 | ) 70 | parser.add_argument( 71 | "--num_epochs", type=int, default=10, help="Number of training epochs." 72 | ) 73 | parser.add_argument("--num_gpus", type=int, default=1, help="Number of gpus.") 74 | parser.add_argument( 75 | "--ckpt_path", 76 | type=str, 77 | default=None, 78 | help="Pretrained checkpoint path to load.", 79 | ) 80 | parser.add_argument("--lr", type=float, default=5e-4, help="Learning rate.") 81 | parser.add_argument("--weight_decay", type=float, default=0, help="Weight decay.") 82 | parser.add_argument("--exp_name", type=str, default="exp", help="Experiment name.") 83 | parser.add_argument( 84 | "--refresh_every", 85 | type=int, 86 | default=1, 87 | help="print the progress bar every X steps", 88 | ) 89 | parser.add_argument("-f", type=str, default="", help="For Jupyter.") 90 | parser.add_argument( 91 | "--lowpass_K", 92 | type=int, 93 | default=21, 94 | help="K for low rank expansion of transient encoding.", 95 | ) 96 | parser.add_argument( 97 | "--train_ratio", 98 | type=float, 99 | default=1.0, 100 | help="Fraction of train dataset to use per epoch. For debugging.", 101 | ) 102 | parser.add_argument( 103 | "--model_width", type=int, default=256, help="Width of model (units per layer)." 104 | ) 105 | parser.add_argument( 106 | "--num_workers", type=int, default=4, help="Number of workers for dataloaders." 107 | ) 108 | parser.add_argument("--vid", type=str, default=vid, help="Video ID of dataset.") 109 | parser.add_argument( 110 | "--deterministic", default=True, action="store_true", help="Reproducibility." 111 | ) 112 | parser.add_argument( 113 | "--inference", 114 | default=False, 115 | action="store_true", 116 | help="For compatibility with evaluation script.", 117 | ) 118 | 119 | hparams, unknown = parser.parse_known_args() 120 | if unknown: 121 | # for compabitibility with evaluation script. 122 | if "--is_eval_script" not in unknown: 123 | print("--- unrecognised arguments ---") 124 | print(unknown) 125 | exit() 126 | hparams.git_train = git.Repo(search_parent_directories=True).head.object.hexsha 127 | # placeholders for eval script 128 | hparams.git_eval = "" 129 | hparams.ckpt_path_eval = "" 130 | 131 | return hparams 132 | 133 | 134 | if __name__ == "__main__": 135 | hparams = get_opts("example") 136 | 137 | print("Argparse options:") 138 | for k, v in hparams.__dict__.items(): 139 | print(f"{k}: {v}") 140 | -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | 2 | CKP=ckpts/$1 3 | VID=$2 4 | EXP=$3 5 | OUT=$4 6 | MASKS_N_SAMPLES=$5 7 | SUMMARY_N_SAMPLES=$6 8 | 9 | EPOCH=9 10 | 11 | CUDA_VISIBLE_DEVICES=0 python evaluate.py \ 12 | --path $CKP\/$VID\/epoch\=$EPOCH\.ckpt \ 13 | --vid $VID --exp $EXP \ 14 | --is_eval_script \ 15 | --outputs $OUT \ 16 | --masks_n_samples $MASKS_N_SAMPLES \ 17 | --summary_n_samples $SUMMARY_N_SAMPLES 18 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | 2 | VID=$1; CUDA_VISIBLE_DEVICES=0 python train.py \ 3 | --vid $VID \ 4 | --exp_name rel/$VID \ 5 | --train_ratio 1 --num_epochs 10 6 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | 4 | import matplotlib.pyplot as plt 5 | import pytorch_lightning 6 | import torch 7 | from torch.optim import Adam 8 | from torch.optim.lr_scheduler import CosineAnnealingLR 9 | from torch.utils.data import DataLoader 10 | 11 | import dataset 12 | import model 13 | import utils 14 | from evaluation.metrics import * 15 | from loss import Loss 16 | from opt import get_opts 17 | from utils import * 18 | 19 | 20 | class NeuralDiffSystem(pytorch_lightning.LightningModule): 21 | def __init__(self, hparams, train_dataset=None, val_dataset=None): 22 | super().__init__() 23 | self.hparams = hparams 24 | if self.hparams.deterministic: 25 | utils.set_deterministic() 26 | 27 | # for avoiding reinitialization of dataloaders when debugging/using notebook 28 | self.train_dataset = train_dataset 29 | self.val_dataset = val_dataset 30 | 31 | self.loss = Loss() 32 | 33 | self.models_to_train = [] 34 | self.embedding_xyz = model.PosEmbedding( 35 | hparams.N_emb_xyz - 1, hparams.N_emb_xyz 36 | ) 37 | self.embedding_dir = model.PosEmbedding( 38 | hparams.N_emb_dir - 1, hparams.N_emb_dir 39 | ) 40 | 41 | self.embeddings = { 42 | "xyz": self.embedding_xyz, 43 | "dir": self.embedding_dir, 44 | } 45 | 46 | self.embedding_t = model.LREEmbedding( 47 | N=hparams.N_vocab, D=hparams.N_tau, K=hparams.lowpass_K 48 | ) 49 | self.embeddings["t"] = self.embedding_t 50 | self.models_to_train += [self.embedding_t] 51 | 52 | self.embedding_a = torch.nn.Embedding(hparams.N_vocab, hparams.N_a) 53 | self.embeddings["a"] = self.embedding_a 54 | self.models_to_train += [self.embedding_a] 55 | 56 | self.nerf_coarse = model.NeuralDiff( 57 | "coarse", 58 | in_channels_xyz=6 * hparams.N_emb_xyz + 3, 59 | in_channels_dir=6 * hparams.N_emb_dir + 3, 60 | W=hparams.model_width, 61 | ) 62 | self.models = {"coarse": self.nerf_coarse} 63 | if hparams.N_importance > 0: 64 | self.nerf_fine = model.NeuralDiff( 65 | "fine", 66 | in_channels_xyz=6 * hparams.N_emb_xyz + 3, 67 | in_channels_dir=6 * hparams.N_emb_dir + 3, 68 | encode_dynamic=True, 69 | in_channels_a=hparams.N_a, 70 | in_channels_t=hparams.N_tau, 71 | beta_min=hparams.beta_min, 72 | W=hparams.model_width, 73 | ) 74 | self.models["fine"] = self.nerf_fine 75 | self.models_to_train += [self.models] 76 | 77 | def get_progress_bar_dict(self): 78 | items = super().get_progress_bar_dict() 79 | items.pop("v_num", None) 80 | return items 81 | 82 | def forward(self, rays, ts, test_time=False, disable_perturb=False): 83 | perturb = 0 if test_time or disable_perturb else self.hparams.perturb 84 | noise_std = 0 if test_time or disable_perturb else self.hparams.noise_std 85 | B = rays.shape[0] 86 | results = defaultdict(list) 87 | for i in range(0, B, self.hparams.chunk): 88 | rendered_ray_chunks = model.render_rays( 89 | models=self.models, 90 | embeddings=self.embeddings, 91 | rays=rays[i : i + self.hparams.chunk], 92 | ts=ts[i : i + self.hparams.chunk], 93 | N_samples=self.hparams.N_samples, 94 | perturb=perturb, 95 | noise_std=noise_std, 96 | N_importance=self.hparams.N_importance, 97 | chunk=self.hparams.chunk, 98 | hp=self.hparams, 99 | test_time=test_time, 100 | ) 101 | 102 | for k, v in rendered_ray_chunks.items(): 103 | results[k] += [v] 104 | 105 | for k, v in results.items(): 106 | results[k] = torch.cat(v, 0) 107 | return results 108 | 109 | def setup(self, stage, reset_dataset=False): 110 | kwargs = {"root": self.hparams.root} 111 | kwargs["vid"] = self.hparams.vid 112 | if (self.train_dataset is None and self.val_dataset is None) or reset_dataset: 113 | self.train_dataset = dataset.EPICDiff(split="train", **kwargs) 114 | self.val_dataset = dataset.EPICDiff(split="val", **kwargs) 115 | 116 | def configure_optimizers(self): 117 | eps = 1e-8 118 | self.optimizer = Adam( 119 | get_parameters(self.models_to_train), 120 | lr=self.hparams.lr, 121 | eps=eps, 122 | weight_decay=self.hparams.weight_decay, 123 | ) 124 | scheduler = CosineAnnealingLR( 125 | self.optimizer, T_max=self.hparams.num_epochs, eta_min=eps 126 | ) 127 | return [self.optimizer], [scheduler] 128 | 129 | def train_dataloader(self): 130 | return DataLoader( 131 | self.train_dataset, 132 | shuffle=True, 133 | num_workers=self.hparams.num_workers, 134 | batch_size=self.hparams.batch_size, 135 | pin_memory=True, 136 | ) 137 | 138 | def val_dataloader(self): 139 | # batch_size=1 for validating one image (H*W rays) at a time 140 | return DataLoader( 141 | self.val_dataset, 142 | shuffle=False, 143 | num_workers=self.hparams.num_workers, 144 | batch_size=1, 145 | pin_memory=True, 146 | ) 147 | 148 | def training_step(self, batch, batch_nb): 149 | rays, rgbs, ts = batch["rays"], batch["rgbs"], batch["ts"] 150 | results = self(rays, ts) 151 | loss_d = self.loss(results, rgbs) 152 | loss = sum(l for l in loss_d.values()) 153 | 154 | with torch.no_grad(): 155 | psnr_ = psnr(results["rgb_fine"], rgbs) 156 | 157 | self.log("lr", self.optimizer.param_groups[0]["lr"]) 158 | self.log("train/loss", loss) 159 | for k, v in loss_d.items(): 160 | self.log(f"train/{k}", v, prog_bar=True) 161 | self.log("train/psnr", psnr_, prog_bar=True) 162 | 163 | return loss 164 | 165 | def render(self, sample, t=None, device=None): 166 | 167 | rays, rgbs, ts = ( 168 | sample["rays"].cuda(), 169 | sample["rgbs"].cuda(), 170 | sample["ts"].cuda(), 171 | ) 172 | 173 | if t is not None: 174 | if type(t) is torch.Tensor: 175 | t = t.cuda() 176 | ts = torch.ones_like(ts) * t 177 | 178 | rays = rays.squeeze() # (H*W, 3) 179 | rgbs = rgbs.squeeze() # (H*W, 3) 180 | ts = ts.squeeze() # (H*W) 181 | with torch.no_grad(): 182 | results = self(rays, ts, test_time=True) 183 | 184 | if device is not None: 185 | for k in results: 186 | results[k] = results[k].to(device) 187 | 188 | return results 189 | 190 | def validation_step(self, batch, batch_nb, is_debug=False): 191 | rays, rgbs, ts = batch["rays"], batch["rgbs"], batch["ts"] 192 | 193 | rays = rays.squeeze() # (H*W, 3) 194 | rgbs = rgbs.squeeze() # (H*W, 3) 195 | ts = ts.squeeze() # (H*W) 196 | # disable perturb (used during training), but keep loss for tensorboard 197 | results = self(rays, ts, disable_perturb=True) 198 | loss_d = self.loss(results, rgbs) 199 | loss = sum(l for l in loss_d.values()) 200 | log = {"val_loss": loss} 201 | 202 | if batch_nb == 0: 203 | WH = batch["img_wh"].view(1, 2) 204 | W, H = WH[0, 0].item(), WH[0, 1].item() 205 | img = ( 206 | results["rgb_fine"].view(H, W, 3)[:, :, :3].permute(2, 0, 1).cpu() 207 | ) # (3, H, W) 208 | img_gt = rgbs.view(H, W, 3).permute(2, 0, 1).cpu() # (3, H, W) 209 | depth = visualize_depth(results["depth_fine"].view(H, W)) # (3, H, W) 210 | stack = torch.stack([img_gt, img, depth]) # (3, 3, H, W) 211 | if self.logger is not None: 212 | self.logger.experiment.add_images( 213 | "val/GT_pred_depth", stack, self.global_step 214 | ) 215 | 216 | psnr_ = psnr(results["rgb_fine"], rgbs) 217 | log["val_psnr"] = psnr_ 218 | if is_debug: 219 | # then visualise in jupyter 220 | log["images"] = stack 221 | log["results"] = results 222 | 223 | f, p = plt.subplots(1, 3, figsize=(15, 15)) 224 | for i in range(3): 225 | im = stack[i] 226 | p[i].imshow(im.permute(1, 2, 0).cpu()) 227 | p[i].axis("off") 228 | plt.show() 229 | 230 | return log 231 | 232 | def validation_epoch_end(self, outputs): 233 | mean_loss = torch.stack([x["val_loss"] for x in outputs]).mean() 234 | mean_psnr = torch.stack([x["val_psnr"] for x in outputs]).mean() 235 | 236 | self.log("val/loss", mean_loss) 237 | self.log("val/psnr", mean_psnr, prog_bar=True) 238 | 239 | 240 | def init_trainer(hparams, logger=None, checkpoint_callback=None): 241 | if checkpoint_callback is None: 242 | checkpoint_callback = pytorch_lightning.callbacks.ModelCheckpoint( 243 | filepath=os.path.join(f"ckpts/{hparams.exp_name}", "{epoch:d}"), 244 | monitor="val/psnr", 245 | mode="max", 246 | save_top_k=-1, 247 | ) 248 | 249 | logger = pytorch_lightning.loggers.TestTubeLogger( 250 | save_dir="logs", 251 | name=hparams.exp_name, 252 | debug=False, 253 | create_git_tag=False, 254 | log_graph=False, 255 | ) 256 | 257 | trainer = pytorch_lightning.Trainer( 258 | max_epochs=hparams.num_epochs, 259 | checkpoint_callback=checkpoint_callback, 260 | resume_from_checkpoint=hparams.ckpt_path, 261 | logger=logger, 262 | weights_summary=None, 263 | progress_bar_refresh_rate=hparams.refresh_every, 264 | gpus=hparams.num_gpus, 265 | accelerator="ddp" if hparams.num_gpus > 1 else None, 266 | num_sanity_val_steps=1, 267 | benchmark=True, 268 | limit_train_batches=hparams.train_ratio, 269 | profiler="simple" if hparams.num_gpus == 1 else None, 270 | ) 271 | 272 | return trainer 273 | 274 | 275 | def main(hparams): 276 | system = NeuralDiffSystem(hparams) 277 | trainer = init_trainer(hparams) 278 | trainer.fit(system) 279 | 280 | 281 | if __name__ == "__main__": 282 | hparams = get_opts() 283 | main(hparams) 284 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import cv2 4 | import imageio 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import skimage.io 8 | import skimage.transform 9 | import torch 10 | import torchvision 11 | from PIL import Image 12 | 13 | import model 14 | import opt 15 | import train 16 | from dataset import EPICDiff 17 | from evaluation.utils import tqdm 18 | 19 | 20 | def set_deterministic(): 21 | 22 | import random 23 | 24 | import numpy 25 | import torch 26 | 27 | torch.manual_seed(0) 28 | random.seed(0) 29 | numpy.random.seed(0) 30 | torch.backends.cudnn.benchmark = False 31 | 32 | 33 | def adjust_jupyter_argv(): 34 | import sys 35 | 36 | sys.argv = sys.argv[:1] 37 | 38 | 39 | def write_mp4(name, frames, fps=10): 40 | imageio.mimwrite(name + ".mp4", frames, "mp4", fps=fps) 41 | 42 | 43 | def overlay_image(im, im_overlay, coord=(100, 70)): 44 | # assumes that im is 3 channel and im_overlay 4 (with alpha) 45 | alpha = im_overlay[:, :, 3] 46 | offset_rows = im_overlay.shape[0] 47 | offset_cols = im_overlay.shape[1] 48 | row = coord[0] 49 | col = coord[1] 50 | im[row : row + offset_rows, col : col + offset_cols, :] = ( 51 | 1 - alpha[:, :, None] 52 | ) * im[row : row + offset_rows, col : col + offset_cols, :] + alpha[ 53 | :, :, None 54 | ] * im_overlay[ 55 | :, :, :3 56 | ] 57 | return im 58 | 59 | 60 | def get_parameters(models): 61 | """Get all model parameters recursively.""" 62 | parameters = [] 63 | if isinstance(models, list): 64 | for model in models: 65 | parameters += get_parameters(model) 66 | elif isinstance(models, dict): 67 | for model in models.values(): 68 | parameters += get_parameters(model) 69 | else: 70 | # single pytorch model 71 | parameters += list(models.parameters()) 72 | return parameters 73 | 74 | 75 | def visualize_depth(depth, cmap=cv2.COLORMAP_JET): 76 | x = depth.cpu().numpy() 77 | x = np.nan_to_num(x) # change nan to 0 78 | mi = np.min(x) # get minimum depth 79 | ma = np.max(x) 80 | x = (x - mi) / (ma - mi + 1e-8) # normalize to 0~1 81 | x = (255 * x).astype(np.uint8) 82 | x_ = Image.fromarray(cv2.applyColorMap(x, cmap)) 83 | x_ = torchvision.transforms.ToTensor()(x_) # (3, H, W) 84 | return x_ 85 | 86 | 87 | def assign_appearance(ids_train, ids_unassigned): 88 | # described in experiments, (3) NeRF-W: reassign each test embedding to closest train embedding 89 | ids = sorted(ids_train + ids_unassigned) 90 | g = {} 91 | for id in ids_unassigned: 92 | pos = ids.index(id) 93 | if pos == 0: 94 | # then only possible to assign to next embedding 95 | id_reassign = ids[1] 96 | elif pos == len(ids) - 1: 97 | # then only possible to assign to previous embedding 98 | id_reassign = ids[pos - 1] 99 | else: 100 | # otherwise the one that is closes according to frame index 101 | id_prev = ids[pos - 1] 102 | id_next = ids[pos + 1] 103 | id_reassign = min( 104 | (abs(ids[pos] - id_prev), id_prev), (abs(ids[pos] - id_next), id_next) 105 | )[1] 106 | g[ids[pos]] = id_reassign 107 | return g 108 | 109 | 110 | def init_model(ckpt_path, dataset): 111 | ckpt = torch.load(ckpt_path, map_location="cpu") 112 | opt_hp = opt.get_opts(dataset.vid) 113 | for j in ckpt["hyper_parameters"]: 114 | setattr(opt_hp, j, ckpt["hyper_parameters"][j]) 115 | model = train.NeuralDiffSystem( 116 | opt_hp, train_dataset=dataset, val_dataset=dataset 117 | ).cuda() 118 | model.load_state_dict(ckpt["state_dict"]) 119 | 120 | g_test = assign_appearance(dataset.img_ids_train, dataset.img_ids_test) 121 | g_val = assign_appearance(dataset.img_ids_train, dataset.img_ids_val) 122 | 123 | for g in [g_test, g_val]: 124 | for i, i_train in g.items(): 125 | model.embedding_a.weight.data[i] = model.embedding_a.weight.data[ 126 | i_train 127 | ] 128 | 129 | return model 130 | --------------------------------------------------------------------------------