├── .gitignore ├── README.md ├── convert_dataset ├── blender.py ├── blender_script.py ├── blender_script_random.py └── llff.py ├── go.mod ├── go.sum ├── learn_nerf ├── __init__.py ├── dataset.py ├── instant_ngp.py ├── model.py ├── ref_nerf.py ├── render.py ├── scripts │ ├── check_bbox.py │ ├── cv_nerf.py │ ├── marching_cubes.py │ ├── plot_log.py │ ├── render_nerf.py │ ├── render_nerf_interactive.ipynb │ ├── render_nerf_pan.py │ ├── render_nerf_spin.py │ ├── render_new_dataset.py │ └── train_nerf.py ├── test_dataset.py └── train.py ├── point_cloud └── main.go ├── setup.py └── simple_dataset ├── camera_gen.go ├── main.go └── vector_flag.go /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info/ 2 | __pycache__/ 3 | data/ 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # learn-nerf 2 | 3 | This is a JAX implementation of [Neural Radiance Fields](https://arxiv.org/abs/2003.08934) for learning purposes. 4 | 5 | I've been curious about NeRF and its follow-up work for a while, but don't have much time to explore it. I learn best by doing, so I'll be implementing stuff here to try to get a feel for it. 6 | 7 | # Usage 8 | 9 | The steps to using this codebase are as follows: 10 | 11 | 1. [Generate a dataset](#generating-a-dataset) - run a simple Go program to turn any `.stl` 3D model into a series of rendered camera views with associated metadata. 12 | 2. [Train a model](#training-a-model) - install the Python dependencies and run the training script. 13 | 3. [Render a novel view](#render-a-novel-view) - render a novel view of the object using a model. 14 | 15 | ## Generating a dataset 16 | 17 | I use a simple format for storing rendered views of the scene. Each frame is stored as a PNG file, and each PNG has an accompanying JSON file describing the camera view. 18 | 19 | For easy experimentation, I created a Go program to render an arbitrary `.stl` file as a collection of views in the supported data format. To run this program, install [Go](https://go.dev/doc/install) and run `go get .` inside of [simple_dataset/](simple_dataset) to get the dependencies. Next, run 20 | 21 | ``` 22 | $ go run . /path/to/model.stl data_dir 23 | ``` 24 | 25 | This will create a directory `data_dir` containing rendered views of `/path/to/model.stl`. 26 | 27 | ## Training a model 28 | 29 | First, install the `learn_nerf` package by running `pip install -e .` inside this repository. You should separately make sure [jax](https://github.com/google/jax) and [Flax](https://github.com/google/flax) are installed in your environment. 30 | 31 | The training script is [learn_nerf/scripts/train_nerf.py](learn_nerf/scripts/train_nerf.py). Here's an example of running this script: 32 | 33 | ``` 34 | python learn_nerf/scripts/train_nerf.py \ 35 | --lr 1e-5 \ 36 | --batch_size 1024 \ 37 | --save_path model_weights.pkl \ 38 | /path/to/data_dir 39 | ``` 40 | 41 | This will periodically save model weights to `model_weights.pkl`. The script may get stuck on `training...` while it shuffles the dataset and compiles the training graph. Wait a minute or two, and losses should start printing out as training ramps up. 42 | 43 | If you get a `Segmentation fault` on CPU, this may be because you don't have enough memory to run batch size 1024--try something lower. 44 | 45 | ## Render a novel view 46 | 47 | To render a view from a trained NeRF model, use [learn_nerf/scripts/render_nerf.py](learn_nerf/scripts/render_nerf.py). Here's an example of the usage: 48 | 49 | ``` 50 | python learn_nerf/scripts/render_nerf.py \ 51 | --batch_size 1024 \ 52 | --model_path model_weights.pkl \ 53 | --width 128 \ 54 | --height 128 \ 55 | /path/to/data_dir/0000.json \ 56 | output.png 57 | ``` 58 | 59 | In the above example, we will render the camera view described by `/path/to/data_dir/0000.json`. Note that the camera view can be from the training set, but doesn't need to be as long as its in the correct JSON format. 60 | -------------------------------------------------------------------------------- /convert_dataset/blender.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert a blender dataset from the original NeRF repo into the format used by 3 | this repository. 4 | """ 5 | 6 | import argparse 7 | import json 8 | import math 9 | import os 10 | import shutil 11 | 12 | import numpy as np 13 | from PIL import Image 14 | 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--split", type=str, default="train") 19 | parser.add_argument("input_dir", type=str) 20 | parser.add_argument("output_dir", type=str) 21 | args = parser.parse_args() 22 | 23 | if os.path.exists(args.output_dir): 24 | raise FileExistsError(f"output path exists: {args.output_dir}") 25 | os.mkdir(args.output_dir) 26 | 27 | json_path = os.path.join(args.input_dir, f"transforms_{args.split}.json") 28 | with open(json_path, "r") as f: 29 | info = json.load(f) 30 | 31 | x_fov = info["camera_angle_x"] 32 | for i, frame in enumerate(info["frames"]): 33 | img_path = os.path.join(args.input_dir, frame["file_path"] + ".png") 34 | img_width, img_height = Image.open(img_path).size 35 | 36 | matrix = np.array(frame["transform_matrix"]) 37 | origin = matrix[:3, -1] 38 | rot = matrix[:3, :3] 39 | x = rot @ np.array([1.0, 0.0, 0.0]) 40 | y = rot @ np.array([0.0, -1.0, 0.0]) 41 | z = rot @ np.array([0.0, 0.0, -1.0]) 42 | y_fov = 2 * math.atan(math.tan(x_fov / 2) * img_height / img_width) 43 | 44 | out_base = os.path.join(args.output_dir, f"{i:04}") 45 | with open(out_base + ".json", "w") as f: 46 | json.dump( 47 | dict( 48 | origin=origin.tolist(), 49 | x_fov=x_fov, 50 | y_fov=y_fov, 51 | x=x.tolist(), 52 | y=y.tolist(), 53 | z=z.tolist(), 54 | ), 55 | f, 56 | ) 57 | shutil.copyfile(img_path, out_base + ".png") 58 | 59 | with open(os.path.join(args.output_dir, "metadata.json"), "w") as f: 60 | json.dump(dict(min=[-1.0] * 3, max=[1.0] * 3), f) 61 | 62 | 63 | if __name__ == "__main__": 64 | main() 65 | -------------------------------------------------------------------------------- /convert_dataset/blender_script.py: -------------------------------------------------------------------------------- 1 | """ 2 | Save Blender animations as a NERF dataset. Each frame of the animation is 3 | saved as an image with an accompanying metadata file. 4 | 5 | To use this script, set the OUTPUT_DIR variable near the top of the file 6 | to point to a directory where the dataset should be saved, then run the 7 | script inside of the "Scripting" tab in Blender. 8 | """ 9 | 10 | import json 11 | import math 12 | import os 13 | 14 | import bpy 15 | 16 | OUTPUT_DIR = None 17 | assert OUTPUT_DIR is not None, "must set OUTPUT_DIR" 18 | os.makedirs(OUTPUT_DIR, exist_ok=True) 19 | 20 | 21 | def scene_bbox(): 22 | large = 100000.0 23 | bbox_min = (large,) * 3 24 | bbox_max = (-large,) * 3 25 | for obj in bpy.context.scene.objects.values(): 26 | if isinstance(obj.data, (bpy.types.Camera, bpy.types.Light)): 27 | continue 28 | for coord in obj.bound_box: 29 | bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord)) 30 | bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord)) 31 | return dict(min=bbox_min, max=bbox_max) 32 | 33 | 34 | def scene_fov(): 35 | x_fov = scene.camera.data.angle_x 36 | y_fov = scene.camera.data.angle_y 37 | width = bpy.context.scene.render.resolution_x 38 | height = bpy.context.scene.render.resolution_y 39 | if scene.camera.data.angle == x_fov: 40 | y_fov = 2 * math.atan(math.tan(x_fov / 2) * height / width) 41 | else: 42 | x_fov = 2 * math.atan(math.tan(y_fov / 2) * width / height) 43 | return x_fov, y_fov 44 | 45 | 46 | with open(os.path.join(OUTPUT_DIR, "metadata.json"), "w") as f: 47 | json.dump(scene_bbox(), f) 48 | 49 | scene = bpy.context.scene 50 | backup_path = scene.render.filepath 51 | backup_format = scene.render.image_settings.file_format 52 | try: 53 | scene.render.image_settings.file_format = "PNG" 54 | for i, frame in enumerate(range(scene.frame_start, scene.frame_end)): 55 | scene.frame_set(frame) 56 | scene.render.filepath = os.path.join(OUTPUT_DIR, f"{i:05}") 57 | 58 | x_fov, y_fov = scene_fov() 59 | matrix = scene.camera.matrix_world 60 | with open(scene.render.filepath + ".json", "w") as f: 61 | json.dump( 62 | dict( 63 | origin=list(matrix.col[3])[:3], 64 | x_fov=x_fov, 65 | y_fov=y_fov, 66 | x=list(matrix.col[0])[:3], 67 | y=list(-matrix.col[1])[:3], 68 | z=list(-matrix.col[2])[:3], 69 | ), 70 | f, 71 | ) 72 | 73 | bpy.ops.render.render(write_still=True) 74 | finally: 75 | scene.render.filepath = backup_path 76 | scene.render.image_settings.file_format = backup_format 77 | -------------------------------------------------------------------------------- /convert_dataset/blender_script_random.py: -------------------------------------------------------------------------------- 1 | """ 2 | Save random views of objects in Blender as a NERF dataset. This is similar to 3 | blender_script.py, but uses random views rather than animation views. 4 | """ 5 | 6 | import json 7 | import math 8 | import os 9 | 10 | import bpy 11 | from mathutils import Vector 12 | from mathutils.noise import random_unit_vector 13 | 14 | NUM_FRAMES = 100 15 | OUTPUT_DIR = None 16 | assert OUTPUT_DIR is not None, "must set OUTPUT_DIR" 17 | os.makedirs(OUTPUT_DIR, exist_ok=True) 18 | 19 | 20 | def scene_bbox(): 21 | large = 100000.0 22 | bbox_min = (large,) * 3 23 | bbox_max = (-large,) * 3 24 | for obj in bpy.context.scene.objects.values(): 25 | if isinstance(obj.data, (bpy.types.Camera, bpy.types.Light)): 26 | continue 27 | for coord in obj.bound_box: 28 | bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord)) 29 | bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord)) 30 | return dict(min=bbox_min, max=bbox_max) 31 | 32 | 33 | def scene_center(): 34 | bbox = scene_bbox() 35 | return (Vector(bbox["min"]) + Vector(bbox["max"])) / 2 36 | 37 | 38 | def scene_fov(): 39 | x_fov = scene.camera.data.angle_x 40 | y_fov = scene.camera.data.angle_y 41 | width = bpy.context.scene.render.resolution_x 42 | height = bpy.context.scene.render.resolution_y 43 | if scene.camera.data.angle == x_fov: 44 | y_fov = 2 * math.atan(math.tan(x_fov / 2) * height / width) 45 | else: 46 | x_fov = 2 * math.atan(math.tan(y_fov / 2) * width / height) 47 | return x_fov, y_fov 48 | 49 | 50 | with open(os.path.join(OUTPUT_DIR, "metadata.json"), "w") as f: 51 | json.dump(scene_bbox(), f) 52 | 53 | scene = bpy.context.scene 54 | backup_matrix = scene.camera.matrix_world.copy() 55 | camera_dist = (backup_matrix.to_translation() - scene_center()).length 56 | backup_path = scene.render.filepath 57 | backup_format = scene.render.image_settings.file_format 58 | try: 59 | scene.render.image_settings.file_format = "PNG" 60 | for i in range(NUM_FRAMES): 61 | scene.render.filepath = os.path.join(OUTPUT_DIR, f"{i:05}") 62 | 63 | x_fov, y_fov = scene_fov() 64 | 65 | direction = random_unit_vector() 66 | camera_pos = scene_center() - camera_dist * direction 67 | scene.camera.location = camera_pos 68 | 69 | # https://blender.stackexchange.com/questions/5210/pointing-the-camera-in-a-particular-direction-programmatically 70 | rot_quat = direction.to_track_quat("-Z", "Y") 71 | scene.camera.rotation_euler = rot_quat.to_euler() 72 | 73 | bpy.context.view_layer.update() 74 | matrix = scene.camera.matrix_world 75 | with open(scene.render.filepath + ".json", "w") as f: 76 | json.dump( 77 | dict( 78 | origin=list(matrix.col[3])[:3], 79 | x_fov=x_fov, 80 | y_fov=y_fov, 81 | x=list(matrix.col[0])[:3], 82 | y=list(-matrix.col[1])[:3], 83 | z=list(-matrix.col[2])[:3], 84 | ), 85 | f, 86 | ) 87 | 88 | bpy.ops.render.render(write_still=True) 89 | finally: 90 | scene.camera.matrix_world = backup_matrix 91 | bpy.context.view_layer.update() 92 | scene.render.filepath = backup_path 93 | scene.render.image_settings.file_format = backup_format 94 | -------------------------------------------------------------------------------- /convert_dataset/llff.py: -------------------------------------------------------------------------------- 1 | """ 2 | Decode an LLFF dataset. 3 | """ 4 | 5 | import argparse 6 | import json 7 | import os 8 | from functools import partial 9 | from multiprocessing.pool import ThreadPool 10 | from typing import Tuple 11 | 12 | import numpy as np 13 | from PIL import Image 14 | from tqdm.auto import tqdm 15 | 16 | 17 | def main(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--factor", type=float, default=1.0) 20 | parser.add_argument("input_dir", type=str) 21 | parser.add_argument("output_dir", type=str) 22 | args = parser.parse_args() 23 | 24 | img_dir = os.path.join(args.input_dir, "images") 25 | img_paths = [ 26 | os.path.join(img_dir, x) 27 | for x in sorted(os.listdir(img_dir)) 28 | if os.path.splitext(x)[1].lower() in [".jpg", ".jpeg", ".png"] 29 | ] 30 | 31 | pose_path = os.path.join(args.input_dir, "poses_bounds.npy") 32 | pose_bounds = np.load(pose_path) 33 | assert len(pose_bounds) == len(img_paths), "image count must match pose count" 34 | 35 | os.makedirs(args.output_dir, exist_ok=True) 36 | bbox_min, bbox_max = None, None 37 | with ThreadPool(8) as p: 38 | for local_min, local_max in tqdm( 39 | p.imap_unordered( 40 | partial(process_img, args.output_dir, args.factor), 41 | enumerate(zip(pose_bounds, img_paths)), 42 | ) 43 | ): 44 | if bbox_min is None: 45 | bbox_min, bbox_max = local_min, local_max 46 | else: 47 | bbox_min = np.minimum(bbox_min, local_min) 48 | bbox_max = np.maximum(bbox_max, local_max) 49 | 50 | with open(os.path.join(args.output_dir, "metadata.json"), "w") as f: 51 | bbox_info = dict(min=bbox_min.tolist(), max=bbox_max.tolist()) 52 | json.dump(bbox_info, f) 53 | 54 | 55 | def process_img( 56 | output_dir: str, factor: float, item: Tuple[int, Tuple[np.ndarray, str]] 57 | ): 58 | i, (pose_bound, img_path) = item 59 | info = pose_bound[:15].reshape([3, 5]) 60 | x, y, z, pos, hwf = info.T 61 | h, w, focal = hwf 62 | z_near, z_far = pose_bound[15:] 63 | _ = z_near 64 | 65 | # https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/load_llff.py#L250 66 | x, y, z = y, -x, z 67 | 68 | # Same changes as in blender.py 69 | y = -y 70 | z = -z 71 | 72 | local_min = pos - z_far 73 | local_max = pos + z_far 74 | 75 | info = dict( 76 | origin=pos.tolist(), 77 | x_fov=float(2 * np.arctan(w / (2 * focal))), 78 | y_fov=float(2 * np.arctan(h / (2 * focal))), 79 | x=x.tolist(), 80 | y=y.tolist(), 81 | z=z.tolist(), 82 | ) 83 | with open(os.path.join(output_dir, f"{i:05}.json"), "w") as f: 84 | json.dump(info, f) 85 | img_path_out = os.path.join(output_dir, f"{i:05}.png") 86 | new_img = Image.open(img_path).convert("RGB") 87 | if factor != 1.0: 88 | old_w, old_h = new_img.size 89 | new_img = new_img.resize((round(old_w * factor), round(old_h * factor))) 90 | new_img.save(img_path_out) 91 | 92 | return local_min, local_max 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/unixpickle/learn-nerf 2 | 3 | go 1.17 4 | 5 | require ( 6 | github.com/unixpickle/essentials v1.3.0 7 | github.com/unixpickle/model3d v0.3.3 8 | ) 9 | 10 | require ( 11 | github.com/pkg/errors v0.9.1 // indirect 12 | github.com/unixpickle/splaytree v0.0.0-20160517015709-ba216b293df0 // indirect 13 | ) 14 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 2 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 3 | github.com/unixpickle/essentials v1.3.0 h1:H258Z5Uo1pVzFjxD2rwFWzHPN3s0J0jLs5kuxTRSfCs= 4 | github.com/unixpickle/essentials v1.3.0/go.mod h1:dQ1idvqrgrDgub3mfckQm7osVPzT3u9rB6NK/LEhmtQ= 5 | github.com/unixpickle/model3d v0.3.2 h1:Zf00+2b8JmdebvY8TxpGbo7JgdcOvzwkwbOkmffF1ng= 6 | github.com/unixpickle/model3d v0.3.2/go.mod h1:Xu7k4U/wrdq//+bGAo9QrQ3lrRXA+tiV2FAf4TEf6FE= 7 | github.com/unixpickle/model3d v0.3.3 h1:nq9n+BwkIdJzXvMabNXuokumr5r5/9HLCQVtApIhKa4= 8 | github.com/unixpickle/model3d v0.3.3/go.mod h1:Xu7k4U/wrdq//+bGAo9QrQ3lrRXA+tiV2FAf4TEf6FE= 9 | github.com/unixpickle/splaytree v0.0.0-20160517015709-ba216b293df0 h1:vf24zG+kzuiwC/Y8I4MeIl0C0StIiBAmFPsnpv/BtuA= 10 | github.com/unixpickle/splaytree v0.0.0-20160517015709-ba216b293df0/go.mod h1:GaKWGsPs4eeIaQbzcYyytkXTrMTczow7bvvBlQSKX1c= 11 | -------------------------------------------------------------------------------- /learn_nerf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unixpickle/learn-nerf/a857c84d4dd50314fe5e1d36f043c3e7ddedb8e5/learn_nerf/__init__.py -------------------------------------------------------------------------------- /learn_nerf/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | from abc import abstractmethod 5 | from dataclasses import dataclass 6 | from typing import Iterator, List, Tuple 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | import numpy as np 11 | from jax._src.prng import PRNGKeyArray as KeyArray 12 | from PIL import Image 13 | 14 | Vec3 = Tuple[float, float, float] 15 | 16 | 17 | @dataclass 18 | class CameraView: 19 | camera_direction: Vec3 20 | camera_origin: Vec3 21 | x_axis: Vec3 22 | y_axis: Vec3 23 | x_fov: float 24 | y_fov: float 25 | 26 | @classmethod 27 | def from_json(cls, path: str, **kwargs) -> "CameraView": 28 | with open(path, "rb") as f: 29 | camera_info = json.load(f) 30 | return cls( 31 | camera_direction=tuple(camera_info["z"]), 32 | camera_origin=tuple(camera_info["origin"]), 33 | x_axis=tuple(camera_info["x"]), 34 | y_axis=tuple(camera_info["y"]), 35 | x_fov=float(camera_info["x_fov"]), 36 | y_fov=float(camera_info["y_fov"]), 37 | **kwargs, 38 | ) 39 | 40 | def to_json(self) -> str: 41 | return json.dumps( 42 | dict( 43 | z=self.camera_direction, 44 | origin=self.camera_origin, 45 | x=self.x_axis, 46 | y=self.y_axis, 47 | x_fov=self.x_fov, 48 | y_fov=self.y_fov, 49 | ) 50 | ) 51 | 52 | def bare_rays(self, width: int, height: int) -> jnp.ndarray: 53 | """ 54 | Get all of the rays in the view in raster scan order. 55 | 56 | Returns an [N x 2 x 3] array of (origin, direction) pairs. 57 | """ 58 | z = jnp.array(self.camera_direction, dtype=jnp.float32) 59 | ys = ( 60 | math.tan(self.y_fov / 2) 61 | * jnp.linspace(-1, 1, num=height)[:, None, None] 62 | * jnp.array(self.y_axis, dtype=jnp.float32) 63 | ) 64 | xs = ( 65 | math.tan(self.x_fov / 2) 66 | * jnp.linspace(-1, 1, num=width)[None, :, None] 67 | * jnp.array(self.x_axis, dtype=jnp.float32) 68 | ) 69 | directions = jnp.reshape(xs + ys + z, [-1, 3]) 70 | directions = directions / jnp.linalg.norm(directions, axis=-1, keepdims=True) 71 | origins = jnp.reshape( 72 | jnp.tile( 73 | jnp.array(self.camera_origin, dtype=jnp.float32)[None, None], 74 | (height, width, 1), 75 | ), 76 | [-1, 3], 77 | ) 78 | return jnp.stack([origins, directions], axis=1) 79 | 80 | 81 | @dataclass 82 | class NeRFView(CameraView): 83 | @abstractmethod 84 | def image(self) -> jnp.ndarray: 85 | """ 86 | Load the image as a [Height x Width x 3] array of uint8 RGB values. 87 | """ 88 | 89 | def rays(self) -> jnp.ndarray: 90 | """ 91 | Get all of the rays in the view with their corresponding colors as a 92 | single compact array. 93 | 94 | Returns an array of shape [N x 3 x 3] where each [3 x 3] element is a 95 | row-major tuple (origin, direction, color). Colors are stored as RGB 96 | values in the range [-1, 1]. 97 | """ 98 | img = self.image() 99 | bare = self.bare_rays(img.shape[1], img.shape[0]) 100 | colors = jnp.reshape(img, [-1, 3]).astype(jnp.float32) / 127.5 - 1 101 | return jnp.concatenate([bare, colors[:, None]], axis=1) 102 | 103 | 104 | @dataclass 105 | class FileNeRFView(NeRFView): 106 | image_path: str 107 | 108 | def image(self) -> jnp.ndarray: 109 | # Premultiply alpha to prevent egregious errors at the border. 110 | rgba = jnp.array(Image.open(self.image_path).convert("RGBA")) 111 | return jnp.round((rgba[:, :, :3] * (rgba[:, :, 3:] / 255))).astype(jnp.uint8) 112 | 113 | 114 | @dataclass 115 | class ModelMetadata: 116 | # Scene/object bounding box. 117 | bbox_min: Vec3 118 | bbox_max: Vec3 119 | 120 | @classmethod 121 | def from_json(cls, path: str) -> "ModelMetadata": 122 | with open(path, "rb") as f: 123 | metadata = json.load(f) 124 | return ModelMetadata( 125 | bbox_min=tuple(metadata["min"]), bbox_max=tuple(metadata["max"]) 126 | ) 127 | 128 | 129 | @dataclass 130 | class NeRFDataset: 131 | metadata: ModelMetadata 132 | views: List[NeRFView] 133 | 134 | def iterate_batches( 135 | self, 136 | dir_path: str, 137 | key: KeyArray, 138 | batch_size: int, 139 | repeat: bool = True, 140 | num_shards: int = 32, 141 | ) -> Iterator[jnp.ndarray]: 142 | """ 143 | Load batches of colored rays from the dataset in a shuffled fashion. 144 | 145 | :param dir_path: directory where the shuffled data is stored. 146 | :param key: the RNG seed for shuffling the data. 147 | :param batch_size: the number of rays to load per batch. 148 | :param repeat: if True, repeat the data after all rays have been 149 | exhausted. If this is False, then the final batch may be 150 | smaller than batch_size. 151 | :param num_shards: the number of temporary files to split the ray data 152 | into while shuffling. Using more shards increases 153 | the number of open file descriptors but reduces the 154 | RAM usage of the dataset. 155 | :return: an iterator over [N x 3 x 3] batches of rays, where each ray 156 | is a tuple (origin, direction, color). 157 | """ 158 | with ShuffledDataset(dir_path, self, key, num_shards=num_shards) as sd: 159 | yield from sd.iterate_batches(batch_size, repeat=repeat) 160 | 161 | 162 | class ShuffledDataset: 163 | """ 164 | A pre-shuffled version of the rays in a NeRFDataset. 165 | 166 | Uses the Jane Street two-stage shuffle as described in: 167 | https://blog.janestreet.com/how-to-shuffle-a-big-dataset/. 168 | 169 | :param dir_path: the directory to store results. 170 | :param dataset: the dataset to shuffle. 171 | :param key: the RNG key for shuffling. 172 | :param num_shards: the number of files to split rays into. More shards 173 | uses less memory but more file descriptors. 174 | """ 175 | 176 | def __init__( 177 | self, 178 | dir_path: str, 179 | dataset: NeRFDataset, 180 | key: KeyArray, 181 | num_shards: int = 32, 182 | ): 183 | self.num_shards = num_shards 184 | self.shard_key, self.shuffle_key = jax.random.split(key) 185 | if not os.path.exists(dir_path): 186 | os.mkdir(dir_path) 187 | done_path = os.path.join(dir_path, "done") 188 | if os.path.exists(done_path): 189 | self.fds = [ 190 | open(os.path.join(dir_path, f"{i}"), "rb") for i in range(num_shards) 191 | ] 192 | else: 193 | self.fds = [ 194 | open(os.path.join(dir_path, f"{i}"), "wb+") for i in range(num_shards) 195 | ] 196 | self._create_shards(dataset) 197 | with open(done_path, "wb+") as f: 198 | f.write(b"done\n") 199 | 200 | def iterate_batches( 201 | self, batch_size: int, repeat: bool = False 202 | ) -> Iterator[jnp.ndarray]: 203 | """ 204 | Load batches of colored rays from the dataset. 205 | 206 | :param batch_size: the number of rays to load per batch. 207 | :param repeat: if True, repeat the data after all rays have been 208 | exhausted. If this is False, then the final batch may be 209 | smaller than batch_size. 210 | :return: an iterator over [N x 3 x 3] batches of rays, where each ray 211 | is a tuple (origin, direction, color). 212 | """ 213 | key = self.shuffle_key 214 | cur_batch = None 215 | while True: 216 | key, this_key = jax.random.split(key) 217 | shard_indices = np.array( 218 | jax.random.permutation(this_key, jnp.arange(self.num_shards)) 219 | ).tolist() 220 | for shard in shard_indices: 221 | key, this_key = jax.random.split(key) 222 | shard_rays = jax.random.permutation(this_key, self._read_shard(shard)) 223 | if cur_batch is not None: 224 | cur_batch = jnp.concatenate([cur_batch, shard_rays], axis=0) 225 | else: 226 | cur_batch = shard_rays 227 | while cur_batch.shape[0] >= batch_size: 228 | yield cur_batch[:batch_size] 229 | cur_batch = cur_batch[batch_size:] 230 | if not repeat: 231 | break 232 | if cur_batch.shape[0]: 233 | yield cur_batch 234 | 235 | def __enter__(self): 236 | return self 237 | 238 | def __exit__(self, *args): 239 | for fd in self.fds: 240 | fd.close() 241 | 242 | def _create_shards(self, dataset: NeRFDataset): 243 | key = self.shard_key 244 | for view in dataset.views: 245 | rays = view.rays() 246 | key, this_key = jax.random.split(key) 247 | assignments = jax.random.randint( 248 | this_key, [rays.shape[0]], 0, self.num_shards 249 | ) 250 | for shard in range(self.num_shards): 251 | sub_batch = rays[assignments == shard] 252 | if sub_batch.shape[0]: 253 | self._append_shard(shard, sub_batch) 254 | 255 | def _append_shard(self, shard: int, arr: jnp.ndarray): 256 | data = arr.astype(jnp.float32).tobytes() 257 | self.fds[shard].write(data) 258 | 259 | def _read_shard(self, shard: int) -> jnp.ndarray: 260 | f = self.fds[shard] 261 | f.seek(0) 262 | data = f.read() 263 | return jnp.array(np.frombuffer(data, dtype=jnp.float32).reshape([-1, 3, 3])) 264 | 265 | 266 | def load_dataset(directory: str) -> NeRFDataset: 267 | """ 268 | Load a dataset from a directory on disk. 269 | 270 | The dataset is stored as a combination of png files and json metadata files 271 | for each PNG file. For a file X.png, X.json is a file containing the 272 | following keys: "origin", "x", "y", "z", "x_fov", "y_fov" describing the 273 | camera. There is also a global "metadata.json" file containing the bounding 274 | box of the scene, stored as a dictionary with keys "min" and "max". 275 | """ 276 | dataset = NeRFDataset( 277 | metadata=ModelMetadata.from_json(os.path.join(directory, "metadata.json")), 278 | views=[], 279 | ) 280 | for img_name in os.listdir(directory): 281 | if img_name.startswith(".") or not img_name.endswith(".png"): 282 | continue 283 | img_path = os.path.join(directory, img_name) 284 | json_path = img_path[: -len(".png")] + ".json" 285 | dataset.views.append(FileNeRFView.from_json(json_path, image_path=img_path)) 286 | return dataset 287 | -------------------------------------------------------------------------------- /learn_nerf/instant_ngp.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple JAX re-implementation of Instant NGP: 3 | https://arxiv.org/abs/2201.05989. 4 | """ 5 | 6 | from typing import Dict, List, Tuple 7 | 8 | import flax.linen as nn 9 | import jax 10 | import jax.numpy as jnp 11 | 12 | from .model import ModelBase, sinusoidal_emb 13 | from .ref_nerf import RefNERFBase 14 | 15 | 16 | class InstantNGPModel(ModelBase): 17 | """ 18 | A NeRF model that utilizes a multilevel hash table. 19 | """ 20 | 21 | table_sizes: List[int] 22 | grid_sizes: List[int] 23 | bbox_min: jnp.ndarray 24 | bbox_max: jnp.ndarray 25 | table_feature_dim: int = 2 26 | table_smooth: bool = False 27 | d_freqs: int = 4 28 | hidden_dim: int = 64 29 | density_dim: int = 16 30 | density_layers: int = 1 31 | color_layers: int = 2 32 | 33 | @nn.compact 34 | def __call__( 35 | self, x: jnp.ndarray, d: jnp.ndarray 36 | ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray]]: 37 | d_emb = sinusoidal_emb(d, self.d_freqs) 38 | out = MultiresHashTableEncoding( 39 | table_sizes=self.table_sizes, 40 | grid_sizes=self.grid_sizes, 41 | bbox_min=self.bbox_min, 42 | bbox_max=self.bbox_max, 43 | feature_dim=self.table_feature_dim, 44 | smooth=self.table_smooth, 45 | )(x) 46 | for _ in range(self.density_layers): 47 | out = nn.relu(nn.Dense(self.hidden_dim)(out)) 48 | out = nn.Dense(self.density_dim)(out) 49 | density = jnp.exp(out[:, :1]) 50 | out = jnp.concatenate([d_emb, out], axis=1) 51 | for _ in range(self.color_layers): 52 | out = nn.relu(nn.Dense(self.hidden_dim)(out)) 53 | color = nn.tanh(nn.Dense(3)(out)) 54 | return density, color, {} 55 | 56 | 57 | class InstantNGPRefNERFModel(RefNERFBase): 58 | """ 59 | A NeRF model that utilizes a multilevel hash table. 60 | """ 61 | 62 | table_sizes: List[int] 63 | grid_sizes: List[int] 64 | bbox_min: jnp.ndarray 65 | bbox_max: jnp.ndarray 66 | table_feature_dim: int = 2 67 | d_freqs: int = 4 68 | hidden_dim: int = 64 69 | density_dim: int = 16 70 | density_layers: int = 1 71 | color_layers: int = 2 72 | 73 | def spatial_block(self, x: jnp.ndarray) -> jnp.ndarray: 74 | x = MultiresHashTableEncoding( 75 | table_sizes=self.table_sizes, 76 | grid_sizes=self.grid_sizes, 77 | bbox_min=self.bbox_min, 78 | bbox_max=self.bbox_max, 79 | feature_dim=self.table_feature_dim, 80 | smooth=True, 81 | )(x) 82 | for _ in range(self.density_layers): 83 | x = nn.relu(nn.Dense(self.hidden_dim)(x)) 84 | return nn.Dense(self.density_dim)(x) 85 | 86 | def directional_block(self, x: jnp.ndarray) -> jnp.ndarray: 87 | for _ in range(self.color_layers): 88 | x = nn.relu(nn.Dense(self.hidden_dim)(x)) 89 | return nn.Dense(3)(x) 90 | 91 | 92 | class MultiresHashTableEncoding(nn.Module): 93 | """ 94 | Encode real-valued spatial coordinates using a multiresolution hash table. 95 | """ 96 | 97 | table_sizes: List[int] 98 | grid_sizes: List[int] 99 | bbox_min: jnp.ndarray 100 | bbox_max: jnp.ndarray 101 | feature_dim: int = 2 102 | smooth: bool = False 103 | 104 | @nn.compact 105 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 106 | results = [] 107 | for table_size, grid_size in zip(self.table_sizes, self.grid_sizes): 108 | results.append( 109 | HashTableEncoding( 110 | table_size=table_size, 111 | grid_size=grid_size, 112 | bbox_min=self.bbox_min, 113 | bbox_max=self.bbox_max, 114 | feature_dim=self.feature_dim, 115 | smooth=self.smooth, 116 | )(x) 117 | ) 118 | return jnp.concatenate(results, axis=1) 119 | 120 | 121 | class HashTableEncoding(nn.Module): 122 | """ 123 | Encode real-valued spatial coordinates using a hash table over a fixed-size 124 | grid of coordinates. 125 | """ 126 | 127 | table_size: int 128 | grid_size: int 129 | bbox_min: jnp.ndarray 130 | bbox_max: jnp.ndarray 131 | feature_dim: int = 2 132 | smooth: bool = False 133 | 134 | @nn.compact 135 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 136 | """ 137 | Compute (interpolated) table entries for the coordinates. 138 | """ 139 | frac = jnp.clip( 140 | (x - self.bbox_min) / (self.bbox_max - self.bbox_min), a_min=0, a_max=1 141 | ) 142 | if self.smooth: 143 | # Shift by half a grid cell so that grid boundaries do not coincide 144 | # at different levels, since boundaries have zero derivatives. 145 | fractional_index = 0.5 + (self.grid_size - 2) * frac 146 | else: 147 | fractional_index = (self.grid_size - 1) * frac 148 | floored = jnp.floor(fractional_index) 149 | 150 | # Avoid out-of-bounds when adding 1 to floor(x) to get ceil(x). 151 | floored = jnp.clip(floored, a_max=self.grid_size - 2) 152 | 153 | ceil_frac = fractional_index - floored 154 | if self.smooth: 155 | ceil_frac = (ceil_frac ** 2) * (3 - 2 * ceil_frac) 156 | 157 | floored = floored.astype(jnp.uint32) 158 | 159 | all_coords = [] 160 | all_weights = [] 161 | for x_offset in [0, 1]: 162 | for y_offset in [0, 1]: 163 | for z_offset in [0, 1]: 164 | offset = jnp.array( 165 | [x_offset, y_offset, z_offset], dtype=floored.dtype 166 | ) 167 | all_coords.append(floored + offset) 168 | all_weights.append( 169 | jnp.prod( 170 | 1 171 | + (2 * ceil_frac - 1) * offset.astype(ceil_frac.dtype) 172 | - ceil_frac, 173 | axis=-1, 174 | keepdims=True, 175 | ) 176 | ) 177 | 178 | if self.grid_size ** 3 > self.table_size: 179 | table = self.param( 180 | "table", 181 | lambda key: 1e-4 182 | * ( 183 | jax.random.uniform(key, (self.table_size, self.feature_dim)) * 2 - 1 184 | ), 185 | ) 186 | all_lookup_results = jnp.concatenate(all_weights) * hash_table_lookup( 187 | table, jnp.concatenate(all_coords) 188 | ) 189 | else: 190 | table = self.param( 191 | "table", 192 | lambda key: 1e-4 193 | * ( 194 | jax.random.uniform(key, (self.grid_size ** 3, self.feature_dim)) * 2 195 | - 1 196 | ), 197 | ) 198 | coords = jnp.concatenate(all_coords) 199 | indices = coords[:, 0] + self.grid_size * ( 200 | coords[:, 1] + self.grid_size * coords[:, 2] 201 | ) 202 | all_lookup_results = ( 203 | jnp.concatenate(all_weights) * jnp.array(table)[indices] 204 | ) 205 | 206 | return jnp.sum( 207 | all_lookup_results.reshape([8, x.shape[0], self.feature_dim]), axis=0 208 | ) 209 | 210 | 211 | def hash_table_lookup(table: jnp.ndarray, coords: jnp.ndarray) -> jnp.ndarray: 212 | """ 213 | Lookup the integer coordinates in a hash table. 214 | 215 | :param table: a [T x F] table of T entries. 216 | :param coords: an [N x 3] batch of 3D integer coordinates. 217 | :return: an [N x F] batch of table lookup results. 218 | """ 219 | coords = coords.astype(jnp.uint32) 220 | # Decorrelate the dimensions with a linear congruential permutation. 221 | indices = ( 222 | coords[:, 0] ^ (19_349_663 * coords[:, 1]) ^ (83_492_791 * coords[:, 2]) 223 | ) % table.shape[0] 224 | return jnp.array(table)[indices] 225 | -------------------------------------------------------------------------------- /learn_nerf/model.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | 3 | import flax.linen as nn 4 | import jax.numpy as jnp 5 | 6 | 7 | class ModelBase(nn.Module): 8 | """ 9 | Base class used by all NeRF models. 10 | """ 11 | 12 | def __call__( 13 | self, x: jnp.ndarray, d: jnp.ndarray 14 | ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray]]: 15 | """ 16 | Predict densities and RGBs for sampled points on rays. 17 | 18 | :param x: an [N x 3] array of coordinates. 19 | :param d: an [N x 3] array of ray directions. 20 | :return: a tuple (density, rgb, aux_losses): 21 | - density: an [N x 1] array of non-negative densities. 22 | - rgb: an [N x 3] array of RGB values in [-1, 1]. 23 | - aux_losses: a collection of [N] arrays containing per-ray 24 | auxiliary losses. 25 | """ 26 | _ = x, d 27 | raise NotImplementedError 28 | 29 | 30 | class NeRFModel(ModelBase): 31 | """ 32 | A model architecture based directly on Mildenhall et al. (2020). 33 | """ 34 | 35 | input_layers: int = 5 36 | mid_layers: int = 4 37 | hidden_dim: int = 256 38 | color_layer_dim: int = 128 39 | x_freqs: int = 10 40 | d_freqs: int = 4 41 | 42 | @nn.compact 43 | def __call__( 44 | self, x: jnp.ndarray, d: jnp.ndarray 45 | ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray]]: 46 | x_emb = sinusoidal_emb(x, self.x_freqs) 47 | d_emb = sinusoidal_emb(d, self.d_freqs) 48 | 49 | z = x_emb 50 | for _ in range(self.input_layers): 51 | z = nn.relu(nn.Dense(self.hidden_dim)(z)) 52 | z = jnp.concatenate([z, x_emb], axis=-1) 53 | for i in range(self.mid_layers): 54 | if i > 0: 55 | z = nn.relu(z) 56 | z = nn.Dense(self.hidden_dim)(z) 57 | density = nn.softplus(nn.Dense(1)(z)) 58 | z = jnp.concatenate([z, d_emb], axis=-1) 59 | z = nn.relu(nn.Dense(self.color_layer_dim)(z)) 60 | rgb = nn.tanh(nn.Dense(3)(z)) 61 | 62 | return density, rgb, {} 63 | 64 | 65 | def sinusoidal_emb(coords: jnp.ndarray, freqs: int) -> jnp.ndarray: 66 | """ 67 | Compute sinusoidal embeddings for some input coordinates. 68 | 69 | :param coords: an [N x D] array of coordinates. 70 | :return: an [N x D*freqs*2] array of embeddings. 71 | """ 72 | coeffs = 2 ** jnp.arange(freqs, dtype=jnp.float32) 73 | inputs = coords[..., None] * coeffs 74 | sines = jnp.sin(inputs) 75 | cosines = jnp.cos(inputs) 76 | combined = jnp.concatenate([sines, cosines], axis=-1) 77 | return combined.reshape(combined.shape[:-2] + (-1,)) 78 | -------------------------------------------------------------------------------- /learn_nerf/ref_nerf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Primitives and helpers for Ref-NeRF: https://arxiv.org/abs/2112.03907. 3 | """ 4 | 5 | import math 6 | from typing import Dict, Tuple 7 | 8 | import flax.linen as nn 9 | import jax 10 | import jax.numpy as jnp 11 | import numpy as np 12 | 13 | from .model import ModelBase, sinusoidal_emb 14 | 15 | HARMONIC_COUNTS = [1, 3, 5, 7, 9, 11, 13, 15] 16 | REF_NERF_OUT_DIM = 9 17 | 18 | 19 | class RefNERFBase(ModelBase): 20 | """ 21 | A base class for Ref-NeRF models. 22 | """ 23 | 24 | sh_degree: int 25 | 26 | def spatial_block(self, x: jnp.ndarray) -> jnp.ndarray: 27 | _ = x 28 | raise NotImplementedError 29 | 30 | def directional_block(self, x: jnp.ndarray) -> jnp.ndarray: 31 | _ = x 32 | raise NotImplementedError 33 | 34 | @nn.compact 35 | def __call__( 36 | self, x: jnp.ndarray, d: jnp.ndarray 37 | ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray]]: 38 | def spatial_fn(x): 39 | out = self.spatial_block(x) 40 | return -out[:, 0].sum(), out 41 | 42 | real_normal, spatial_out = jax.grad(spatial_fn, has_aux=True)(x) 43 | real_normal = _safe_normalize(real_normal) 44 | 45 | density, diffuse_color, spectral, roughness, normal, _bottleneck = jnp.split( 46 | spatial_out, np.cumsum([1, 3, 1, 1, 3]).tolist(), axis=-1 47 | ) 48 | density = jnp.exp(density) 49 | 50 | # Designed to initialize diffuse to 0.25, so that initial 51 | # summed color is 0.5. 52 | diffuse_color = nn.sigmoid(diffuse_color - math.log(3)) 53 | 54 | spectral = nn.sigmoid(spectral) 55 | roughness = nn.softplus(roughness) 56 | normal = _safe_normalize(normal) 57 | 58 | reflection = d - 2 * normal * jnp.sum(d * normal, axis=-1, keepdims=True) 59 | reflection_enc = integrated_directional_encoding( 60 | self.sh_degree, reflection, roughness 61 | ) 62 | normal_dot = jnp.sum(-d * normal, axis=-1, keepdims=True) 63 | dir_input = jnp.concatenate([spatial_out, reflection_enc, normal_dot], axis=1) 64 | dir_output = self.directional_block(dir_input) 65 | spectral_color = nn.sigmoid(dir_output) 66 | 67 | full_color = ( 68 | linear_rgb_to_srgb(_leaky_clip(spectral_color * spectral + diffuse_color)) 69 | * 2 70 | - 1 71 | ) 72 | aux_losses = dict( 73 | normal_mse=jnp.sum((normal - real_normal) ** 2, axis=-1), 74 | neg_normal=jnp.maximum(0.0, jnp.sum(normal * d, axis=-1)) ** 2, 75 | ) 76 | 77 | return density, full_color, aux_losses 78 | 79 | 80 | class RefNERFModel(RefNERFBase): 81 | """ 82 | A Ref-NeRF model built upon the original NeRF architecture. 83 | """ 84 | 85 | input_layers: int = 5 86 | mid_layers: int = 4 87 | hidden_dim: int = 256 88 | color_layer_dim: int = 128 89 | x_freqs: int = 10 90 | d_freqs: int = 4 91 | 92 | def spatial_block(self, x: jnp.ndarray) -> jnp.ndarray: 93 | x_emb = sinusoidal_emb(x, self.x_freqs) 94 | 95 | z = x_emb 96 | for _ in range(self.input_layers): 97 | z = nn.relu(nn.Dense(self.hidden_dim)(z)) 98 | z = jnp.concatenate([z, x_emb], axis=-1) 99 | for i in range(self.mid_layers): 100 | if i > 0: 101 | z = nn.relu(z) 102 | z = nn.Dense(self.hidden_dim)(z) 103 | return z 104 | 105 | def directional_block(self, x: jnp.ndarray) -> jnp.ndarray: 106 | z = nn.relu(nn.Dense(self.color_layer_dim)(x)) 107 | return nn.Dense(3)(z) 108 | 109 | 110 | def linear_rgb_to_srgb(colors: jnp.ndarray): 111 | """ 112 | Perform Gamma compression to convert linear RGB colors to sRGB. 113 | """ 114 | # https://github.com/google/jax/issues/5798 115 | safe_colors = jnp.maximum(1e-5, colors) 116 | return jnp.where( 117 | colors <= 0.0031308, 12.92 * colors, 1.055 * (safe_colors ** (1 / 2.4)) - 0.055 118 | ) 119 | 120 | 121 | def integrated_directional_encoding( 122 | sh_degree: int, coords: jnp.ndarray, roughness: jnp.ndarray 123 | ) -> jnp.ndarray: 124 | """ 125 | Compute the integrated directional encoding for the 3D coordinates given 126 | the corresponding roughness parameter. Intuitively, higher roughness 127 | parameters ignore the coordinates more. 128 | 129 | :param sh_degree: the degree of the harmonics. Should be in range [1, 8]. 130 | :param coords: an [N x 3] array of normalized coordinates. 131 | :param roughness: an [N x 1] array of densities. 132 | :return: an [N x D] array of integrated directional encodings. 133 | """ 134 | assert len(roughness.shape) == 2 and roughness.shape[1] == 1 135 | assert len(coords.shape) == 2 and coords.shape[1] == 3 136 | 137 | levels = jnp.array( 138 | [x for i, y in enumerate(HARMONIC_COUNTS[:sh_degree]) for x in [i] * y], 139 | dtype=roughness.dtype, 140 | ) 141 | attenuation = jnp.exp(-roughness * (levels * (levels + 1)) / 2) 142 | harmonics = spherical_harmonic(sh_degree, coords) 143 | return harmonics * attenuation 144 | 145 | 146 | def spherical_harmonic(sh_degree: int, coords: jnp.ndarray) -> jnp.ndarray: 147 | """ 148 | Compute the spherical harmonic encoding of the 3D coordinates. 149 | 150 | :param sh_degree: the degree of the harmonics. Should be in range [1, 8]. 151 | :param coords: an [N x 3] array of normalized coordinates. 152 | :return: an [N x D] array of harmonic encodings. 153 | """ 154 | assert sh_degree >= 1 and sh_degree <= 8 155 | # Based on https://github.com/NVlabs/tiny-cuda-nn/blob/8575542682cb67cddfc748cc3d3cfc12593799aa/include/tiny-cuda-nn/encodings/spherical_harmonics.h#L76 156 | x, y, z = coords[:, 0], coords[:, 1], coords[:, 2] 157 | 158 | xy = x * y 159 | xz = x * z 160 | yz = y * z 161 | x2 = x * x 162 | y2 = y * y 163 | z2 = z * z 164 | x4 = x2 * x2 165 | y4 = y2 * y2 166 | z4 = z2 * z2 167 | x6 = x4 * x2 168 | y6 = y4 * y2 169 | z6 = z4 * z2 170 | 171 | out = [None] * 64 172 | 173 | def populate(): 174 | out[0] = jnp.broadcast_to(jnp.array(0.28209479177387814), x.shape) 175 | if sh_degree <= 1: 176 | return 177 | out[1] = -0.48860251190291987 * y 178 | out[2] = 0.48860251190291987 * z 179 | out[3] = -0.48860251190291987 * x 180 | if sh_degree <= 2: 181 | return 182 | out[4] = 1.0925484305920792 * xy 183 | out[5] = -1.0925484305920792 * yz 184 | out[6] = 0.94617469575755997 * z2 - 0.31539156525251999 185 | out[7] = -1.0925484305920792 * xz 186 | out[8] = 0.54627421529603959 * x2 - 0.54627421529603959 * y2 187 | if sh_degree <= 3: 188 | return 189 | out[9] = 0.59004358992664352 * y * (-3.0 * x2 + y2) 190 | out[10] = 2.8906114426405538 * xy * z 191 | out[11] = 0.45704579946446572 * y * (1.0 - 5.0 * z2) 192 | out[12] = 0.3731763325901154 * z * (5.0 * z2 - 3.0) 193 | out[13] = 0.45704579946446572 * x * (1.0 - 5.0 * z2) 194 | out[14] = 1.4453057213202769 * z * (x2 - y2) 195 | out[15] = 0.59004358992664352 * x * (-x2 + 3.0 * y2) 196 | if sh_degree <= 4: 197 | return 198 | out[16] = 2.5033429417967046 * xy * (x2 - y2) 199 | out[17] = 1.7701307697799304 * yz * (-3.0 * x2 + y2) 200 | out[18] = 0.94617469575756008 * xy * (7.0 * z2 - 1.0) 201 | out[19] = 0.66904654355728921 * yz * (3.0 - 7.0 * z2) 202 | out[20] = ( 203 | -3.1735664074561294 * z2 + 3.7024941420321507 * z4 + 0.31735664074561293 204 | ) 205 | out[21] = 0.66904654355728921 * xz * (3.0 - 7.0 * z2) 206 | out[22] = 0.47308734787878004 * (x2 - y2) * (7.0 * z2 - 1.0) 207 | out[23] = 1.7701307697799304 * xz * (-x2 + 3.0 * y2) 208 | out[24] = ( 209 | -3.7550144126950569 * x2 * y2 210 | + 0.62583573544917614 * x4 211 | + 0.62583573544917614 * y4 212 | ) 213 | if sh_degree <= 5: 214 | return 215 | out[25] = 0.65638205684017015 * y * (10.0 * x2 * y2 - 5.0 * x4 - y4) 216 | out[26] = 8.3026492595241645 * xy * z * (x2 - y2) 217 | out[27] = -0.48923829943525038 * y * (3.0 * x2 - y2) * (9.0 * z2 - 1.0) 218 | out[28] = 4.7935367849733241 * xy * z * (3.0 * z2 - 1.0) 219 | out[29] = 0.45294665119569694 * y * (14.0 * z2 - 21.0 * z4 - 1.0) 220 | out[30] = 0.1169503224534236 * z * (-70.0 * z2 + 63.0 * z4 + 15.0) 221 | out[31] = 0.45294665119569694 * x * (14.0 * z2 - 21.0 * z4 - 1.0) 222 | out[32] = 2.3967683924866621 * z * (x2 - y2) * (3.0 * z2 - 1.0) 223 | out[33] = -0.48923829943525038 * x * (x2 - 3.0 * y2) * (9.0 * z2 - 1.0) 224 | out[34] = 2.0756623148810411 * z * (-6.0 * x2 * y2 + x4 + y4) 225 | out[35] = 0.65638205684017015 * x * (10.0 * x2 * y2 - x4 - 5.0 * y4) 226 | if sh_degree <= 6: 227 | return 228 | out[36] = 1.3663682103838286 * xy * (-10.0 * x2 * y2 + 3.0 * x4 + 3.0 * y4) 229 | out[37] = 2.3666191622317521 * yz * (10.0 * x2 * y2 - 5.0 * x4 - y4) 230 | out[38] = 2.0182596029148963 * xy * (x2 - y2) * (11.0 * z2 - 1.0) 231 | out[39] = -0.92120525951492349 * yz * (3.0 * x2 - y2) * (11.0 * z2 - 3.0) 232 | out[40] = 0.92120525951492349 * xy * (-18.0 * z2 + 33.0 * z4 + 1.0) 233 | out[41] = 0.58262136251873131 * yz * (30.0 * z2 - 33.0 * z4 - 5.0) 234 | out[42] = ( 235 | 6.6747662381009842 * z2 236 | - 20.024298714302954 * z4 237 | + 14.684485723822165 * z6 238 | - 0.31784601133814211 239 | ) 240 | out[43] = 0.58262136251873131 * xz * (30.0 * z2 - 33.0 * z4 - 5.0) 241 | out[44] = ( 242 | 0.46060262975746175 243 | * (x2 - y2) 244 | * (11.0 * z2 * (3.0 * z2 - 1.0) - 7.0 * z2 + 1.0) 245 | ) 246 | out[45] = -0.92120525951492349 * xz * (x2 - 3.0 * y2) * (11.0 * z2 - 3.0) 247 | out[46] = 0.50456490072872406 * (11.0 * z2 - 1.0) * (-6.0 * x2 * y2 + x4 + y4) 248 | out[47] = 2.3666191622317521 * xz * (10.0 * x2 * y2 - x4 - 5.0 * y4) 249 | out[48] = ( 250 | 10.247761577878714 * x2 * y4 251 | - 10.247761577878714 * x4 * y2 252 | + 0.6831841051919143 * x6 253 | - 0.6831841051919143 * y6 254 | ) 255 | if sh_degree <= 7: 256 | return 257 | out[49] = ( 258 | 0.70716273252459627 * y * (-21.0 * x2 * y4 + 35.0 * x4 * y2 - 7.0 * x6 + y6) 259 | ) 260 | out[50] = 5.2919213236038001 * xy * z * (-10.0 * x2 * y2 + 3.0 * x4 + 3.0 * y4) 261 | out[51] = ( 262 | -0.51891557872026028 263 | * y 264 | * (13.0 * z2 - 1.0) 265 | * (-10.0 * x2 * y2 + 5.0 * x4 + y4) 266 | ) 267 | out[52] = 4.1513246297620823 * xy * z * (x2 - y2) * (13.0 * z2 - 3.0) 268 | out[53] = ( 269 | -0.15645893386229404 270 | * y 271 | * (3.0 * x2 - y2) 272 | * (13.0 * z2 * (11.0 * z2 - 3.0) - 27.0 * z2 + 3.0) 273 | ) 274 | out[54] = 0.44253269244498261 * xy * z * (-110.0 * z2 + 143.0 * z4 + 15.0) 275 | out[55] = ( 276 | 0.090331607582517306 * y * (-135.0 * z2 + 495.0 * z4 - 429.0 * z6 + 5.0) 277 | ) 278 | out[56] = ( 279 | 0.068284276912004949 * z * (315.0 * z2 - 693.0 * z4 + 429.0 * z6 - 35.0) 280 | ) 281 | out[57] = ( 282 | 0.090331607582517306 * x * (-135.0 * z2 + 495.0 * z4 - 429.0 * z6 + 5.0) 283 | ) 284 | out[58] = ( 285 | 0.07375544874083044 286 | * z 287 | * (x2 - y2) 288 | * (143.0 * z2 * (3.0 * z2 - 1.0) - 187.0 * z2 + 45.0) 289 | ) 290 | out[59] = ( 291 | -0.15645893386229404 292 | * x 293 | * (x2 - 3.0 * y2) 294 | * (13.0 * z2 * (11.0 * z2 - 3.0) - 27.0 * z2 + 3.0) 295 | ) 296 | out[60] = ( 297 | 1.0378311574405206 * z * (13.0 * z2 - 3.0) * (-6.0 * x2 * y2 + x4 + y4) 298 | ) 299 | out[61] = ( 300 | -0.51891557872026028 301 | * x 302 | * (13.0 * z2 - 1.0) 303 | * (-10.0 * x2 * y2 + x4 + 5.0 * y4) 304 | ) 305 | out[62] = 2.6459606618019 * z * (15.0 * x2 * y4 - 15.0 * x4 * y2 + x6 - y6) 306 | out[63] = ( 307 | 0.70716273252459627 * x * (-35.0 * x2 * y4 + 21.0 * x4 * y2 - x6 + 7.0 * y6) 308 | ) 309 | 310 | populate() 311 | return jnp.stack([x for x in out if x is not None], axis=1) 312 | 313 | 314 | def _safe_normalize(vs: jnp.ndarray, eps=1e-10) -> jnp.ndarray: 315 | # Using jnp.linalg.norm is not safe at exactly 0. 316 | # https://github.com/google/jax/issues/3058 317 | return vs / jnp.sqrt(jnp.sum(vs ** 2, axis=-1, keepdims=True) + eps) 318 | 319 | 320 | def _leaky_clip(x: jnp.ndarray) -> jnp.ndarray: 321 | """ 322 | Clip x to the range [0, 1] while still allowing gradients to push it back 323 | inside the bounds. 324 | """ 325 | delta = jax.lax.stop_gradient(jnp.clip(x, 0, 1) - x) 326 | return x + delta 327 | -------------------------------------------------------------------------------- /learn_nerf/render.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict, Tuple, Union 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | from jax._src.prng import PRNGKeyArray as KeyArray 7 | 8 | from .model import ModelBase 9 | 10 | 11 | @dataclass 12 | class NeRFRenderer: 13 | """ 14 | A NeRF hierarchy with corresponding settings for rendering rays. 15 | 16 | :param coarse: the coarse model. 17 | :param fine: the fine model. 18 | :param coarse_params: params of the coarse model. 19 | :param fine_params: params of the fine model. 20 | :param background: the [3] RGB array background color. 21 | :param bbox_min: minimum point of the scene bounding box. 22 | :param bbox_max: maximum point of the scene bounding box. 23 | :param coarse_ts: samples per ray for coarse model. 24 | :param fine_ts: additional samples per ray for fine model. 25 | """ 26 | 27 | coarse: ModelBase 28 | fine: ModelBase 29 | coarse_params: Any 30 | fine_params: Any 31 | background: jnp.ndarray 32 | bbox_min: jnp.ndarray 33 | bbox_max: jnp.ndarray 34 | coarse_ts: int 35 | fine_ts: int 36 | 37 | min_t_range: float = 1e-3 38 | 39 | def render_rays( 40 | self, 41 | key: KeyArray, 42 | batch: jnp.ndarray, 43 | ) -> Dict[str, Dict[str, jnp.ndarray]]: 44 | """ 45 | :param key: an RNG key for sampling points along rays. 46 | :param batch: an [N x 2 x 3] batch of (origin, direction) rays. 47 | :return: a dict with keys "fine", "coarse", "fine_aux", "coarse_aux". 48 | The former two keys store render outputs for both layers of 49 | the hierarchy (see render_rays() function for details). 50 | The "fine_aux" and "coarse_aux" keys contain dicts of 51 | (averaged) named auxiliary losses. 52 | """ 53 | t_min, t_max, mask = self.t_range(batch) 54 | 55 | coarse_key, fine_key = jax.random.split(key) 56 | # Evaluate the coarse model using regular stratified sampling. 57 | coarse_ts = RaySamples.stratified_sampling( 58 | t_min=t_min, 59 | t_max=t_max, 60 | mask=mask, 61 | count=self.coarse_ts, 62 | key=coarse_key, 63 | ) 64 | coarse_out, coarse_aux = render_rays( 65 | model=self.coarse, 66 | params=self.coarse_params, 67 | background=self.background, 68 | batch=batch, 69 | ts=coarse_ts, 70 | ) 71 | 72 | # Evaluate the fine model using a combined set of points. 73 | fine_ts = coarse_ts.fine_sampling( 74 | count=self.fine_ts, 75 | key=fine_key, 76 | densities=jax.lax.stop_gradient(coarse_out["densities"]), 77 | ) 78 | fine_out, fine_aux = render_rays( 79 | model=self.fine, 80 | params=self.fine_params, 81 | background=self.background, 82 | batch=batch, 83 | ts=fine_ts, 84 | ) 85 | 86 | return dict( 87 | coarse=coarse_out, 88 | fine=fine_out, 89 | coarse_aux=coarse_aux, 90 | fine_aux=fine_aux, 91 | ) 92 | 93 | def t_range( 94 | self, batch: jnp.ndarray, epsilon: float = 1e-8 95 | ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: 96 | """ 97 | For a batch of rays, compute the t_min and t_max for each ray 98 | according to the scene bounding box. 99 | 100 | :param batch: a batch of rays, each of [N x 2 x 3] rays. 101 | :param epsilon: small offset to add to ray directions to prevent NaNs. 102 | :return: a tuple (t_min, t_max, mask) of [N] arrays. The mask array is 103 | a boolean array, where False means not to render the ray. 104 | """ 105 | bbox = jnp.stack([self.bbox_min, self.bbox_max]) 106 | bounds, mask = jax.vmap( 107 | lambda ray: ray_t_range( 108 | bbox, ray, min_t_range=self.min_t_range, epsilon=epsilon 109 | ) 110 | )(batch) 111 | return bounds[:, 0], bounds[:, 1], mask 112 | 113 | 114 | @dataclass 115 | class RaySamples: 116 | t_min: jnp.ndarray 117 | t_max: jnp.ndarray 118 | mask: jnp.ndarray 119 | ts: jnp.ndarray 120 | 121 | @classmethod 122 | def stratified_sampling( 123 | cls, 124 | t_min: jnp.ndarray, 125 | t_max: jnp.ndarray, 126 | mask: jnp.ndarray, 127 | count: int, 128 | key: KeyArray, 129 | ) -> "RaySamples": 130 | """ 131 | :param t_min: a batch of minimum values. 132 | :param t_max: a batch of maximum values. 133 | :param mask: a batch of masks for each ray. 134 | :param count: number of samples per batch element. 135 | :param key: RNG key for sampling. 136 | :return: samples along the rays. 137 | """ 138 | bin_size = ((t_max - t_min) / count)[:, None] 139 | bin_starts = ( 140 | jnp.arange(0, count, dtype=jnp.float32)[None] * bin_size + t_min[:, None] 141 | ) 142 | randoms = jax.random.uniform(key, bin_starts.shape) * bin_size 143 | return cls(t_min=t_min, t_max=t_max, mask=mask, ts=bin_starts + randoms) 144 | 145 | def points(self, rays: jnp.ndarray) -> jnp.ndarray: 146 | """ 147 | For each ray, compute the points at all ts. 148 | 149 | :param rays: a batch of rays of shape [N x 2 x 3] where each ray is a 150 | tuple (origin, direction). 151 | :return: a batch of points of shape [N x T x 3]. 152 | """ 153 | return rays[:, :1] + (rays[:, 1:] * self.ts[:, :, None]) 154 | 155 | def render_rays( 156 | self, 157 | densities: jnp.ndarray, 158 | rgbs: jnp.ndarray, 159 | background: jnp.ndarray, 160 | ) -> jnp.ndarray: 161 | """ 162 | Perform volumetric rendering given density and color samples along a batch 163 | of rays. 164 | 165 | :param densities: an [N x T] batch of non-negative density outputs. 166 | :param rgbs: an [N x T x 3] batch of RGB values. 167 | :param background: an RGB background color, of shape [3]. 168 | :return: an [N x 3] batch of RGB values. 169 | """ 170 | probs = self.termination_probs(densities) 171 | colors = jnp.concatenate( 172 | [rgbs, jnp.tile(background[None, None], [rgbs.shape[0], 1, 1])], axis=1 173 | ) 174 | return jnp.where( 175 | self.mask[:, None], jnp.sum(probs[..., None] * colors, axis=1), background 176 | ) 177 | 178 | def render_alpha( 179 | self, 180 | densities: jnp.ndarray, 181 | ) -> jnp.ndarray: 182 | """ 183 | Given density estimates along the rays, compute the probability that 184 | each ray hits the object. 185 | 186 | :param densities: an [N x T] batch of non-negative density outputs. 187 | :return: an [N x 1] batch of alpha values. 188 | """ 189 | probs = self.termination_probs(densities) 190 | return jnp.where(self.mask[:, None], 1 - probs[:, -1:], 0.0) 191 | 192 | def average_aux_losses( 193 | self, 194 | densities: jnp.ndarray, 195 | aux: Dict[str, jnp.ndarray], 196 | ) -> jnp.ndarray: 197 | """ 198 | Compute an average of auxiliary losses over each ray, weighted by the 199 | volume density. 200 | 201 | :param densities: an [N x T] batch of non-negative density outputs. 202 | :param aux: a dict mapping loss names to [N x T] batches. 203 | :return: a dict mapping loss names to mean losses. 204 | """ 205 | probs = self.termination_probs(densities)[:, :-1] 206 | return { 207 | k: jnp.mean(jnp.where(self.mask[:, None], jnp.sum(v * probs, axis=-1), 0.0)) 208 | for k, v in aux.items() 209 | } 210 | 211 | def fine_sampling( 212 | self, 213 | count: int, 214 | key: KeyArray, 215 | densities: jnp.ndarray, 216 | combine: bool = True, 217 | eps: float = 1e-8, 218 | ) -> "RaySamples": 219 | """ 220 | Sample points along a ray leveraging density information from a 221 | coarsely sampled ray (stored in self). 222 | 223 | :param count: the number of points to sample. 224 | :param key: the RNG key to use for sampling. 225 | :param densities: the sampled non-negative densities for ts. 226 | :param combine: if True, combine the new sampled points with the old 227 | sampled points in one sorted array. 228 | :param eps: a small probability to add to termination probs to avoid 229 | division by zero. 230 | :return: an [N x T'] array of sampled ts, similar to stratified_sampling(). 231 | """ 232 | w = self.termination_probs(densities)[:, :-1] + eps 233 | 234 | # Setup an inverse CDF for inverse transform sampling. 235 | xs = jnp.cumsum(w, axis=1) 236 | xs = jnp.concatenate([self._const_vec(0.0), xs], axis=1) 237 | xs = xs / xs[:, -1:] # normalize 238 | ys = jnp.concatenate( 239 | [self.t_min[:, None], self.ends()], 240 | axis=1, 241 | ) 242 | 243 | # Evaluate the inverse CDF at quasi-random points. 244 | input_samples = self.stratified_sampling( 245 | t_min=jnp.zeros_like(self.t_min), 246 | t_max=jnp.ones_like(self.t_max), 247 | mask=self.mask, 248 | count=count, 249 | key=key, 250 | ) 251 | new_ts = jax.vmap(jnp.interp)(input_samples.ts, xs, ys) 252 | 253 | if combine: 254 | combined = jnp.concatenate([self.ts, new_ts], axis=1) 255 | new_ts = jnp.sort(combined, axis=1) 256 | 257 | return RaySamples(t_min=self.t_min, t_max=self.t_max, mask=self.mask, ts=new_ts) 258 | 259 | def starts(self) -> jnp.ndarray: 260 | t_mid = (self.ts[:, 1:] + self.ts[:, :-1]) / 2 261 | return jnp.concatenate([self.t_min[:, None], t_mid], axis=1) 262 | 263 | def ends(self) -> jnp.ndarray: 264 | t_mid = (self.ts[:, 1:] + self.ts[:, :-1]) / 2 265 | return jnp.concatenate([t_mid, self.t_max[:, None]], axis=1) 266 | 267 | def deltas(self) -> jnp.ndarray: 268 | return self.ends() - self.starts() 269 | 270 | def termination_probs(self, densities: jnp.ndarray): 271 | density_dt = densities * self.deltas() 272 | 273 | # Compute the integral of termination probability over 274 | # time, to get the probability we make it to time t. 275 | acc_densities_cur = jnp.cumsum(density_dt, axis=1) 276 | acc_densities_prev = jnp.concatenate( 277 | [jnp.zeros_like(acc_densities_cur[:, :1]), acc_densities_cur], axis=1 278 | ) 279 | prob_survive = jnp.exp(-acc_densities_prev) 280 | 281 | # Compute the probability of terminating at time t given 282 | # that we made it to time t. 283 | prob_terminate = jnp.concatenate( 284 | [1 - jnp.exp(-density_dt), self._const_vec(1.0)], axis=1 285 | ) 286 | 287 | return prob_survive * prob_terminate 288 | 289 | def _const_vec(self, x: float) -> jnp.ndarray: 290 | return jnp.tile(jnp.array(x).reshape([1, 1]), [self.ts.shape[0], 1]) 291 | 292 | 293 | def render_rays( 294 | model: ModelBase, 295 | params: Any, 296 | background: jnp.ndarray, 297 | batch: jnp.ndarray, 298 | ts: "RaySamples", 299 | ) -> Tuple[Dict[str, jnp.ndarray], Dict[str, jnp.ndarray]]: 300 | """ 301 | Render a batch of rays using a model. 302 | 303 | :param model: the NeRF model to run. 304 | :param params: the parameter object for the model. 305 | :param background: the [3] array of background color. 306 | :param batch: an [N x 2 x 3] array of (origin, direction) rays. 307 | :param ts: samples along the rays. 308 | :return: a tuple (out, aux). 309 | - out: a dict of results with the following keys: 310 | - outputs: an [N x 3] array of RGB colors. 311 | - rgbs: an [N x T x 3] array of per-point RGB outputs. 312 | - densities: an [N x T] array of model density outputs. 313 | - alphas: an [N x 1] array of hit probabilities. 314 | - coords: an [N x 3] array of ray collision coordinates. 315 | - aux: a dict mapping loss names to (scalar) means of the losses 316 | across all unmasked rays. 317 | """ 318 | all_points = ts.points(batch) 319 | direction_batch = jnp.tile(batch[:, 1:2], [1, all_points.shape[1], 1]) 320 | densities, rgbs, aux = model.apply( 321 | dict(params=params), 322 | all_points.reshape([-1, 3]), 323 | direction_batch.reshape([-1, 3]), 324 | ) 325 | densities = densities.reshape(all_points.shape[:-1]) 326 | rgbs = rgbs.reshape(all_points.shape) 327 | aux = {k: v.reshape(densities.shape) for k, v in aux.items()} 328 | 329 | outputs = ts.render_rays(densities, rgbs, background) 330 | alphas = ts.render_alpha(densities) 331 | coords = ts.render_rays(densities, all_points, jnp.zeros([3], dtype=rgbs.dtype)) 332 | aux_mean = ts.average_aux_losses(densities, aux) 333 | 334 | return ( 335 | dict( 336 | outputs=outputs, 337 | rgbs=rgbs, 338 | densities=densities, 339 | alphas=alphas, 340 | coords=coords, 341 | ), 342 | aux_mean, 343 | ) 344 | 345 | 346 | def ray_t_range( 347 | bbox: jnp.ndarray, 348 | ray: jnp.ndarray, 349 | min_t_range: float = 1e-3, 350 | epsilon: float = 1e-8, 351 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 352 | """ 353 | For a single ray, compute the t_min and t_max for it and return a mask 354 | indicating whether on not the ray intersects the bounding box at all. 355 | 356 | :param bbox: a [2 x 3] array of (bbox_min, bbox_max). 357 | :param ray: a [2 x 3] array containing ray (origin, direction). 358 | :param min_t_range: the minimum space between t_min and t_max. 359 | :param epsilon: small offset to add to ray directions to prevent NaNs. 360 | :return: a tuple (ts, mask) where ts is of shape [2] storing (t_min, t_max) 361 | and mask is a boolean scalar. 362 | """ 363 | origin = ray[0] 364 | direction = ray[1] 365 | 366 | # Find timesteps of collision on each axis: 367 | # o+t*d=b 368 | # t*d=b-o 369 | # t=(b-o)/d 370 | offsets = bbox - origin 371 | ts = offsets / (direction + epsilon) 372 | 373 | # Sort so that the minimum t always comes first. 374 | ts = jnp.concatenate( 375 | [ 376 | jnp.min(ts, axis=0, keepdims=True), 377 | jnp.max(ts, axis=0, keepdims=True), 378 | ], 379 | axis=0, 380 | ) 381 | 382 | # Find overlapping bounds and apply constraints. 383 | min_t = jnp.maximum(0, jnp.max(ts[0])) 384 | max_t = jnp.min(ts[1]) 385 | max_t_clipped = jnp.maximum(max_t, min_t + min_t_range) 386 | real_range = jnp.stack([min_t, max_t_clipped]) 387 | null_range = jnp.array([0, min_t_range]) 388 | mask = min_t < max_t 389 | return jnp.where(mask, real_range, null_range), mask 390 | -------------------------------------------------------------------------------- /learn_nerf/scripts/check_bbox.py: -------------------------------------------------------------------------------- 1 | """ 2 | Compute (min, max, mean) of the pixels for rays shooting outside of a scene's 3 | bounding box to make sure the bounding box actually covers everything. 4 | """ 5 | 6 | import argparse 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | from learn_nerf.dataset import load_dataset 11 | from learn_nerf.render import ray_t_range 12 | from tqdm.auto import tqdm 13 | 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("data_dir") 18 | args = parser.parse_args() 19 | 20 | dataset = load_dataset(args.data_dir) 21 | 22 | bbox = jnp.array((dataset.metadata.bbox_min, dataset.metadata.bbox_max)) 23 | ray_masks = jax.jit( 24 | lambda rays: jax.vmap(lambda ray: ray_t_range(bbox, ray))(rays)[1] 25 | ) 26 | 27 | min_color = None 28 | max_color = None 29 | color_sum = None 30 | total_colors = 0.0 31 | for view in tqdm(dataset.views): 32 | colored_rays = view.rays() 33 | rays, colors = colored_rays[:, :2], colored_rays[:, 2] 34 | masked_colors = colors[~ray_masks(rays)] 35 | if not jnp.any(masked_colors): 36 | continue 37 | local_min = jnp.min(masked_colors, axis=0) 38 | local_max = jnp.max(masked_colors, axis=0) 39 | local_sum = jnp.sum(masked_colors, axis=0) 40 | if min_color is None: 41 | min_color, max_color, color_sum = local_min, local_max, local_sum 42 | else: 43 | min_color = jnp.minimum(min_color, local_min) 44 | max_color = jnp.maximum(max_color, local_max) 45 | color_sum = color_sum + local_sum 46 | total_colors += masked_colors.shape[0] 47 | mean_color = color_sum / total_colors 48 | print("min color", min_color.tolist()) 49 | print("max color", max_color.tolist()) 50 | print("mean color", mean_color.tolist()) 51 | 52 | 53 | if __name__ == "__main__": 54 | main() 55 | -------------------------------------------------------------------------------- /learn_nerf/scripts/cv_nerf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run a NeRF model with K-fold cross-validation to find the frames which have 3 | the highest validation loss. These might be samples in the dataset with 4 | incorrect camera poses. 5 | """ 6 | 7 | import argparse 8 | import random 9 | import sys 10 | import tempfile 11 | from typing import Iterator, List, Set 12 | 13 | import jax 14 | import jax.numpy as jnp 15 | from jax._src.prng import PRNGKeyArray as KeyArray 16 | from learn_nerf.dataset import NeRFDataset, load_dataset 17 | from learn_nerf.scripts.train_nerf import add_model_args, create_model 18 | from learn_nerf.train import TrainLoop 19 | from tqdm.auto import tqdm 20 | 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--seed", type=int, default=0) 25 | parser.add_argument("--lr", type=float, default=1e-4) 26 | parser.add_argument("--batch_size", type=int, default=4096, help="rays per batch") 27 | parser.add_argument( 28 | "--folds", type=int, default=10, help="number of training runs to perform" 29 | ) 30 | parser.add_argument( 31 | "--coarse_samples", type=int, default=64, help="samples per coarse ray" 32 | ) 33 | parser.add_argument( 34 | "--fine_samples", 35 | type=int, 36 | default=128, 37 | help="samples per fine ray (not including coarse samples)", 38 | ) 39 | parser.add_argument("--train_iters", type=int, default=1500) 40 | add_model_args(parser) 41 | parser.add_argument("data_dir", type=str) 42 | args = parser.parse_args() 43 | 44 | print("loading dataset...") 45 | data = load_dataset(args.data_dir) 46 | 47 | global_key = jax.random.PRNGKey( 48 | args.seed if args.seed is not None else random.randint(0, 2 ** 32 - 1) 49 | ) 50 | init_key, shuffle_key, global_key = jax.random.split(global_key, num=3) 51 | shuffle_indices = jax.random.permutation( 52 | shuffle_key, jnp.arange(len(data.views)) 53 | ).tolist() 54 | 55 | for i, valid_indices in enumerate(chunk_indices(args.folds, shuffle_indices)): 56 | print(f"performing cross validation for fold {i}...") 57 | train_data = NeRFDataset( 58 | metadata=data.metadata, 59 | views=[x for i, x in enumerate(data.views) if i not in valid_indices], 60 | ) 61 | valid_data = NeRFDataset( 62 | metadata=data.metadata, 63 | views=[x for i, x in enumerate(data.views) if i in valid_indices], 64 | ) 65 | coarse, fine, train_kwargs = create_model(args, data.metadata) 66 | loop = TrainLoop( 67 | coarse, 68 | fine, 69 | init_rng=init_key, 70 | lr=args.lr, 71 | coarse_ts=args.coarse_samples, 72 | fine_ts=args.fine_samples, 73 | **train_kwargs, 74 | ) 75 | step_fn = loop.step_fn( 76 | jnp.array(data.metadata.bbox_min), 77 | jnp.array(data.metadata.bbox_max), 78 | ) 79 | key = global_key 80 | with tempfile.TemporaryDirectory() as tmp_dir: 81 | data_key, key = jax.random.split(key, 2) 82 | batch_iter = train_data.iterate_batches(tmp_dir, data_key, args.batch_size) 83 | batch = next(batch_iter) 84 | print("dataset shuffling complete.") 85 | for _ in tqdm(range(args.train_iters), file=sys.stderr): 86 | step_key, key = jax.random.split(key, 2) 87 | step_fn(step_key, batch) 88 | batch = next(batch_iter) 89 | valid_results = validation_losses( 90 | key=key, loop=loop, data=valid_data, batch_size=args.batch_size 91 | ) 92 | for view, loss in zip(valid_data.views, valid_results): 93 | print(loss, view.image_path) 94 | 95 | 96 | def validation_losses( 97 | key: KeyArray, loop: TrainLoop, data: NeRFDataset, batch_size: int 98 | ) -> Iterator[float]: 99 | loss_fn = jax.jit( 100 | lambda key, batch, params: loop.losses( 101 | key=key, 102 | bbox_min=jnp.array(data.metadata.bbox_min), 103 | bbox_max=jnp.array(data.metadata.bbox_max), 104 | batch=batch, 105 | params=params, 106 | )[1] 107 | ) 108 | for view in data.views: 109 | rays = view.rays() 110 | total_loss = 0.0 111 | for i in range(0, rays.shape[0], batch_size): 112 | test_key, key = jax.random.split(key) 113 | sub_batch = rays[i : i + batch_size] 114 | losses = loss_fn(test_key, sub_batch, loop.state.params) 115 | total_loss += float(losses["fine"]) * len(sub_batch) 116 | yield total_loss / rays.shape[0] 117 | 118 | 119 | def chunk_indices(num_chunks: int, indices: List[int]) -> Iterator[Set[int]]: 120 | chunk_size = len(indices) // num_chunks 121 | extra = len(indices) % num_chunks 122 | offset = 0 123 | for i in range(num_chunks): 124 | if i < extra: 125 | size = chunk_size + 1 126 | else: 127 | size = chunk_size 128 | if not size: 129 | return 130 | yield set(indices[offset : offset + size]) 131 | offset += size 132 | assert offset == len(indices) 133 | 134 | 135 | if __name__ == "__main__": 136 | main() 137 | -------------------------------------------------------------------------------- /learn_nerf/scripts/marching_cubes.py: -------------------------------------------------------------------------------- 1 | """ 2 | Apply marching cubes on a trained NeRF model to reproduce a mesh. 3 | """ 4 | 5 | import argparse 6 | import math 7 | import pickle 8 | import struct 9 | from typing import Sequence 10 | 11 | import jax 12 | import jax.numpy as jnp 13 | import numpy as np 14 | import skimage 15 | from learn_nerf.dataset import ModelMetadata 16 | from learn_nerf.scripts.train_nerf import add_model_args, create_model 17 | from tqdm.auto import tqdm 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--batch_size", type=int, default=1024, help="rays per batch") 23 | parser.add_argument( 24 | "--resolution", type=int, default=32, help="steps along each direction" 25 | ) 26 | parser.add_argument("--threshold", type=float, default=0.9) 27 | parser.add_argument("--model_path", type=str, default="nerf.pkl") 28 | add_model_args(parser) 29 | parser.add_argument("metadata_json", type=str) 30 | parser.add_argument("output_obj", type=str) 31 | args = parser.parse_args() 32 | 33 | print("loading metadata...") 34 | metadata = ModelMetadata.from_json(args.metadata_json) 35 | 36 | print("loading model...") 37 | _, fine, _ = create_model(args, metadata) 38 | with open(args.model_path, "rb") as f: 39 | params = pickle.load(f)["fine"] 40 | 41 | density_fn = jax.jit( 42 | lambda coords: ( 43 | 1 44 | - jnp.exp( 45 | -fine.apply(dict(params=params), coords, jnp.zeros_like(coords))[0] 46 | ) 47 | ) 48 | ) 49 | 50 | input_coords = grid_coordinates( 51 | bbox_min=metadata.bbox_min, 52 | bbox_max=metadata.bbox_max, 53 | grid_size=args.resolution, 54 | ).reshape([-1, 3]) 55 | 56 | print("computing densities...") 57 | outputs = [] 58 | for i in tqdm(range(0, input_coords.shape[0], args.batch_size)): 59 | batch = input_coords[i : i + args.batch_size] 60 | density = density_fn(batch) 61 | outputs.append(density) 62 | 63 | volume = np.array(jnp.concatenate(outputs, axis=0).reshape([args.resolution] * 3)) 64 | volume = np.pad(volume, 1, mode="constant", constant_values=0) 65 | 66 | # Adapted from https://scikit-image.org/docs/dev/auto_examples/edges/plot_marching_cubes.html. 67 | verts, faces, normals, _values = skimage.measure.marching_cubes( 68 | volume, level=args.threshold 69 | ) 70 | 71 | verts = flip_x_and_z(verts) 72 | size = np.array(metadata.bbox_max) - np.array(metadata.bbox_min) 73 | verts *= size / args.resolution 74 | verts -= (np.max(verts, axis=0) + np.min(verts, axis=0)) / 2 75 | 76 | if args.output_obj.endswith(".obj"): 77 | write_obj(args.output_obj, verts, faces) 78 | elif args.output_obj.endswith(".stl"): 79 | write_stl(args.output_stl, verts, faces, normals) 80 | 81 | 82 | def flip_x_and_z(tris: np.ndarray) -> np.ndarray: 83 | return np.stack([tris[..., 2], tris[..., 1], tris[..., 0]], axis=-1) 84 | 85 | 86 | def grid_coordinates( 87 | bbox_min: Sequence[float], bbox_max: Sequence[float], grid_size: int 88 | ) -> np.ndarray: 89 | result = np.empty([grid_size] * 3 + [3]) 90 | for i, (bbox_min, bbox_max) in enumerate(zip(bbox_min, bbox_max)): 91 | sub_size = [grid_size if i == j else 1 for j in range(3)] 92 | result[..., i] = np.linspace(bbox_min, bbox_max, num=grid_size).reshape( 93 | sub_size 94 | ) 95 | return result 96 | 97 | 98 | def write_obj(path: str, vertices: np.ndarray, faces: np.ndarray): 99 | vertex_strs = [f"v {x:.5f} {y:.5f} {z:.5f}" for x, y, z in vertices.tolist()] 100 | face_strs = [f"f {x[0]+1} {x[1]+1} {x[2]+1}" for x in faces.tolist()] 101 | with open(path, "w") as f: 102 | f.write("\n".join(vertex_strs) + "\n") 103 | f.write("\n".join(face_strs) + "\n") 104 | 105 | 106 | def write_stl(path: str, vertices: np.ndarray, faces: np.ndarray, normals: np.ndarray): 107 | with open(path, "wb") as f: 108 | f.write(b"\x00" * 80) 109 | f.write(struct.pack(" str: 43 | name, _ = os.path.splitext(os.path.basename(path)) 44 | return name.replace("_", " ") 45 | 46 | 47 | def read_log(path) -> Dict[str, np.ndarray]: 48 | result = defaultdict(list) 49 | with open(path, "r") as f: 50 | lines = [line for line in f.readlines() if line.startswith("step")] 51 | for line in lines: 52 | for field in (x for x in line.split() if "=" in x): 53 | name, value = field.split("=") 54 | result[name].append(float(value)) 55 | return {k: np.array(v) for k, v in result.items()} 56 | 57 | 58 | if __name__ == "__main__": 59 | main() 60 | -------------------------------------------------------------------------------- /learn_nerf/scripts/render_nerf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Render a view using a NeRF model. 3 | """ 4 | 5 | import argparse 6 | import pickle 7 | import random 8 | 9 | import jax 10 | import jax.numpy as jnp 11 | import numpy as np 12 | from learn_nerf.dataset import CameraView, ModelMetadata 13 | from learn_nerf.render import NeRFRenderer 14 | from learn_nerf.scripts.train_nerf import add_model_args, create_model 15 | from PIL import Image 16 | from tqdm.auto import tqdm 17 | 18 | 19 | def main(): 20 | parser = argparser() 21 | parser.add_argument("view_json", type=str, nargs="+") 22 | parser.add_argument("output_png", type=str) 23 | args = parser.parse_args() 24 | 25 | renderer = RenderSession(args) 26 | for view_json in args.view_json: 27 | print(f"rendering view {view_json}...") 28 | renderer.render_view(CameraView.from_json(view_json)) 29 | renderer.save(args.output_png) 30 | 31 | 32 | def argparser(): 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument("--seed", type=int, default=None) 35 | parser.add_argument("--batch_size", type=int, default=1024, help="rays per batch") 36 | parser.add_argument( 37 | "--coarse_samples", type=int, default=64, help="samples per coarse ray" 38 | ) 39 | parser.add_argument( 40 | "--fine_samples", 41 | type=int, 42 | default=128, 43 | help="samples per fine ray (not including coarse samples)", 44 | ) 45 | parser.add_argument("--width", type=int, default=512) 46 | parser.add_argument("--height", type=int, default=512) 47 | parser.add_argument("--model_path", type=str, default="nerf.pkl") 48 | add_model_args(parser) 49 | parser.add_argument("metadata_json", type=str) 50 | return parser 51 | 52 | 53 | class RenderSession: 54 | def __init__(self, args: argparse.Namespace): 55 | print("loading metadata...") 56 | self.metadata = ModelMetadata.from_json(args.metadata_json) 57 | 58 | print("loading model...") 59 | coarse, fine, _ = create_model(args, self.metadata) 60 | with open(args.model_path, "rb") as f: 61 | params = pickle.load(f) 62 | 63 | self.renderer = NeRFRenderer( 64 | coarse=coarse, 65 | fine=fine, 66 | coarse_params=params["coarse"], 67 | fine_params=params["fine"], 68 | background=params["background"], 69 | bbox_min=jnp.array(self.metadata.bbox_min, dtype=jnp.float32), 70 | bbox_max=jnp.array(self.metadata.bbox_max, dtype=jnp.float32), 71 | coarse_ts=args.coarse_samples, 72 | fine_ts=args.fine_samples, 73 | ) 74 | self.render_fn = jax.jit( 75 | lambda *args: self.renderer.render_rays(*args)["fine"]["outputs"] 76 | ) 77 | 78 | self.key = jax.random.PRNGKey( 79 | args.seed if args.seed is not None else random.randint(0, 2 ** 32 - 1) 80 | ) 81 | 82 | self.args = args 83 | self.images = [] 84 | 85 | def render_view(self, view: CameraView): 86 | rays = view.bare_rays(self.args.width, self.args.height) 87 | colors = jnp.zeros([0, 3]) 88 | for i in tqdm(range(0, rays.shape[0], self.args.batch_size)): 89 | sub_batch = rays[i : i + self.args.batch_size] 90 | self.key, this_key = jax.random.split(self.key) 91 | sub_colors = self.render_fn(this_key, sub_batch) 92 | colors = jnp.concatenate([colors, sub_colors], axis=0) 93 | image = ( 94 | (np.array(colors).reshape([self.args.height, self.args.width, 3]) + 1) 95 | * 127.5 96 | ).astype(np.uint8) 97 | self.images.append(image) 98 | 99 | def save(self, output_path: str): 100 | image = np.concatenate(self.images, axis=1) 101 | Image.fromarray(image).save(output_path) 102 | 103 | 104 | if __name__ == "__main__": 105 | main() 106 | -------------------------------------------------------------------------------- /learn_nerf/scripts/render_nerf_interactive.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "50a73fd7", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import argparse\n", 11 | "import copy\n", 12 | "import os\n", 13 | "import shlex\n", 14 | "\n", 15 | "from IPython.display import display\n", 16 | "from PIL import Image\n", 17 | "import numpy as np\n", 18 | "\n", 19 | "from learn_nerf.dataset import CameraView\n", 20 | "from learn_nerf.scripts.render_nerf import RenderSession, argparser" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "id": "fb07430e", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "parser = argparser()\n", 31 | "parser.add_argument('start_view', type=str)\n", 32 | "\n", 33 | "# Example of some arguments.\n", 34 | "arg_str = \"--model llff_nerfs/nerf_v14.pkl --instant_ngp --width 256 --height 256 /media/dumpster1/colmap_test/room/nerf_dataset_v14/metadata.json /media/dumpster1/colmap_test/room/nerf_dataset_v14/00000.json\"\n", 35 | "args = parser.parse_args(shlex.split(arg_str))\n", 36 | "renderer = RenderSession(args)" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "id": "ac6b7aeb", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "LEFT_RIGHT_THETA = 0.2\n", 47 | "UP_DOWN_THETA = 0.2\n", 48 | "FORWARD_DIST = 0.3\n", 49 | "\n", 50 | "saved_views = []\n", 51 | "view = CameraView.from_json(args.start_view)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "id": "b268723d", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "def render_loop():\n", 62 | " while True:\n", 63 | " saved_views.append(copy.deepcopy(view))\n", 64 | " renderer.render_view(view)\n", 65 | " display(Image.fromarray(renderer.images[-1]))\n", 66 | " cmds = ''\n", 67 | " while not cmds or any(x not in list('rludfbxo') for x in cmds):\n", 68 | " cmds = input('r=right, l=left, u=up, d=down, f=forward, b=back, o=reorient, x=stop: ')\n", 69 | " for cmd in cmds:\n", 70 | " if cmd == 'x':\n", 71 | " return\n", 72 | " elif cmd == 'r' or cmd == 'l':\n", 73 | " th = -LEFT_RIGHT_THETA\n", 74 | " if cmd == 'l':\n", 75 | " th = -th\n", 76 | " x, z = np.array(view.x_axis), np.array(view.camera_direction)\n", 77 | " view.x_axis = tuple(x*np.cos(th) + z*np.sin(th))\n", 78 | " view.camera_direction = tuple(-x*np.sin(th) + z*np.cos(th))\n", 79 | " elif cmd == 'u' or cmd == 'd':\n", 80 | " th = -UP_DOWN_THETA\n", 81 | " if cmd == 'u':\n", 82 | " th = -th\n", 83 | " x, z = np.array(view.y_axis), np.array(view.camera_direction)\n", 84 | " view.y_axis = tuple(x*np.cos(th) + z*np.sin(th))\n", 85 | " view.camera_direction = tuple(-x*np.sin(th) + z*np.cos(th))\n", 86 | " elif cmd == 'f' or cmd == 'b':\n", 87 | " d = FORWARD_DIST\n", 88 | " if cmd == 'b':\n", 89 | " d = -d\n", 90 | " view.camera_origin = tuple(np.array(view.camera_origin) + np.array(view.camera_direction)*d)\n", 91 | " elif cmd == 'o':\n", 92 | " x, y, z = np.array(view.x_axis), np.array(saved_views[0].y_axis), np.array(view.camera_direction)\n", 93 | " z = z - y*np.dot(z, y)\n", 94 | " z = z / np.linalg.norm(z)\n", 95 | " x = np.cross(y, z)\n", 96 | " x = x / np.linalg.norm(x)\n", 97 | " view.x_axis = tuple(x)\n", 98 | " view.y_axis = tuple(y)\n", 99 | " view.camera_direction = tuple(z)\n", 100 | "\n", 101 | "render_loop()" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "id": "c0321759", 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "# Save a low-res version to a single reel file.\n", 112 | "renderer.save('interactive.png')" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "id": "62107323", 119 | "metadata": { 120 | "scrolled": true 121 | }, 122 | "outputs": [], 123 | "source": [ 124 | "# Render in higher resolution for ffmpeg encoding.\n", 125 | "os.makedirs('interactive', exist_ok=True)\n", 126 | "renderer.images = []\n", 127 | "args.width = 384\n", 128 | "args.height = 384\n", 129 | "for i, view in enumerate(saved_views):\n", 130 | " print(f'view {i} of {len(saved_views)}')\n", 131 | " renderer.render_view(view)\n", 132 | " Image.fromarray(renderer.images.pop()).save(f'interactive/{i:04}.png')" 133 | ] 134 | } 135 | ], 136 | "metadata": { 137 | "kernelspec": { 138 | "display_name": "Python 3 (ipykernel)", 139 | "language": "python", 140 | "name": "python3" 141 | }, 142 | "language_info": { 143 | "codemirror_mode": { 144 | "name": "ipython", 145 | "version": 3 146 | }, 147 | "file_extension": ".py", 148 | "mimetype": "text/x-python", 149 | "name": "python", 150 | "nbconvert_exporter": "python", 151 | "pygments_lexer": "ipython3", 152 | "version": "3.8.10" 153 | } 154 | }, 155 | "nbformat": 4, 156 | "nbformat_minor": 5 157 | } 158 | -------------------------------------------------------------------------------- /learn_nerf/scripts/render_nerf_pan.py: -------------------------------------------------------------------------------- 1 | """ 2 | Render a panning view of a NeRF model. 3 | """ 4 | 5 | import math 6 | 7 | import jax.numpy as jnp 8 | import numpy as np 9 | from learn_nerf.dataset import CameraView 10 | from learn_nerf.scripts.render_nerf import RenderSession, argparser 11 | 12 | 13 | def main(): 14 | parser = argparser() 15 | parser.add_argument("--frames", type=int, default=10) 16 | parser.add_argument("--distance", type=float, default=2.0) 17 | parser.add_argument("--random_axis", action="store_true") 18 | parser.add_argument("output_png", type=str) 19 | args = parser.parse_args() 20 | 21 | rs = RenderSession(args) 22 | 23 | scale = float(jnp.linalg.norm(rs.renderer.bbox_min - rs.renderer.bbox_max)) 24 | center = np.array((rs.renderer.bbox_min + rs.renderer.bbox_max) / 2) 25 | 26 | rot_axis = np.array([0.0, 0.0, -1.0]) 27 | rot_basis_1 = np.array([1.0, 0.0, 0.0]) 28 | if args.random_axis: 29 | rot_axis = np.random.normal(size=(3,)) 30 | rot_axis /= np.linalg.norm(rot_axis) 31 | rot_basis_1 = np.array([-rot_axis[2], 0.0, rot_axis[0]]) 32 | rot_basis_1 /= np.linalg.norm(rot_basis_1) 33 | rot_basis_2 = np.cross(rot_axis, rot_basis_1) 34 | 35 | for frame in range(args.frames): 36 | print(f"sampling frame {frame}...") 37 | theta = (frame / args.frames) * math.pi * 2 38 | direction = np.cos(theta) * rot_basis_1 + np.sin(theta) * rot_basis_2 39 | rs.render_view( 40 | CameraView( 41 | camera_direction=tuple(direction), 42 | camera_origin=tuple(-direction * scale * args.distance + center), 43 | x_axis=tuple( 44 | np.cos(theta + np.pi / 2) * rot_basis_1 45 | + np.sin(theta + np.pi / 2) * rot_basis_2 46 | ), 47 | y_axis=tuple(rot_axis), 48 | x_fov=60.0 * math.pi / 180, 49 | y_fov=60.0 * math.pi / 180, 50 | ) 51 | ) 52 | 53 | rs.save(args.output_png) 54 | 55 | 56 | if __name__ == "__main__": 57 | main() 58 | -------------------------------------------------------------------------------- /learn_nerf/scripts/render_nerf_spin.py: -------------------------------------------------------------------------------- 1 | """ 2 | Spin around the y axis from a fixed camera view. 3 | """ 4 | 5 | import math 6 | 7 | import jax.numpy as jnp 8 | import numpy as np 9 | from learn_nerf.dataset import CameraView 10 | from learn_nerf.scripts.render_nerf import RenderSession, argparser 11 | 12 | 13 | def main(): 14 | parser = argparser() 15 | parser.add_argument("--frames", type=int, default=10) 16 | parser.add_argument("view_json", type=str) 17 | parser.add_argument("output_png", type=str) 18 | args = parser.parse_args() 19 | 20 | rs = RenderSession(args) 21 | 22 | view = CameraView.from_json(args.view_json) 23 | x, z = np.array(view.x_axis), np.array(view.camera_direction) 24 | 25 | for i in range(args.frames): 26 | theta = 2 * math.pi * i / args.frames 27 | sin, cos = math.sin(theta), math.cos(theta) 28 | view.x_axis, view.camera_direction = tuple(cos * x + sin * z), tuple( 29 | -sin * x + cos * z 30 | ) 31 | rs.render_view(view) 32 | 33 | rs.save(args.output_png) 34 | 35 | 36 | if __name__ == "__main__": 37 | main() 38 | -------------------------------------------------------------------------------- /learn_nerf/scripts/render_new_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Create a new NeRF dataset using a trained NeRF model by rendering random 3 | viewing angles. 4 | """ 5 | 6 | import argparse 7 | import math 8 | import os 9 | import pickle 10 | import random 11 | import shutil 12 | 13 | import jax 14 | import jax.numpy as jnp 15 | import numpy as np 16 | from learn_nerf.dataset import CameraView, ModelMetadata 17 | from learn_nerf.render import NeRFRenderer 18 | from learn_nerf.scripts.train_nerf import add_model_args, create_model 19 | from PIL import Image 20 | from tqdm.auto import tqdm 21 | 22 | 23 | def main(): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--seed", type=int, default=None) 26 | parser.add_argument("--batch_size", type=int, default=1024, help="rays per batch") 27 | parser.add_argument( 28 | "--coarse_samples", type=int, default=64, help="samples per coarse ray" 29 | ) 30 | parser.add_argument( 31 | "--fine_samples", 32 | type=int, 33 | default=128, 34 | help="samples per fine ray (not including coarse samples)", 35 | ) 36 | parser.add_argument("--num_images", type=int, default=100) 37 | parser.add_argument("--size", type=int, default=512) 38 | parser.add_argument("--distance", type=float, default=1.0) 39 | parser.add_argument("--max_depth", type=float, default=10.0) 40 | parser.add_argument("--model_path", type=str, default="nerf.pkl") 41 | add_model_args(parser) 42 | parser.add_argument("metadata_json", type=str) 43 | parser.add_argument("output_dir", type=str) 44 | args = parser.parse_args() 45 | 46 | if os.path.exists(args.output_dir): 47 | raise FileExistsError(f"output directory exists: {args.output_dir}") 48 | 49 | metadata = ModelMetadata.from_json(args.metadata_json) 50 | 51 | print("loading model...") 52 | coarse, fine, _ = create_model(args, metadata) 53 | with open(args.model_path, "rb") as f: 54 | params = pickle.load(f) 55 | 56 | renderer = NeRFRenderer( 57 | coarse=coarse, 58 | fine=fine, 59 | coarse_params=params["coarse"], 60 | fine_params=params["fine"], 61 | background=params["background"], 62 | bbox_min=jnp.array(metadata.bbox_min, dtype=jnp.float32), 63 | bbox_max=jnp.array(metadata.bbox_max, dtype=jnp.float32), 64 | coarse_ts=args.coarse_samples, 65 | fine_ts=args.fine_samples, 66 | ) 67 | render_fn = jax.jit(lambda *args: renderer.render_rays(*args)["fine"]) 68 | 69 | key = jax.random.PRNGKey( 70 | args.seed if args.seed is not None else random.randint(0, 2 ** 32 - 1) 71 | ) 72 | 73 | os.makedirs(args.output_dir) 74 | shutil.copy(args.metadata_json, os.path.join(args.output_dir, "metadata.json")) 75 | 76 | scale = float(jnp.linalg.norm(renderer.bbox_min - renderer.bbox_max)) 77 | center = np.array((renderer.bbox_min + renderer.bbox_max) / 2) 78 | 79 | for frame in range(args.num_images): 80 | print(f"sampling frame {frame}...") 81 | z = np.random.normal(size=(3,)) 82 | z = z / np.linalg.norm(z) 83 | x = np.array([z[1], -z[0], 0.0]) 84 | x = x / np.linalg.norm(x) 85 | y = np.cross(z, x) 86 | view = CameraView( 87 | camera_direction=tuple(z), 88 | camera_origin=tuple(-z * scale * args.distance + center), 89 | x_axis=tuple(x), 90 | y_axis=tuple(y), 91 | x_fov=60.0 * math.pi / 180, 92 | y_fov=60.0 * math.pi / 180, 93 | ) 94 | with open(os.path.join(args.output_dir, f"{frame:05}.json"), "w") as f: 95 | f.write(view.to_json()) 96 | rays = view.bare_rays(args.size, args.size) 97 | colors = jnp.zeros([0, 3]) 98 | depths = jnp.zeros([0, 1]) 99 | for i in tqdm(range(0, rays.shape[0], args.batch_size)): 100 | sub_batch = rays[i : i + args.batch_size] 101 | key, this_key = jax.random.split(key) 102 | results = render_fn(this_key, sub_batch) 103 | 104 | z_depth = ( 105 | jnp.clip( 106 | jnp.where( 107 | results["alphas"] > 0.9, 108 | ( 109 | ( 110 | (results["coords"] - jnp.array(view.camera_origin)) 111 | @ jnp.array(view.camera_direction) 112 | )[:, None] 113 | / (results["alphas"] + 1e-8) 114 | ), 115 | args.max_depth, 116 | ), 117 | 0.0, 118 | args.max_depth, 119 | ) 120 | / args.max_depth 121 | ) 122 | colors = jnp.concatenate([colors, results["outputs"]], axis=0) 123 | depths = jnp.concatenate([depths, z_depth], axis=0) 124 | image = ( 125 | (np.array(colors).reshape([args.size, args.size, 3]) + 1) * 127.5 126 | ).astype(np.uint8) 127 | Image.fromarray(image).save(os.path.join(args.output_dir, f"{frame:05}.png")) 128 | depth_image = ( 129 | np.array(depths).reshape([args.size, args.size]) * 0xFFFF 130 | ).astype(np.uint32) 131 | Image.fromarray(depth_image).save( 132 | os.path.join(args.output_dir, f"{frame:05}_depth.png") 133 | ) 134 | 135 | 136 | if __name__ == "__main__": 137 | main() 138 | -------------------------------------------------------------------------------- /learn_nerf/scripts/train_nerf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a NeRF model on a scene. 3 | """ 4 | 5 | import argparse 6 | import os 7 | import random 8 | from functools import partial 9 | from typing import Any, Dict, Tuple 10 | 11 | import jax 12 | import jax.numpy as jnp 13 | from learn_nerf.dataset import ModelMetadata, load_dataset 14 | from learn_nerf.instant_ngp import InstantNGPModel, InstantNGPRefNERFModel 15 | from learn_nerf.model import ModelBase, NeRFModel 16 | from learn_nerf.ref_nerf import RefNERFModel 17 | from learn_nerf.train import TrainLoop 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--seed", type=int, default=None) 23 | parser.add_argument("--lr", type=float, default=1e-4) 24 | parser.add_argument("--batch_size", type=int, default=4096, help="rays per batch") 25 | parser.add_argument( 26 | "--test_batch_size", type=int, default=None, help="rays per test batch" 27 | ) 28 | parser.add_argument( 29 | "--coarse_samples", type=int, default=64, help="samples per coarse ray" 30 | ) 31 | parser.add_argument( 32 | "--fine_samples", 33 | type=int, 34 | default=128, 35 | help="samples per fine ray (not including coarse samples)", 36 | ) 37 | parser.add_argument( 38 | "--density_penalty", 39 | type=float, 40 | default=None, 41 | help="penalty coefficient for density at random points", 42 | ) 43 | parser.add_argument( 44 | "--density_penalty_batch_size", 45 | type=int, 46 | default=128, 47 | help="batch size for computing density penalty", 48 | ) 49 | parser.add_argument("--save_interval", type=int, default=1000) 50 | parser.add_argument("--save_path", type=str, default="nerf.pkl") 51 | parser.add_argument("--one_view", action="store_true") 52 | parser.add_argument("--test_data_dir", type=str, default=None) 53 | add_model_args(parser) 54 | parser.add_argument("data_dir", type=str) 55 | args = parser.parse_args() 56 | 57 | if args.test_batch_size is None: 58 | args.test_batch_size = args.batch_size 59 | 60 | print("loading dataset...") 61 | data = load_dataset(args.data_dir) 62 | if args.one_view: 63 | data.views = data.views[:1] 64 | 65 | if args.test_data_dir is not None: 66 | print("loading test dataset...") 67 | test_data = load_dataset(args.test_data_dir) 68 | if args.one_view: 69 | test_data.views = test_data.views[:1] 70 | else: 71 | test_data = None 72 | 73 | key = jax.random.PRNGKey( 74 | args.seed if args.seed is not None else random.randint(0, 2 ** 32 - 1) 75 | ) 76 | init_key, key = jax.random.split(key) 77 | 78 | print("creating model and train loop...") 79 | coarse, fine, train_kwargs = create_model(args, data.metadata) 80 | loop = TrainLoop( 81 | coarse, 82 | fine, 83 | init_rng=init_key, 84 | lr=args.lr, 85 | coarse_ts=args.coarse_samples, 86 | fine_ts=args.fine_samples, 87 | density_penalty=args.density_penalty, 88 | density_penalty_batch_size=args.density_penalty_batch_size, 89 | **train_kwargs, 90 | ) 91 | if os.path.exists(args.save_path): 92 | print(f"loading from checkpoint: {args.save_path}") 93 | loop.load(args.save_path) 94 | step_fn = loop.step_fn( 95 | jnp.array(data.metadata.bbox_min), 96 | jnp.array(data.metadata.bbox_max), 97 | ) 98 | if test_data is not None: 99 | loss_fn = jax.jit( 100 | lambda key, batch, params: loop.losses( 101 | key=key, 102 | bbox_min=jnp.array(data.metadata.bbox_min), 103 | bbox_max=jnp.array(data.metadata.bbox_max), 104 | batch=batch, 105 | params=params, 106 | )[1] 107 | ) 108 | 109 | print("training...") 110 | data_key, test_data_key, key = jax.random.split(key, 3) 111 | shuffle_dir = os.path.join(args.data_dir, "shuffled") 112 | if test_data: 113 | test_shuffle_dir = os.path.join(args.test_data_dir, "shuffled") 114 | test_iterator = test_data.iterate_batches( 115 | test_shuffle_dir, test_data_key, args.test_batch_size 116 | ) 117 | for i, batch in enumerate( 118 | data.iterate_batches(shuffle_dir, data_key, args.batch_size) 119 | ): 120 | step_key, test_key, key = jax.random.split(key, 3) 121 | if test_data is not None: 122 | test_batch = next(test_iterator) 123 | test_losses = { 124 | f"test_{k}": v 125 | for k, v in loss_fn(test_key, test_batch, loop.state.params).items() 126 | } 127 | losses = step_fn(step_key, batch) 128 | if test_data is not None: 129 | losses.update(test_losses) 130 | loss_str = " ".join(f"{k}={float(v):.05}" for k, v in losses.items()) 131 | print(f"step {i}: {loss_str}") 132 | if i and i % args.save_interval == 0: 133 | loop.save(args.save_path) 134 | 135 | 136 | def add_model_args(parser: argparse.ArgumentParser): 137 | parser.add_argument("--instant_ngp", action="store_true") 138 | parser.add_argument("--ref_nerf", action="store_true") 139 | 140 | 141 | def create_model( 142 | args: argparse.Namespace, metadata: ModelMetadata 143 | ) -> Tuple[ModelBase, ModelBase, Dict[str, Any]]: 144 | if args.instant_ngp: 145 | if args.ref_nerf: 146 | model_cls = partial(InstantNGPRefNERFModel, sh_degree=4) 147 | else: 148 | model_cls = InstantNGPModel 149 | coarse = model_cls( 150 | table_sizes=[2 ** 18] * 6, 151 | grid_sizes=[2 ** (4 + i // 2) for i in range(6)], 152 | bbox_min=jnp.array(metadata.bbox_min), 153 | bbox_max=jnp.array(metadata.bbox_max), 154 | ) 155 | fine = model_cls( 156 | table_sizes=[2 ** 18] * 16, 157 | grid_sizes=[2 ** (4 + i // 2) for i in range(16)], 158 | bbox_min=jnp.array(metadata.bbox_min), 159 | bbox_max=jnp.array(metadata.bbox_max), 160 | ) 161 | train_kwargs = dict(adam_eps=1e-15, adam_b1=0.9, adam_b2=0.99) 162 | else: 163 | if args.ref_nerf: 164 | model_cls = partial(RefNERFModel, sh_degree=4) 165 | else: 166 | model_cls = NeRFModel 167 | coarse = model_cls() 168 | fine = model_cls() 169 | train_kwargs = dict() 170 | return coarse, fine, train_kwargs 171 | 172 | 173 | if __name__ == "__main__": 174 | main() 175 | -------------------------------------------------------------------------------- /learn_nerf/test_dataset.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | from dataclasses import dataclass 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | from .dataset import NeRFDataset, NeRFView 8 | 9 | 10 | @dataclass 11 | class DummyView(NeRFView): 12 | dummy_image: jnp.ndarray 13 | 14 | def image(self) -> jnp.ndarray: 15 | return self.dummy_image 16 | 17 | 18 | def test_nerf_dataset_iterate_batches(): 19 | dataset = NeRFDataset( 20 | views=[ 21 | DummyView( 22 | camera_direction=(0.0, 1.0, 0.0), 23 | camera_origin=(2.0, 2.0, 2.0), 24 | x_axis=(-1.0, 0.0, 0.0), 25 | y_axis=(0.0, 0.0, 1.0), 26 | x_fov=60.0, 27 | y_fov=60.0, 28 | dummy_image=jax.random.uniform(jax.random.PRNGKey(1337), (10, 10, 3)), 29 | ), 30 | DummyView( 31 | camera_direction=(1.0, 0.0, 0.0), 32 | camera_origin=(-2.0, 2.0, 2.0), 33 | x_axis=(-0.0, 0.0, -1.0), 34 | y_axis=(0.0, 1.0, 0.0), 35 | x_fov=60.0, 36 | y_fov=60.0, 37 | dummy_image=jax.random.uniform(jax.random.PRNGKey(1338), (10, 10, 3)), 38 | ), 39 | ], 40 | bbox_min=(0.0, 0.0, 0.0), 41 | bbox_max=(1.0, 1.0, 1.0), 42 | ) 43 | with tempfile.TemporaryDirectory() as tmp_dir: 44 | batches = list( 45 | dataset.iterate_batches( 46 | tmp_dir, jax.random.PRNGKey(1234), batch_size=51, repeat=False 47 | ) 48 | ) 49 | assert len(batches) == 4, "unexpected number of batches" 50 | assert batches[-1].shape[0] == 200 - 51 * 3, "unexpected last batch size" 51 | 52 | combined = jnp.concatenate(batches, axis=0) 53 | 54 | # Verify origin count. 55 | for view in dataset.views: 56 | origin = jnp.array(view.camera_origin, dtype=jnp.float32) 57 | origins = combined[:, 0] 58 | 59 | view_mask = jnp.sum(jnp.abs(origins - origin), axis=-1) < 1e-5 60 | count = jnp.sum(view_mask) 61 | num_pixels = view.dummy_image.shape[0] * view.dummy_image.shape[1] 62 | assert ( 63 | int(count) == num_pixels 64 | ), f"unexpected number of samples with origin {origin}" 65 | 66 | view_rays = combined[view_mask] 67 | directions = view_rays[:, 1] 68 | mean_direction = jnp.mean(directions, axis=0) 69 | mean_direction = mean_direction / jnp.linalg.norm(mean_direction) 70 | camera_dot = jnp.sum( 71 | mean_direction * jnp.array(view.camera_direction, dtype=jnp.float32) 72 | ) 73 | assert ( 74 | abs(float(camera_dot) - 1) < 1e-5 75 | ), f"mean direction was {mean_direction} but should be {view.camera_direction}" 76 | 77 | colors = view_rays[:, 2] 78 | mean_color = jnp.mean(colors, axis=0) 79 | actual_mean = jnp.mean(view.dummy_image / 127.5 - 1, axis=(0, 1)) 80 | diff = jnp.mean(jnp.abs(mean_color - actual_mean)) 81 | assert float(diff) < 1e-5, "invalid colors for rays" 82 | -------------------------------------------------------------------------------- /learn_nerf/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from functools import partial 4 | from typing import Any, Callable, Dict, Optional, Tuple 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | import optax 9 | from flax.core.scope import VariableDict 10 | from flax.training import train_state 11 | from jax._src.prng import PRNGKeyArray as KeyArray 12 | 13 | from .model import ModelBase 14 | from .render import NeRFRenderer 15 | 16 | 17 | class TrainLoop: 18 | """ 19 | A stateful training loop. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | coarse: ModelBase, 25 | fine: ModelBase, 26 | init_rng: KeyArray, 27 | lr: float, 28 | coarse_ts: int, 29 | fine_ts: int, 30 | adam_b1: float = 0.9, 31 | adam_b2: float = 0.999, 32 | adam_eps: float = 1e-7, 33 | loss_weights: Dict[str, float] = None, 34 | density_penalty: Optional[float] = None, 35 | density_penalty_batch_size: int = 128, 36 | ): 37 | self.coarse = coarse 38 | self.fine = fine 39 | self.coarse_ts = coarse_ts 40 | self.fine_ts = fine_ts 41 | self.loss_weights = ( 42 | loss_weights if loss_weights is not None else default_loss_weights() 43 | ) 44 | self.density_penalty = density_penalty 45 | self.density_penalty_batch_size = density_penalty_batch_size 46 | 47 | coarse_rng, fine_rng = jax.random.split(init_rng) 48 | example_batch = jnp.array([[0.0, 0.0, 0.0]]) 49 | coarse_vars = coarse.init(dict(params=coarse_rng), example_batch, example_batch) 50 | fine_vars = fine.init(dict(params=fine_rng), example_batch, example_batch) 51 | self.state = train_state.TrainState.create( 52 | apply_fn=None, 53 | params=dict( 54 | coarse=coarse_vars["params"], 55 | fine=fine_vars["params"], 56 | # Initialize background as all black. 57 | background=jnp.array([-1.0, -1.0, -1.0]), 58 | ), 59 | tx=optax.adam(lr, b1=adam_b1, b2=adam_b2, eps=adam_eps), 60 | ) 61 | 62 | def save(self, path: str): 63 | """ 64 | Save the model parameters to a file. 65 | """ 66 | tmp_path = path + ".tmp" 67 | with open(tmp_path, "wb") as f: 68 | pickle.dump(self.state.params, f) 69 | os.rename(tmp_path, path) 70 | 71 | def load(self, path: str): 72 | """ 73 | Load the model parameters from a file. 74 | """ 75 | with open(path, "rb") as f: 76 | self.state = self.state.replace(params=pickle.load(f)) 77 | 78 | def step_fn( 79 | self, bbox_min: jnp.ndarray, bbox_max: jnp.ndarray 80 | ) -> Callable[[jax.random.PRNGKey, jnp.ndarray], Dict[str, jnp.ndarray]]: 81 | """ 82 | Create a function that steps in place and returns a logging dict. 83 | """ 84 | 85 | @jax.jit 86 | def step_fn( 87 | state: train_state.TrainState, key: KeyArray, batch: jnp.ndarray 88 | ) -> Tuple[train_state.TrainState, Dict[str, jnp.ndarray]]: 89 | loss_fn = partial(self.losses, key, bbox_min, bbox_max, batch) 90 | grad, values = jax.grad(loss_fn, has_aux=True)(state.params) 91 | 92 | def tree_norm(tree: Any) -> jnp.ndarray: 93 | return jnp.sqrt( 94 | jax.tree_util.tree_reduce( 95 | lambda total, x: total + jnp.sum(x ** 2), tree, jnp.array(0.0) 96 | ) 97 | ) 98 | 99 | values.update( 100 | dict( 101 | grad_norm=tree_norm(grad), 102 | param_norm=tree_norm(state.params), 103 | ) 104 | ) 105 | 106 | return state.apply_gradients(grads=grad), values 107 | 108 | def in_place_step(key: KeyArray, batch: jnp.ndarray) -> Dict[str, jnp.ndarray]: 109 | self.state, ret_val = step_fn(self.state, key, batch) 110 | return ret_val 111 | 112 | return in_place_step 113 | 114 | def losses( 115 | self, 116 | key: KeyArray, 117 | bbox_min: jnp.ndarray, 118 | bbox_max: jnp.ndarray, 119 | batch: jnp.ndarray, 120 | params: VariableDict, 121 | ) -> Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]: 122 | """ 123 | Compute losses and a logging dict for a given batch and settings. 124 | """ 125 | renderer = NeRFRenderer( 126 | coarse=self.coarse, 127 | fine=self.fine, 128 | coarse_params=params["coarse"], 129 | fine_params=params["fine"], 130 | background=params["background"], 131 | bbox_min=bbox_min, 132 | bbox_max=bbox_max, 133 | coarse_ts=self.coarse_ts, 134 | fine_ts=self.fine_ts, 135 | ) 136 | 137 | key, density_key = jax.random.split(key) 138 | 139 | render_out = renderer.render_rays(key, batch[:, :2]) 140 | targets = batch[:, 2] 141 | coarse_loss = jnp.mean((render_out["coarse"]["outputs"] - targets) ** 2) 142 | fine_loss = jnp.mean((render_out["fine"]["outputs"] - targets) ** 2) 143 | 144 | loss_dict = dict(coarse=coarse_loss, fine=fine_loss) 145 | total_loss = coarse_loss + fine_loss 146 | for name, loss in render_out["coarse_aux"].items(): 147 | loss_dict[f"coarse_{name}"] = loss 148 | total_loss = total_loss + self.loss_weights[name] * loss 149 | for name, loss in render_out["fine_aux"].items(): 150 | loss_dict[f"fine_{name}"] = loss 151 | total_loss = total_loss + self.loss_weights[name] * loss 152 | 153 | if self.density_penalty is not None: 154 | for prefix, model in [("fine", self.fine), ("coarse", self.coarse)]: 155 | penalty = self.average_density( 156 | key=density_key, 157 | model=model, 158 | params=params[prefix], 159 | bbox_min=bbox_min, 160 | bbox_max=bbox_max, 161 | ) 162 | loss_dict[f"{prefix}_density"] = penalty 163 | total_loss = total_loss + self.density_penalty * penalty 164 | 165 | return total_loss, loss_dict 166 | 167 | def average_density( 168 | self, 169 | key: KeyArray, 170 | model: ModelBase, 171 | params: Any, 172 | bbox_min: jnp.ndarray, 173 | bbox_max: jnp.ndarray, 174 | ) -> jnp.ndarray: 175 | coords = ( 176 | jax.random.uniform(key, shape=(self.density_penalty_batch_size, 3)) 177 | * (bbox_max - bbox_min) 178 | + bbox_min 179 | ) 180 | dirs = jax.random.normal(key, shape=(self.density_penalty_batch_size, 3)) 181 | dirs = dirs / jnp.linalg.norm(dirs, axis=-1, keepdims=True) 182 | 183 | densities, _, _ = model.apply(dict(params=params), coords, dirs) 184 | return jnp.mean(densities) 185 | 186 | 187 | def default_loss_weights() -> Dict[str, float]: 188 | return dict( 189 | normal_mse=3e-4, 190 | neg_normal=0.1, 191 | ) 192 | -------------------------------------------------------------------------------- /point_cloud/main.go: -------------------------------------------------------------------------------- 1 | // Command point_cloud reconstructs a point cloud from a dataset with added 2 | // depth images as exported by the render_new_dataset.py script. 3 | package main 4 | 5 | import ( 6 | "encoding/json" 7 | "errors" 8 | "flag" 9 | "fmt" 10 | "image/color" 11 | "image/png" 12 | "log" 13 | "math" 14 | "math/rand" 15 | "os" 16 | "path/filepath" 17 | 18 | "github.com/unixpickle/essentials" 19 | "github.com/unixpickle/model3d/model3d" 20 | "github.com/unixpickle/model3d/render3d" 21 | "github.com/unixpickle/model3d/toolbox3d" 22 | ) 23 | 24 | func main() { 25 | var maxDepth float64 26 | var thickness float64 27 | var delta float64 28 | var maxPoints int 29 | var sortDensity bool 30 | var sortDensityK int 31 | var dataDir string 32 | var outputPath string 33 | flag.Float64Var(&maxDepth, "max-depth", 10.0, "maximum depth value corresponding to white pixel") 34 | flag.Float64Var(&thickness, "thickness", 0.02, "radius of each point") 35 | flag.Float64Var(&delta, "delta", 0.02, "marching cubes delta") 36 | flag.IntVar(&maxPoints, "max-points", 50000, "maximum points to sample") 37 | flag.BoolVar(&sortDensity, "sort-density", false, "remove lowest density samples first") 38 | flag.IntVar(&sortDensityK, "sort-density-k", 5, "neighbor to use for density estimate") 39 | flag.StringVar(&dataDir, "data-dir", "", "data directory") 40 | flag.StringVar(&outputPath, "output-path", "", "output zipped material OBJ path") 41 | flag.Parse() 42 | if dataDir == "" || outputPath == "" { 43 | essentials.Die("Must specify -data-dir and -output-path") 44 | } 45 | 46 | log.Println("Computing points...") 47 | points := []model3d.Coord3D{} 48 | colors := []render3d.Color{} 49 | for i := 0; true; i++ { 50 | metadataPath := filepath.Join(dataDir, fmt.Sprintf("%05d.json", i)) 51 | depthPath := filepath.Join(dataDir, fmt.Sprintf("%05d_depth.png", i)) 52 | colorPath := filepath.Join(dataDir, fmt.Sprintf("%05d.png", i)) 53 | 54 | if _, err := os.Stat(metadataPath); os.IsNotExist(err) { 55 | break 56 | } 57 | var metadata struct { 58 | Origin [3]float64 `json:"origin"` 59 | XFov float64 `json:"x_fov"` 60 | YFov float64 `json:"y_fov"` 61 | X [3]float64 `json:"x"` 62 | Y [3]float64 `json:"y"` 63 | Z [3]float64 `json:"z"` 64 | } 65 | f, err := os.Open(metadataPath) 66 | essentials.Must(err) 67 | err = json.NewDecoder(f).Decode(&metadata) 68 | f.Close() 69 | essentials.Must(err) 70 | 71 | origin := model3d.NewCoord3DArray(metadata.Origin) 72 | xAxis := model3d.NewCoord3DArray(metadata.X).Scale(math.Tan(metadata.XFov / 2)) 73 | yAxis := model3d.NewCoord3DArray(metadata.Y).Scale(math.Tan(metadata.YFov / 2)) 74 | zAxis := model3d.NewCoord3DArray(metadata.Z) 75 | 76 | err = ReadRGBD(depthPath, colorPath, func(x, y float64, depth uint16, c color.Color) { 77 | zDist := (float64(depth) / 0xffff) * maxDepth 78 | direction := zAxis.Add(xAxis.Scale(x)).Add(yAxis.Scale(y)).Normalize() 79 | scale := zDist / direction.Dot(zAxis) 80 | coord := origin.Add(direction.Scale(scale)) 81 | points = append(points, coord) 82 | r, g, b, _ := color.RGBAModel.Convert(c).RGBA() 83 | colors = append(colors, render3d.NewColorRGB(float64(r)/0xffff, float64(g)/0xffff, float64(b)/0xffff)) 84 | }) 85 | essentials.Must(err) 86 | } 87 | 88 | if len(points) > maxPoints { 89 | log.Printf("Found %d points. Reducing to %d...", len(points), maxPoints) 90 | if sortDensity { 91 | SortByDensity(sortDensityK, points, colors) 92 | } else { 93 | rand.Shuffle(len(points), func(i, j int) { 94 | points[i], points[j] = points[j], points[i] 95 | colors[i], colors[j] = colors[j], colors[i] 96 | }) 97 | } 98 | points = points[:maxPoints] 99 | colors = colors[:maxPoints] 100 | } else { 101 | log.Printf("Using all %d points.", len(points)) 102 | } 103 | 104 | log.Println("Constructing solid and color function...") 105 | min := points[0] 106 | max := points[0] 107 | for _, p := range points { 108 | min = min.Min(p) 109 | max = max.Max(p) 110 | } 111 | tree := model3d.NewCoordTree(points) 112 | solid := model3d.CheckedFuncSolid( 113 | min, 114 | max, 115 | func(c model3d.Coord3D) bool { 116 | return tree.Dist(c) < thickness 117 | }, 118 | ) 119 | coordToColor := map[model3d.Coord3D]render3d.Color{} 120 | for i, c := range points { 121 | coordToColor[c] = colors[i] 122 | } 123 | colorFunc := toolbox3d.CoordColorFunc(func(c model3d.Coord3D) render3d.Color { 124 | return coordToColor[tree.NearestNeighbor(c)] 125 | }) 126 | 127 | log.Println("Creating mesh...") 128 | mesh := model3d.MarchingCubesSearch(solid, delta, 8) 129 | 130 | log.Println("Saving mesh...") 131 | mesh.SaveQuantizedMaterialOBJ(outputPath, 128, colorFunc.Cached().TriangleColor) 132 | } 133 | 134 | func ReadRGBD(depthPath, colorPath string, cb func(x, y float64, depth uint16, c color.Color)) error { 135 | f, err := os.Open(depthPath) 136 | if err != nil { 137 | return err 138 | } 139 | depthImg, err := png.Decode(f) 140 | f.Close() 141 | if err != nil { 142 | return err 143 | } 144 | 145 | f, err = os.Open(colorPath) 146 | if err != nil { 147 | return err 148 | } 149 | colorImg, err := png.Decode(f) 150 | f.Close() 151 | if err != nil { 152 | return err 153 | } 154 | 155 | b := depthImg.Bounds() 156 | if b != colorImg.Bounds() { 157 | return errors.New("mismatched size of RGB and depth images") 158 | } 159 | 160 | for y := 0; y < b.Dy(); y++ { 161 | yFrac := 2*float64(y)/float64(b.Dy()-1) - 1 162 | for x := 0; x < b.Dx(); x++ { 163 | xFrac := 2*float64(x)/float64(b.Dx()-1) - 1 164 | 165 | depth, _, _, _ := color.Gray16Model.Convert(depthImg.At(x, y)).RGBA() 166 | if depth == 0xffff { 167 | continue 168 | } 169 | c := colorImg.At(x, y) 170 | cb(xFrac, yFrac, uint16(depth), c) 171 | } 172 | } 173 | 174 | return nil 175 | } 176 | 177 | func SortByDensity(k int, points []model3d.Coord3D, colors []render3d.Color) { 178 | tree := model3d.NewCoordTree(points) 179 | dists := make([]float64, len(points)) 180 | essentials.ConcurrentMap(0, len(dists), func(i int) { 181 | c := points[i] 182 | dists[i] = c.SquaredDist(tree.KNN(k, c)[k-1]) 183 | }) 184 | essentials.VoodooSort(dists, func(i, j int) bool { 185 | return dists[i] < dists[j] 186 | }, points, colors) 187 | } 188 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="learn-nerf", 5 | py_modules=["learn_nerf"], 6 | install_requires=[], 7 | ) 8 | -------------------------------------------------------------------------------- /simple_dataset/camera_gen.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "math" 5 | 6 | "github.com/unixpickle/model3d/model3d" 7 | "github.com/unixpickle/model3d/render3d" 8 | ) 9 | 10 | type CameraGen interface { 11 | Camera(i, total int) *render3d.Camera 12 | } 13 | 14 | type RandomCameraGen struct { 15 | Object render3d.Object 16 | Fov float64 17 | } 18 | 19 | func (r *RandomCameraGen) Camera(i, total int) *render3d.Camera { 20 | direction := model3d.NewCoord3DRandUnit() 21 | return render3d.DirectionalCamera(r.Object, direction, r.Fov*math.Pi/180) 22 | } 23 | 24 | type RotatingCameraGen struct { 25 | Object render3d.Object 26 | Fov float64 27 | Axis model3d.Coord3D 28 | Offset model3d.Coord3D 29 | 30 | total int 31 | furthestScale float64 32 | } 33 | 34 | func (r *RotatingCameraGen) Camera(i, total int) *render3d.Camera { 35 | if r.total != total { 36 | r.updateCache(total) 37 | } 38 | dir := r.direction(i, total) 39 | center := r.Object.Min().Mid(r.Object.Max()) 40 | cam := render3d.NewCameraAt(center.Add(dir.Scale(r.furthestScale)), center, r.Fov) 41 | return cam 42 | } 43 | 44 | func (r *RotatingCameraGen) updateCache(total int) { 45 | r.total = total 46 | scale := 0.0 47 | for i := 0; i < total; i++ { 48 | cam := render3d.DirectionalCamera(r.Object, r.direction(i, total), r.Fov) 49 | s := cam.Origin.Dist(r.Object.Min().Mid(r.Object.Max())) 50 | if s > scale { 51 | scale = s 52 | } 53 | } 54 | r.furthestScale = scale 55 | } 56 | 57 | func (r *RotatingCameraGen) direction(i, total int) model3d.Coord3D { 58 | theta := math.Pi * 2 * float64(i) / float64(total) 59 | rotation := model3d.Rotation(r.Axis, theta) 60 | return rotation.Apply(r.Offset) 61 | } 62 | -------------------------------------------------------------------------------- /simple_dataset/main.go: -------------------------------------------------------------------------------- 1 | // Command simple_dataset creates a NeRF dataset from a single-color STL file. 2 | package main 3 | 4 | import ( 5 | "encoding/json" 6 | "flag" 7 | "fmt" 8 | "log" 9 | "math" 10 | "math/rand" 11 | "os" 12 | "path/filepath" 13 | 14 | "github.com/unixpickle/essentials" 15 | "github.com/unixpickle/model3d/model3d" 16 | "github.com/unixpickle/model3d/render3d" 17 | ) 18 | 19 | func main() { 20 | var fov float64 21 | var resolution int 22 | var numImages int 23 | var numLights int 24 | var lightBrightness float64 25 | var seed int64 26 | var noImages bool 27 | var rotate bool 28 | color := VectorFlag{Value: model3d.XYZ(0.8, 0.8, 0.0)} 29 | rotationAxis := VectorFlag{Value: model3d.Z(1.0)} 30 | rotationOffset := VectorFlag{Value: model3d.Y(-1.0)} 31 | 32 | flag.Float64Var(&fov, "fov", 60.0, "field of view in degrees") 33 | flag.IntVar(&resolution, "resolution", 800, "side length of images to render") 34 | flag.IntVar(&numImages, "images", 100, "number of images to render") 35 | flag.IntVar(&numLights, "num-lights", 5, "number of lights to put into the scene") 36 | flag.Float64Var(&lightBrightness, "light-brightness", 0.5, "brightness of lights") 37 | flag.Int64Var(&seed, "seed", 0, "seed for Go's random number generation") 38 | flag.BoolVar(&noImages, "no-images", false, "only save json files, not renderings") 39 | flag.BoolVar(&rotate, "rotate", false, "render a rotating view rather than random views") 40 | flag.Var(&color, "color", "color of the model, as 'r,g,b'") 41 | flag.Var(&rotationAxis, "rotation-axis", "axis of rotation for -rotate") 42 | flag.Var(&rotationOffset, "rotation-offest", "initial offset from center for -rotate") 43 | 44 | flag.Usage = func() { 45 | fmt.Fprintln(os.Stderr, "Usage: simple_dataset [flags] ") 46 | fmt.Fprintln(os.Stderr) 47 | fmt.Fprintln(os.Stderr, "Flags:") 48 | flag.PrintDefaults() 49 | os.Exit(1) 50 | } 51 | 52 | flag.Parse() 53 | if len(flag.Args()) != 2 { 54 | flag.Usage() 55 | } 56 | 57 | rand.Seed(seed) 58 | 59 | outputDir := flag.Args()[1] 60 | log.Printf("Creating output directory: %s...", outputDir) 61 | if stats, err := os.Stat(outputDir); err == nil && !stats.IsDir() { 62 | essentials.Die("output directory already exists: " + outputDir) 63 | } else if os.IsNotExist(err) { 64 | essentials.Must(os.MkdirAll(outputDir, 0755)) 65 | } 66 | 67 | log.Println("Loading model...") 68 | inputPath := flag.Args()[0] 69 | object := ReadObject(inputPath, color.Value) 70 | 71 | log.Println("Writing metadata...") 72 | WriteGlobalMetadata(outputDir, object) 73 | 74 | log.Println("Creating random lights...") 75 | lights := RandomLights(object, numLights, lightBrightness) 76 | 77 | var cameraGen CameraGen 78 | if rotate { 79 | cameraGen = &RotatingCameraGen{ 80 | Object: object, 81 | Fov: fov * math.Pi / 180, 82 | Axis: rotationAxis.Value, 83 | Offset: rotationOffset.Value, 84 | } 85 | } else { 86 | cameraGen = &RandomCameraGen{Object: object, Fov: fov * math.Pi / 180} 87 | } 88 | 89 | for i := 0; i < numImages; i++ { 90 | log.Printf("Rendering imade %d/%d...", i+1, numImages) 91 | camera := cameraGen.Camera(i, numImages) 92 | 93 | if !noImages { 94 | caster := &render3d.RayCaster{ 95 | Camera: camera, 96 | Lights: lights, 97 | } 98 | viewImage := render3d.NewImage(resolution, resolution) 99 | caster.Render(viewImage, object) 100 | 101 | imagePath := filepath.Join(outputDir, fmt.Sprintf("%04d.png", i)) 102 | essentials.Must(viewImage.Save(imagePath)) 103 | } 104 | 105 | metaPath := filepath.Join(outputDir, fmt.Sprintf("%04d.json", i)) 106 | metadata := map[string]interface{}{ 107 | "origin": camera.Origin.Array(), 108 | "x": camera.ScreenX.Array(), 109 | "y": camera.ScreenY.Array(), 110 | "z": camera.ScreenX.Cross(camera.ScreenY).Normalize().Array(), 111 | "x_fov": camera.FieldOfView, 112 | "y_fov": camera.FieldOfView, 113 | } 114 | f, err := os.Create(metaPath) 115 | essentials.Must(err) 116 | essentials.Must(json.NewEncoder(f).Encode(metadata)) 117 | essentials.Must(f.Close()) 118 | } 119 | } 120 | 121 | func ReadObject(path string, color model3d.Coord3D) render3d.Object { 122 | r, err := os.Open(path) 123 | essentials.Must(err) 124 | defer r.Close() 125 | 126 | triangles, err := model3d.ReadSTL(r) 127 | essentials.Must(err) 128 | mesh := normalizeMesh(model3d.NewMeshTriangles(triangles)) 129 | 130 | collider := model3d.MeshToCollider(mesh) 131 | return render3d.Objectify( 132 | collider, 133 | func(c model3d.Coord3D, rc model3d.RayCollision) render3d.Color { 134 | return render3d.NewColorRGB(color.X, color.Y, color.Z) 135 | }, 136 | ) 137 | } 138 | 139 | func normalizeMesh(mesh *model3d.Mesh) *model3d.Mesh { 140 | mesh = mesh.Translate(mesh.Min().Mid(mesh.Max()).Scale(-1)) 141 | m := mesh.Max() 142 | size := math.Max(math.Max(m.X, m.Y), m.Z) 143 | return mesh.Scale(1 / size) 144 | } 145 | 146 | func WriteGlobalMetadata(outputDir string, object render3d.Object) { 147 | globalMetadataPath := filepath.Join(outputDir, "metadata.json") 148 | f, err := os.Create(globalMetadataPath) 149 | essentials.Must(err) 150 | defer f.Close() 151 | globalMetadata := map[string]interface{}{ 152 | "min": object.Min().Array(), 153 | "max": object.Max().Array(), 154 | } 155 | essentials.Must(json.NewEncoder(f).Encode(globalMetadata)) 156 | } 157 | 158 | func RandomLights(object render3d.Object, n int, brightness float64) []*render3d.PointLight { 159 | center := object.Min().Mid(object.Max()) 160 | lights := make([]*render3d.PointLight, n) 161 | for i := 0; i < n; i++ { 162 | direction := model3d.NewCoord3DRandUnit() 163 | lights[i] = &render3d.PointLight{ 164 | Origin: center.Add(direction.Scale(1000)), 165 | Color: render3d.NewColor(brightness), 166 | } 167 | } 168 | return lights 169 | } 170 | -------------------------------------------------------------------------------- /simple_dataset/vector_flag.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "strings" 7 | 8 | "github.com/unixpickle/model3d/model3d" 9 | ) 10 | 11 | // A VectorFlag is a flag.Value that parses comma-delimited 12 | // 3D vectors, e.g. "3.0, 2, -1". 13 | type VectorFlag struct { 14 | Value model3d.Coord3D 15 | } 16 | 17 | func (v *VectorFlag) String() string { 18 | var parts [3]string 19 | for i, x := range v.Value.Array() { 20 | parts[i] = strconv.FormatFloat(x, 'f', -1, 64) 21 | } 22 | return strings.Join(parts[:], ",") 23 | } 24 | 25 | func (v *VectorFlag) Set(s string) error { 26 | parts := strings.Split(s, ",") 27 | if len(parts) != 3 { 28 | return fmt.Errorf("vector does not have exactly two commas: %s", s) 29 | } 30 | var res [3]float64 31 | for i, x := range parts { 32 | x = strings.TrimSpace(x) 33 | f, err := strconv.ParseFloat(x, 64) 34 | if err != nil { 35 | return fmt.Errorf("invalid component '%s' in vector '%s': %s", x, s, err.Error()) 36 | } 37 | res[i] = f 38 | } 39 | v.Value = model3d.NewCoord3DArray(res) 40 | return nil 41 | } 42 | --------------------------------------------------------------------------------