├── .gitignore ├── colabs ├── gaussian_splatting_colab.ipynb └── HyperNerf.ipynb └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | scripts -------------------------------------------------------------------------------- /colabs/gaussian_splatting_colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "id": "VjYy0F2gZIPR" 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "%cd /content\n", 12 | "!git clone --recursive https://github.com/camenduru/gaussian-splatting\n", 13 | "!pip install -q plyfile\n", 14 | "\n", 15 | "%cd /content/gaussian-splatting\n", 16 | "!pip install -q https://huggingface.co/camenduru/gaussian-splatting/resolve/main/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl\n", 17 | "!pip install -q https://huggingface.co/camenduru/gaussian-splatting/resolve/main/simple_knn-0.0.0-cp310-cp310-linux_x86_64.whl\n", 18 | "\n", 19 | "!wget https://huggingface.co/camenduru/gaussian-splatting/resolve/main/tandt_db.zip\n", 20 | "!unzip tandt_db.zip\n", 21 | "\n", 22 | "!python train.py -s /content/gaussian-splatting/tandt/train\n", 23 | "\n", 24 | "# !wget https://huggingface.co/camenduru/gaussian-splatting/resolve/main/GaussianViewTest.zip\n", 25 | "# !unzip GaussianViewTest.zip\n", 26 | "# !python render.py -m /content/gaussian-splatting/GaussianViewTest/model\n", 27 | "# !ffmpeg -framerate 3 -i /content/gaussian-splatting/GaussianViewTest/model/train/ours_30000/renders/%05d.png -vf \"pad=ceil(iw/2)*2:ceil(ih/2)*2\" -c:v libx264 -r 3 -pix_fmt yuv420p /content/renders.mp4\n", 28 | "# !ffmpeg -framerate 3 -i /content/gaussian-splatting/GaussianViewTest/model/train/ours_30000/gt/%05d.png -vf \"pad=ceil(iw/2)*2:ceil(ih/2)*2\" -c:v libx264 -r 3 -pix_fmt yuv420p /content/gt.mp4 -y" 29 | ] 30 | } 31 | ], 32 | "metadata": { 33 | "accelerator": "GPU", 34 | "colab": { 35 | "gpuType": "T4", 36 | "provenance": [] 37 | }, 38 | "kernelspec": { 39 | "display_name": "Python 3", 40 | "name": "python3" 41 | }, 42 | "language_info": { 43 | "name": "python" 44 | } 45 | }, 46 | "nbformat": 4, 47 | "nbformat_minor": 0 48 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dynamic NeRF 2 | 3 | Verified: Papers listed with [+] have been verfied by myself or colleagues. The code is runnable. Please leave an issue if you need help on setting up. 4 | 5 | # 1. Datasets 6 | ## Custom Data Preparation 7 | - [Monocular Dynamic View Synthesis: A Reality Check](https://github.com/KAIR-BAIR/dycheck/blob/main/docs/RECORD3D_CAPTURE.md) 8 | - [Process a video into a Nerfie dataset](https://colab.research.google.com/github/google/nerfies/blob/main/notebooks/Nerfies_Capture_Processing.ipynb) 9 | - [Robust Dynamic Radiance Fields](https://github.com/facebookresearch/robust-dynrf) 10 | Estimate monocular depth, Predict optical flows, Obtain motion mask. 11 | - [Neural Scene Flow Fields](https://github.com/zhengqili/Neural-Scene-Flow-Fields/tree/main) 12 | Instructions for custom data. 13 | 14 | ### Synthetic 15 | - [D-Nerf Dataset](https://www.albertpumarola.com/research/D-NeRF/index.html) 16 | 17 | 18 | ### Real 19 | - [Plenoptic Dataset](https://github.com/facebookresearch/Neural_3D_Video/releases/tag/v1.0) 20 | - [Hypernerf Dataset](https://github.com/google/hypernerf/releases/tag/v0.1) 21 | - [Nerfies Dataset](https://github.com/google/nerfies/releases/download/0.1/nerfies-vrig-dataset-v0.1.zip) 22 | - [Dynamic NeRF](https://github.com/gaochen315/DynamicNeRF) 23 | Balloon1, Balloon2, Jumping, Playground, Skating, Truck, Umbrella 24 | 25 | 26 | # 2. My Notebooks 27 | - Robust Dynamic Radiance Fields, Liu et. al., CVPR, 2023. [[Kaggle](https://www.kaggle.com/code/declanide/robust-nerf)] 28 | 29 | # 3. Papers 30 | ## 2024 31 | - Dynamic 3D Gaussians: Tracking by Persistent Dynamic View Synthesis, Luiten et. al., International Conference on 3D Vision (3DV), 2024. [[Paper](https://dynamic3dgaussians.github.io/paper.pdf) | [Project Page](https://dynamic3dgaussians.github.io/) | [Code](https://github.com/JonathonLuiten/Dynamic3DGaussians) | [Explanation Video](https://www.youtube.com/live/hDuy1TgD8I4?si=6oGN0IYnPRxOibpg)] 32 | - Sync-NeRF : Generalizing Dynamic NeRFs to Unsynchronized Videos, AAAI 2024. [[Paper](https://arxiv.org/abs/2310.13356), [Code](https://github.com/seoha-kim/Sync-NeRF)] 33 | - Endo-4DGS: Endoscopic Monocular Scene Reconstruction with 4D Gaussian Splatting, [[Paper](https://arxiv.org/abs/2401.16416) | [Code](https://github.com/lastbasket/Endo-4DGS)] 34 | - DaReNeRF: Direction-aware Representation for Dynamic Scenes, CVPR 2024 35 | - Sync-NeRF: Generalizing Dynamic NeRFs to Unsynchronized Videos, AAAI2024. [[Code](https://github.com/seoha-kim/Sync-NeRF)] 36 | - SC-GS: Sparse-Controlled Gaussian Splatting for Editable Dynamic Scenes. [[Code](https://github.com/yihua7/SC-GS)] 37 | - InstantSplat: Unbounded Sparse-view Pose-free Gaussian Splatting in 40 Seconds, [Project](https://instantsplat.github.io/) 38 | - GaussianFlow: Splatting Gaussian Dynamics for 4D Content Creation 39 | - Entity-NeRF: Detecting and Removing Moving Entities in Urban Scenes, CVPR 2024. [Project](https://otonari726.github.io/entitynerf/) 40 | - Ced-NeRF: A Compact and Efficient Method for Dynamic Neural Radiance Fields, AAAI 2024. [Paper](https://ojs.aaai.org/index.php/AAAI/article/view/28138) 41 | - Shape of Motion: 4D Reconstruction from a Single Video, 2024. [[Project](https://shape-of-motion.github.io/) | [Code](https://github.com/vye16/shape-of-motion/)] 42 | - MoSca: Dynamic Gaussian Fusion from Casual Videos via 4D Motion Scaffolds, 2024. [Project](https://www.cis.upenn.edu/~leijh/projects/mosca/) 43 | - Dynamic Gaussian Marbles for Novel View Synthesis of Casual Monocular Videos, 2024 44 | - Dynamic Gaussian Mesh: Consistent Mesh Reconstruction from Monocular Videos, 2024. [Project](https://www.liuisabella.com/DG-Mesh/) 45 | - DyNeRFactor: Temporally consistent intrinsic scene decomposition for dynamic NeRFs, 2024. [Paper](https://www.sciencedirect.com/science/article/pii/S0097849324001195?dgcid=rss_sd_all) 46 | - DynVideo-E: Harnessing Dynamic NeRF for Large-Scale Motion- and View-Change Human-Centric Video Editing, CVPR 2024. [Project](https://showlab.github.io/DynVideo-E/) 47 | - Point-DynRF: Point-Based Dynamic Radiance Fields From a Monocular Video, WACV 2024. [Paper](https://openaccess.thecvf.com/content/WACV2024/papers/Park_Point-DynRF_Point-Based_Dynamic_Radiance_Fields_From_a_Monocular_Video_WACV_2024_paper.pdf) 48 | 49 | - FPO++: efficient encoding and rendering of dynamic neural radiance fields by analyzing and enhancing Fourier PlenOctrees, The Visual Computer, 2024. 50 | - Evdnerf: Reconstructing event data with dynamic neural radiance fields, WACV 2024. [Code](https://github.com/anish-bhattacharya/EvDNeRF) 51 | - CTNeRF: Cross-time Transformer for dynamic neural radiance field from monocular video, Pattern Recognition, 2024. [Code](https://github.com/xingy038/ctnerf) 52 | - DynamicSurf: Dynamic Neural RGB-D Surface Reconstruction with an Optimizable Feature Grid, International Conference on 3D Vision (3DV) 2024. [Code](https://github.com/Mirgahney/dynsurf) 53 | - [+] Spacetime Gaussian Feature Splatting for Real-Time Dynamic View Synthesis, CVPR 2024. [Code](https://github.com/oppo-us-research/SpacetimeGaussians) 54 | 55 | ## 2023 56 | - DynIBaR: Neural Dynamic Image-Based Rendering, CVPR, 2023 [[Project Page](https://dynibar.github.io/)] 57 | - Tensor4D : Efficient Neural 4D Decomposition for High-fidelity Dynamic Reconstruction and Rendering, Shao et. al., CVPR, 2023. [[Paper](https://arxiv.org/abs/2211.11610) | [Code](https://github.com/DSaurus/Tensor4D)] 58 | - HyperReel: High-Fidelity 6-DoF Video with Ray-Conditioned Sampling, CVPR 2023 (Highlight). [Code](https://github.com/facebookresearch/hyperreel) 59 | - HexPlane: A Fast Representation for Dynamic Scenes, Cao et. al., CVPR, 2023. [[Paper](https://caoang327.github.io/HexPlane/HexPlane.pdf) | [Project Page](https://caoang327.github.io/HexPlane/) | [Code](https://github.com/Caoang327/HexPlane)] 60 | - Robust Dynamic Radiance Fields, Liu et. al., CVPR, 2023. [[Code](https://github.com/facebookresearch/robust-dynrf) | [Kaggle](https://www.kaggle.com/code/declanide/robust-nerf)] 61 | - V4D: Voxel for 4D Novel View Synthesis, Gan et. al., IEEE Transactions on Visualization and Computer Graphics, 2023. [[Paper](https://arxiv.org/abs/2205.14332) | [Code](https://github.com/GANWANSHUI/V4D)] (instructions for custom data) 62 | - Dynamic Mesh-Aware Radiance Fields, ICCV, 2023. [[Project Page](https://mesh-aware-rf.github.io/) | [Code](https://github.com/YilingQiao/DMRF)] 63 | - NeRFPlayer: A Streamable Dynamic Scene Representation with Decomposed Neural Radiance Fields, IEEE Transactions on Visualization and Computer Graphics, vol 29(5), 2023. [[Code](https://github.com/lsongx/nerfplayer-nerfstudio)] 64 | - Deformable 3D Gaussians for High-Fidelity Monocular Dynamic Scene Reconstruction, Yang et. al., ACM Transactions on Graphics, 2023. [[Paper](https://arxiv.org/pdf/2309.13101.pdf) | [Project Page](https://ingra14m.github.io/Deformable-Gaussians/) | [Code](https://github.com/ingra14m/Deformable-3D-Gaussians)] 65 | - V4d: Voxel for 4d novel view synthesis, Gan et. al., IEEE Transactions on Visualization and Computer Graphics, 2023. [[Code](https://github.com/GANWANSHUI/V4D)] 66 | - MixVoxels: Mixed Neural Voxels for Fast Multi-view Video Synthesis, ICCV2023 Oral. [Code](https://github.com/fengres/mixvoxels) 67 | - DynPoint: Dynamic Neural Point For View Synthesis, NeurIPS 2023. 68 | 69 | ## 2022 70 | - Fourier PlenOctrees for Dynamic Radiance Field Rendering in Real-time, CVPR 2022 [[Project Page](https://aoliao12138.github.io/FPO/)] 71 | - D2NeRF: Self-Supervised Decoupling of Dynamic and Static Objects from a Monocular Video, NeurIPS, 2022. [[Project Page](https://d2nerf.github.io/) | [Code](https://github.com/ChikaYan/d2nerf)] 72 | - Monocular Dynamic View Synthesis: A Reality Check, Gao et. al., Neurips 2022. [[Project Page](https://hangg7.com/dycheck/)] 73 | - TiNeuVox: Fast Dynamic Radiance Fields with Time-Aware Neural Voxels, Fang et. al., ACM SIGGRAPH Asia 2022. [[Project Page](https://jaminfong.cn/tineuvox/) | [Code](https://github.com/hustvl/TiNeuVox)] 74 | - Fourier PlenOctrees for Dynamic Radiance Field Rendering in Real-time, CVPR 2022. [Project](https://aoliao12138.github.io/FPO/) 75 | 76 | ## 2021 77 | - Nerfies: Deformable Neural Radiance Fields, ICCV, 2021. [[Code](https://github.com/google/nerfies)] (instructions for **custom data**, this is the one everyone refering to) 78 | - Dynamic View Synthesis from Dynamic Monocular Video, ICCV, 2021. [[Code](https://github.com/gaochen315/DynamicNeRF)] 79 | - HyperNeRF: A Higher-Dimensional Representation for Topologically Varying Neural Radiance Fields, ACM Trans. Graph, 2021. [[Code](https://github.com/google/hyperNeRF) | [Project Page](https://hypernerf.github.io/) | [Colab](./colabs/HyperNerf.ipynb)] (instructions for custom data) 80 | - BARF: Bundle-Adjusting Neural Radiance Fields, Lin et. al., ICCV 2021 (Oral). [[Code](https://github.com/chenhsuanlin/bundle-adjusting-NeRF)] 81 | 82 | ## 2020 83 | - D-NeRF: Neural Radiance Fields for Dynamic Scenes, Pumarola et. al, CVPR 2020. [[Project Page](https://www.albertpumarola.com/research/D-NeRF/index.html) | [Code](https://github.com/albertpumarola/D-NeRF)] -------------------------------------------------------------------------------- /colabs/HyperNerf.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "HyperNeRF Training.ipynb", 7 | "private_outputs": true, 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "TPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "EZ_wkNVdTz-C" 22 | }, 23 | "source": [ 24 | "# Let's train HyperNeRF!\n", 25 | "\n", 26 | "**Author**: [Keunhong Park](https://keunhong.com)\n", 27 | "\n", 28 | "[[Project Page](https://hypernerf.github.io)]\n", 29 | "[[Paper](https://arxiv.org/abs/2106.13228)]\n", 30 | "[[GitHub](https://github.com/google/hypernerf)]\n", 31 | "\n", 32 | "This notebook provides an demo for training HyperNeRF.\n", 33 | "\n", 34 | "### Instructions\n", 35 | "\n", 36 | "1. Convert a video into our dataset format using the Nerfies [dataset processing notebook](https://colab.sandbox.google.com/github/google/nerfies/blob/main/notebooks/Nerfies_Capture_Processing.ipynb).\n", 37 | "2. Set the `data_dir` below to where you saved the dataset.\n", 38 | "3. Come back to this notebook to train HyperNeRF.\n", 39 | "\n", 40 | "\n", 41 | "### Notes\n", 42 | " * To accomodate the limited compute power of Colab runtimes, this notebook defaults to a \"toy\" version of our method. The number of samples have been reduced and the elastic regularization turned off.\n", 43 | "\n", 44 | " * To train a high-quality model, please look at the CLI options we provide in the [Github repository](https://github.com/google/hypernerf).\n", 45 | "\n", 46 | "\n", 47 | "\n", 48 | " * Please report issues on the [GitHub issue tracker](https://github.com/google/hypernerf/issues).\n", 49 | "\n", 50 | "\n", 51 | "If you find this work useful, please consider citing:\n", 52 | "```bibtex\n", 53 | "@article{park2021hypernerf\n", 54 | " author = {Park, Keunhong and Sinha, Utkarsh and Hedman, Peter and Barron, Jonathan T. and Bouaziz, Sofien and Goldman, Dan B and Martin-Brualla, Ricardo and Seitz, Steven M.},\n", 55 | " title = {HyperNeRF: A Higher-Dimensional Representation for Topologically Varying Neural Radiance Fields},\n", 56 | " journal = {arXiv preprint arXiv:2106.13228},\n", 57 | " year = {2021},\n", 58 | "}\n", 59 | "```\n" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": { 65 | "id": "OlW1gF_djH6H" 66 | }, 67 | "source": [ 68 | "## Environment Setup" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "metadata": { 74 | "id": "I6Jbspl7TnIX" 75 | }, 76 | "source": [ 77 | "!pip install flax immutabledict mediapy\n", 78 | "!pip install --upgrade git+https://github.com/google/hypernerf" 79 | ], 80 | "execution_count": null, 81 | "outputs": [] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "metadata": { 86 | "id": "zGJux-m5Xp3Z", 87 | "cellView": "form" 88 | }, 89 | "source": [ 90 | "# @title Configure notebook runtime\n", 91 | "# @markdown If you would like to use a GPU runtime instead, change the runtime type by going to `Runtime > Change runtime type`. \n", 92 | "# @markdown You will have to use a smaller batch size on GPU.\n", 93 | "\n", 94 | "runtime_type = 'tpu' # @param ['gpu', 'tpu']\n", 95 | "if runtime_type == 'tpu':\n", 96 | " import jax.tools.colab_tpu\n", 97 | " jax.tools.colab_tpu.setup_tpu()\n", 98 | "\n", 99 | "print('Detected Devices:', jax.devices())" 100 | ], 101 | "execution_count": null, 102 | "outputs": [] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "metadata": { 107 | "id": "afUtLfRWULEi", 108 | "cellView": "form" 109 | }, 110 | "source": [ 111 | "# @title Mount Google Drive\n", 112 | "# @markdown Mount Google Drive onto `/content/gdrive`. You can skip this if running locally.\n", 113 | "\n", 114 | "from google.colab import drive\n", 115 | "drive.mount('/content/gdrive')" 116 | ], 117 | "execution_count": null, 118 | "outputs": [] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "metadata": { 123 | "id": "ENOfbG3AkcVN", 124 | "cellView": "form" 125 | }, 126 | "source": [ 127 | "# @title Define imports and utility functions.\n", 128 | "\n", 129 | "import jax\n", 130 | "from jax.config import config as jax_config\n", 131 | "import jax.numpy as jnp\n", 132 | "from jax import grad, jit, vmap\n", 133 | "from jax import random\n", 134 | "\n", 135 | "import flax\n", 136 | "import flax.linen as nn\n", 137 | "from flax import jax_utils\n", 138 | "from flax import optim\n", 139 | "from flax.metrics import tensorboard\n", 140 | "from flax.training import checkpoints\n", 141 | "jax_config.enable_omnistaging() # Linen requires enabling omnistaging\n", 142 | "\n", 143 | "from absl import logging\n", 144 | "from io import BytesIO\n", 145 | "import random as pyrandom\n", 146 | "import numpy as np\n", 147 | "import PIL\n", 148 | "import IPython\n", 149 | "\n", 150 | "\n", 151 | "# Monkey patch logging.\n", 152 | "def myprint(msg, *args, **kwargs):\n", 153 | " print(msg % args)\n", 154 | "\n", 155 | "logging.info = myprint \n", 156 | "logging.warn = myprint\n", 157 | "logging.error = myprint\n", 158 | "\n", 159 | "\n", 160 | "def show_image(image, fmt='png'):\n", 161 | " image = image_utils.image_to_uint8(image)\n", 162 | " f = BytesIO()\n", 163 | " PIL.Image.fromarray(image).save(f, fmt)\n", 164 | " IPython.display.display(IPython.display.Image(data=f.getvalue()))\n", 165 | "\n" 166 | ], 167 | "execution_count": null, 168 | "outputs": [] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "metadata": { 173 | "id": "wW7FsSB-jORB" 174 | }, 175 | "source": [ 176 | "## Configuration" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "metadata": { 182 | "id": "rz7wRm7YT9Ka" 183 | }, 184 | "source": [ 185 | "# @title Model and dataset configuration\n", 186 | "\n", 187 | "from pathlib import Path\n", 188 | "from pprint import pprint\n", 189 | "import gin\n", 190 | "from IPython.display import display, Markdown\n", 191 | "\n", 192 | "from hypernerf import models\n", 193 | "from hypernerf import modules\n", 194 | "from hypernerf import warping\n", 195 | "from hypernerf import datasets\n", 196 | "from hypernerf import configs\n", 197 | "\n", 198 | "\n", 199 | "# @markdown The working directory.\n", 200 | "train_dir = '/content/gdrive/My Drive/nerfies/hypernerf_experiments/capture1/exp1' # @param {type: \"string\"}\n", 201 | "# @markdown The directory to the dataset capture.\n", 202 | "data_dir = '/content/gdrive/My Drive/nerfies/captures/capture1' # @param {type: \"string\"}\n", 203 | "\n", 204 | "# @markdown Training configuration.\n", 205 | "max_steps = 100000 # @param {type: 'number'}\n", 206 | "batch_size = 4096 # @param {type: 'number'}\n", 207 | "image_scale = 8 # @param {type: 'number'}\n", 208 | "\n", 209 | "# @markdown Model configuration.\n", 210 | "use_viewdirs = True #@param {type: 'boolean'}\n", 211 | "use_appearance_metadata = True #@param {type: 'boolean'}\n", 212 | "num_coarse_samples = 64 # @param {type: 'number'}\n", 213 | "num_fine_samples = 64 # @param {type: 'number'}\n", 214 | "\n", 215 | "# @markdown Deformation configuration.\n", 216 | "use_warp = True #@param {type: 'boolean'}\n", 217 | "warp_field_type = '@SE3Field' #@param['@SE3Field', '@TranslationField']\n", 218 | "warp_min_deg = 0 #@param{type:'number'}\n", 219 | "warp_max_deg = 6 #@param{type:'number'}\n", 220 | "\n", 221 | "# @markdown Hyper-space configuration.\n", 222 | "hyper_num_dims = 8 #@param{type:'number'}\n", 223 | "hyper_point_min_deg = 0 #@param{type:'number'}\n", 224 | "hyper_point_max_deg = 1 #@param{type:'number'}\n", 225 | "hyper_slice_method = 'bendy_sheet' #@param['none', 'axis_aligned_plane', 'bendy_sheet']\n", 226 | "\n", 227 | "\n", 228 | "checkpoint_dir = Path(train_dir, 'checkpoints')\n", 229 | "checkpoint_dir.mkdir(exist_ok=True, parents=True)\n", 230 | "\n", 231 | "config_str = f\"\"\"\n", 232 | "DELAYED_HYPER_ALPHA_SCHED = {{\n", 233 | " 'type': 'piecewise',\n", 234 | " 'schedules': [\n", 235 | " (1000, ('constant', 0.0)),\n", 236 | " (0, ('linear', 0.0, %hyper_point_max_deg, 10000))\n", 237 | " ],\n", 238 | "}}\n", 239 | "\n", 240 | "ExperimentConfig.image_scale = {image_scale}\n", 241 | "ExperimentConfig.datasource_cls = @NerfiesDataSource\n", 242 | "NerfiesDataSource.data_dir = '{data_dir}'\n", 243 | "NerfiesDataSource.image_scale = {image_scale}\n", 244 | "\n", 245 | "NerfModel.use_viewdirs = {int(use_viewdirs)}\n", 246 | "NerfModel.use_rgb_condition = {int(use_appearance_metadata)}\n", 247 | "NerfModel.num_coarse_samples = {num_coarse_samples}\n", 248 | "NerfModel.num_fine_samples = {num_fine_samples}\n", 249 | "\n", 250 | "NerfModel.use_viewdirs = True\n", 251 | "NerfModel.use_stratified_sampling = True\n", 252 | "NerfModel.use_posenc_identity = False\n", 253 | "NerfModel.nerf_trunk_width = 128\n", 254 | "NerfModel.nerf_trunk_depth = 8\n", 255 | "\n", 256 | "TrainConfig.max_steps = {max_steps}\n", 257 | "TrainConfig.batch_size = {batch_size}\n", 258 | "TrainConfig.print_every = 100\n", 259 | "TrainConfig.use_elastic_loss = False\n", 260 | "TrainConfig.use_background_loss = False\n", 261 | "\n", 262 | "# Warp configs.\n", 263 | "warp_min_deg = {warp_min_deg}\n", 264 | "warp_max_deg = {warp_max_deg}\n", 265 | "NerfModel.use_warp = {use_warp}\n", 266 | "SE3Field.min_deg = %warp_min_deg\n", 267 | "SE3Field.max_deg = %warp_max_deg\n", 268 | "SE3Field.use_posenc_identity = False\n", 269 | "NerfModel.warp_field_cls = @SE3Field\n", 270 | "\n", 271 | "TrainConfig.warp_alpha_schedule = {{\n", 272 | " 'type': 'linear',\n", 273 | " 'initial_value': {warp_min_deg},\n", 274 | " 'final_value': {warp_max_deg},\n", 275 | " 'num_steps': {int(max_steps*0.8)},\n", 276 | "}}\n", 277 | "\n", 278 | "# Hyper configs.\n", 279 | "hyper_num_dims = {hyper_num_dims}\n", 280 | "hyper_point_min_deg = {hyper_point_min_deg}\n", 281 | "hyper_point_max_deg = {hyper_point_max_deg}\n", 282 | "\n", 283 | "NerfModel.hyper_embed_cls = @hyper/GLOEmbed\n", 284 | "hyper/GLOEmbed.num_dims = %hyper_num_dims\n", 285 | "NerfModel.hyper_point_min_deg = %hyper_point_min_deg\n", 286 | "NerfModel.hyper_point_max_deg = %hyper_point_max_deg\n", 287 | "\n", 288 | "TrainConfig.hyper_alpha_schedule = %DELAYED_HYPER_ALPHA_SCHED\n", 289 | "\n", 290 | "hyper_sheet_min_deg = 0\n", 291 | "hyper_sheet_max_deg = 6\n", 292 | "HyperSheetMLP.min_deg = %hyper_sheet_min_deg\n", 293 | "HyperSheetMLP.max_deg = %hyper_sheet_max_deg\n", 294 | "HyperSheetMLP.output_channels = %hyper_num_dims\n", 295 | "\n", 296 | "NerfModel.hyper_slice_method = '{hyper_slice_method}'\n", 297 | "NerfModel.hyper_sheet_mlp_cls = @HyperSheetMLP\n", 298 | "NerfModel.hyper_use_warp_embed = True\n", 299 | "\n", 300 | "TrainConfig.hyper_sheet_alpha_schedule = ('constant', %hyper_sheet_max_deg)\n", 301 | "\"\"\"\n", 302 | "\n", 303 | "gin.parse_config(config_str)\n", 304 | "\n", 305 | "config_path = Path(train_dir, 'config.gin')\n", 306 | "with open(config_path, 'w') as f:\n", 307 | " logging.info('Saving config to %s', config_path)\n", 308 | " f.write(config_str)\n", 309 | "\n", 310 | "exp_config = configs.ExperimentConfig()\n", 311 | "train_config = configs.TrainConfig()\n", 312 | "eval_config = configs.EvalConfig()\n", 313 | "\n", 314 | "display(Markdown(\n", 315 | " gin.config.markdown(gin.config_str())))" 316 | ], 317 | "execution_count": null, 318 | "outputs": [] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "metadata": { 323 | "id": "r872r6hiVUVS", 324 | "cellView": "form" 325 | }, 326 | "source": [ 327 | "# @title Create datasource and show an example.\n", 328 | "\n", 329 | "from hypernerf import datasets\n", 330 | "from hypernerf import image_utils\n", 331 | "\n", 332 | "dummy_model = models.NerfModel({}, 0, 0)\n", 333 | "datasource = exp_config.datasource_cls(\n", 334 | " image_scale=exp_config.image_scale,\n", 335 | " random_seed=exp_config.random_seed,\n", 336 | " # Enable metadata based on model needs.\n", 337 | " use_warp_id=dummy_model.use_warp,\n", 338 | " use_appearance_id=(\n", 339 | " dummy_model.nerf_embed_key == 'appearance'\n", 340 | " or dummy_model.hyper_embed_key == 'appearance'),\n", 341 | " use_camera_id=dummy_model.nerf_embed_key == 'camera',\n", 342 | " use_time=dummy_model.warp_embed_key == 'time')\n", 343 | "\n", 344 | "show_image(datasource.load_rgb(datasource.train_ids[0]))" 345 | ], 346 | "execution_count": null, 347 | "outputs": [] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "metadata": { 352 | "id": "XC3PIY74XB05", 353 | "cellView": "form" 354 | }, 355 | "source": [ 356 | "# @title Create training iterators\n", 357 | "\n", 358 | "devices = jax.local_devices()\n", 359 | "\n", 360 | "train_iter = datasource.create_iterator(\n", 361 | " datasource.train_ids,\n", 362 | " flatten=True,\n", 363 | " shuffle=True,\n", 364 | " batch_size=train_config.batch_size,\n", 365 | " prefetch_size=3,\n", 366 | " shuffle_buffer_size=train_config.shuffle_buffer_size,\n", 367 | " devices=devices,\n", 368 | ")\n", 369 | "\n", 370 | "def shuffled(l):\n", 371 | " import random as r\n", 372 | " import copy\n", 373 | " l = copy.copy(l)\n", 374 | " r.shuffle(l)\n", 375 | " return l\n", 376 | "\n", 377 | "train_eval_iter = datasource.create_iterator(\n", 378 | " shuffled(datasource.train_ids), batch_size=0, devices=devices)\n", 379 | "val_eval_iter = datasource.create_iterator(\n", 380 | " shuffled(datasource.val_ids), batch_size=0, devices=devices)" 381 | ], 382 | "execution_count": null, 383 | "outputs": [] 384 | }, 385 | { 386 | "cell_type": "markdown", 387 | "metadata": { 388 | "id": "erY9l66KjYYW" 389 | }, 390 | "source": [ 391 | "## Training" 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "metadata": { 397 | "id": "nZnS8BhcXe5E", 398 | "cellView": "form" 399 | }, 400 | "source": [ 401 | "# @title Initialize model\n", 402 | "# @markdown Defines the model and initializes its parameters.\n", 403 | "\n", 404 | "from flax.training import checkpoints\n", 405 | "from hypernerf import models\n", 406 | "from hypernerf import model_utils\n", 407 | "from hypernerf import schedules\n", 408 | "from hypernerf import training\n", 409 | "\n", 410 | "# @markdown Restore a checkpoint if one exists.\n", 411 | "restore_checkpoint = False # @param{type:'boolean'}\n", 412 | "\n", 413 | "\n", 414 | "rng = random.PRNGKey(exp_config.random_seed)\n", 415 | "np.random.seed(exp_config.random_seed + jax.process_index())\n", 416 | "devices_to_use = jax.devices()\n", 417 | "\n", 418 | "learning_rate_sched = schedules.from_config(train_config.lr_schedule)\n", 419 | "nerf_alpha_sched = schedules.from_config(train_config.nerf_alpha_schedule)\n", 420 | "warp_alpha_sched = schedules.from_config(train_config.warp_alpha_schedule)\n", 421 | "elastic_loss_weight_sched = schedules.from_config(\n", 422 | "train_config.elastic_loss_weight_schedule)\n", 423 | "hyper_alpha_sched = schedules.from_config(train_config.hyper_alpha_schedule)\n", 424 | "hyper_sheet_alpha_sched = schedules.from_config(\n", 425 | " train_config.hyper_sheet_alpha_schedule)\n", 426 | "\n", 427 | "rng, key = random.split(rng)\n", 428 | "params = {}\n", 429 | "model, params['model'] = models.construct_nerf(\n", 430 | " key,\n", 431 | " batch_size=train_config.batch_size,\n", 432 | " embeddings_dict=datasource.embeddings_dict,\n", 433 | " near=datasource.near,\n", 434 | " far=datasource.far)\n", 435 | "\n", 436 | "optimizer_def = optim.Adam(learning_rate_sched(0))\n", 437 | "optimizer = optimizer_def.create(params)\n", 438 | "\n", 439 | "state = model_utils.TrainState(\n", 440 | " optimizer=optimizer,\n", 441 | " nerf_alpha=nerf_alpha_sched(0),\n", 442 | " warp_alpha=warp_alpha_sched(0),\n", 443 | " hyper_alpha=hyper_alpha_sched(0),\n", 444 | " hyper_sheet_alpha=hyper_sheet_alpha_sched(0))\n", 445 | "scalar_params = training.ScalarParams(\n", 446 | " learning_rate=learning_rate_sched(0),\n", 447 | " elastic_loss_weight=elastic_loss_weight_sched(0),\n", 448 | " warp_reg_loss_weight=train_config.warp_reg_loss_weight,\n", 449 | " warp_reg_loss_alpha=train_config.warp_reg_loss_alpha,\n", 450 | " warp_reg_loss_scale=train_config.warp_reg_loss_scale,\n", 451 | " background_loss_weight=train_config.background_loss_weight,\n", 452 | " hyper_reg_loss_weight=train_config.hyper_reg_loss_weight)\n", 453 | "\n", 454 | "if restore_checkpoint:\n", 455 | " logging.info('Restoring checkpoint from %s', checkpoint_dir)\n", 456 | " state = checkpoints.restore_checkpoint(checkpoint_dir, state)\n", 457 | "step = state.optimizer.state.step + 1\n", 458 | "state = jax_utils.replicate(state, devices=devices)\n", 459 | "del params" 460 | ], 461 | "execution_count": null, 462 | "outputs": [] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "metadata": { 467 | "id": "at2CL5DRZ7By", 468 | "cellView": "form" 469 | }, 470 | "source": [ 471 | "# @title Define pmapped functions\n", 472 | "# @markdown This parallelizes the training and evaluation step functions using `jax.pmap`.\n", 473 | "\n", 474 | "import functools\n", 475 | "from hypernerf import evaluation\n", 476 | "\n", 477 | "\n", 478 | "def _model_fn(key_0, key_1, params, rays_dict, extra_params):\n", 479 | " out = model.apply({'params': params},\n", 480 | " rays_dict,\n", 481 | " extra_params=extra_params,\n", 482 | " rngs={\n", 483 | " 'coarse': key_0,\n", 484 | " 'fine': key_1\n", 485 | " },\n", 486 | " mutable=False)\n", 487 | " return jax.lax.all_gather(out, axis_name='batch')\n", 488 | "\n", 489 | "pmodel_fn = jax.pmap(\n", 490 | " # Note rng_keys are useless in eval mode since there's no randomness.\n", 491 | " _model_fn,\n", 492 | " in_axes=(0, 0, 0, 0, 0), # Only distribute the data input.\n", 493 | " devices=devices_to_use,\n", 494 | " axis_name='batch',\n", 495 | ")\n", 496 | "\n", 497 | "render_fn = functools.partial(evaluation.render_image,\n", 498 | " model_fn=pmodel_fn,\n", 499 | " device_count=len(devices),\n", 500 | " chunk=eval_config.chunk)\n", 501 | "train_step = functools.partial(\n", 502 | " training.train_step,\n", 503 | " model,\n", 504 | " elastic_reduce_method=train_config.elastic_reduce_method,\n", 505 | " elastic_loss_type=train_config.elastic_loss_type,\n", 506 | " use_elastic_loss=train_config.use_elastic_loss,\n", 507 | " use_background_loss=train_config.use_background_loss,\n", 508 | " use_warp_reg_loss=train_config.use_warp_reg_loss,\n", 509 | " use_hyper_reg_loss=train_config.use_hyper_reg_loss,\n", 510 | ")\n", 511 | "ptrain_step = jax.pmap(\n", 512 | " train_step,\n", 513 | " axis_name='batch',\n", 514 | " devices=devices,\n", 515 | " # rng_key, state, batch, scalar_params.\n", 516 | " in_axes=(0, 0, 0, None),\n", 517 | " # Treat use_elastic_loss as compile-time static.\n", 518 | " donate_argnums=(2,), # Donate the 'batch' argument.\n", 519 | ")" 520 | ], 521 | "execution_count": null, 522 | "outputs": [] 523 | }, 524 | { 525 | "cell_type": "code", 526 | "metadata": { 527 | "id": "vbc7cMr5aR_1", 528 | "cellView": "form" 529 | }, 530 | "source": [ 531 | "# @title Train!\n", 532 | "# @markdown This runs the training loop!\n", 533 | "\n", 534 | "import mediapy\n", 535 | "from hypernerf import utils\n", 536 | "from hypernerf import visualization as viz\n", 537 | "\n", 538 | "\n", 539 | "print_every_n_iterations = 100 # @param{type:'number'}\n", 540 | "visualize_results_every_n_iterations = 500 # @param{type:'number'}\n", 541 | "save_checkpoint_every_n_iterations = 1000 # @param{type:'number'}\n", 542 | "\n", 543 | "\n", 544 | "logging.info('Starting training')\n", 545 | "rng = rng + jax.process_index() # Make random seed separate across hosts.\n", 546 | "keys = random.split(rng, len(devices))\n", 547 | "time_tracker = utils.TimeTracker()\n", 548 | "time_tracker.tic('data', 'total')\n", 549 | "\n", 550 | "for step, batch in zip(range(step, train_config.max_steps + 1), train_iter):\n", 551 | " time_tracker.toc('data')\n", 552 | " scalar_params = scalar_params.replace(\n", 553 | " learning_rate=learning_rate_sched(step),\n", 554 | " elastic_loss_weight=elastic_loss_weight_sched(step))\n", 555 | " # pytype: enable=attribute-error\n", 556 | " nerf_alpha = jax_utils.replicate(nerf_alpha_sched(step), devices)\n", 557 | " warp_alpha = jax_utils.replicate(warp_alpha_sched(step), devices)\n", 558 | " hyper_alpha = jax_utils.replicate(hyper_alpha_sched(step), devices)\n", 559 | " hyper_sheet_alpha = jax_utils.replicate(\n", 560 | " hyper_sheet_alpha_sched(step), devices)\n", 561 | " state = state.replace(nerf_alpha=nerf_alpha,\n", 562 | " warp_alpha=warp_alpha,\n", 563 | " hyper_alpha=hyper_alpha,\n", 564 | " hyper_sheet_alpha=hyper_sheet_alpha)\n", 565 | "\n", 566 | " with time_tracker.record_time('train_step'):\n", 567 | " state, stats, keys, _ = ptrain_step(keys, state, batch, scalar_params)\n", 568 | " time_tracker.toc('total')\n", 569 | "\n", 570 | " if step % print_every_n_iterations == 0:\n", 571 | " logging.info(\n", 572 | " 'step=%d, warp_alpha=%.04f, hyper_alpha=%.04f, hyper_sheet_alpha=%.04f, %s',\n", 573 | " step, \n", 574 | " warp_alpha_sched(step), \n", 575 | " hyper_alpha_sched(step), \n", 576 | " hyper_sheet_alpha_sched(step), \n", 577 | " time_tracker.summary_str('last'))\n", 578 | " coarse_metrics_str = ', '.join(\n", 579 | " [f'{k}={v.mean():.04f}' for k, v in stats['coarse'].items()])\n", 580 | " fine_metrics_str = ', '.join(\n", 581 | " [f'{k}={v.mean():.04f}' for k, v in stats['fine'].items()])\n", 582 | " logging.info('\\tcoarse metrics: %s', coarse_metrics_str)\n", 583 | " if 'fine' in stats:\n", 584 | " logging.info('\\tfine metrics: %s', fine_metrics_str)\n", 585 | " \n", 586 | " if step % visualize_results_every_n_iterations == 0:\n", 587 | " print(f'[step={step}] Training set visualization')\n", 588 | " eval_batch = next(train_eval_iter)\n", 589 | " render = render_fn(state, eval_batch, rng=rng)\n", 590 | " rgb = render['rgb']\n", 591 | " acc = render['acc']\n", 592 | " depth_exp = render['depth']\n", 593 | " depth_med = render['med_depth']\n", 594 | " rgb_target = eval_batch['rgb']\n", 595 | " depth_med_viz = viz.colorize(depth_med, cmin=datasource.near, cmax=datasource.far)\n", 596 | " mediapy.show_images([rgb_target, rgb, depth_med_viz],\n", 597 | " titles=['GT RGB', 'Pred RGB', 'Pred Depth'])\n", 598 | "\n", 599 | " print(f'[step={step}] Validation set visualization')\n", 600 | " eval_batch = next(val_eval_iter)\n", 601 | " render = render_fn(state, eval_batch, rng=rng)\n", 602 | " rgb = render['rgb']\n", 603 | " acc = render['acc']\n", 604 | " depth_exp = render['depth']\n", 605 | " depth_med = render['med_depth']\n", 606 | " rgb_target = eval_batch['rgb']\n", 607 | " depth_med_viz = viz.colorize(depth_med, cmin=datasource.near, cmax=datasource.far)\n", 608 | " mediapy.show_images([rgb_target, rgb, depth_med_viz],\n", 609 | " titles=['GT RGB', 'Pred RGB', 'Pred Depth'])\n", 610 | "\n", 611 | " if step % save_checkpoint_every_n_iterations == 0:\n", 612 | " training.save_checkpoint(checkpoint_dir, state)\n", 613 | "\n", 614 | " time_tracker.tic('data', 'total')\n" 615 | ], 616 | "execution_count": null, 617 | "outputs": [] 618 | }, 619 | { 620 | "cell_type": "code", 621 | "metadata": { 622 | "id": "o69auGWvdyyd" 623 | }, 624 | "source": [ 625 | "" 626 | ], 627 | "execution_count": null, 628 | "outputs": [] 629 | } 630 | ] 631 | } --------------------------------------------------------------------------------