├── .flake8 ├── .github └── workflows │ ├── lint.yml │ └── mypy.yml ├── .gitignore ├── README.md ├── dozer_render.gif ├── lego_render.gif ├── mypy.ini ├── render_360.py ├── requirements.txt ├── tensorf ├── __init__.py ├── cameras.py ├── data.py ├── networks.py ├── render.py ├── tensor_vm.py ├── train_config.py ├── training.py └── utils.py ├── train_lego.py └── train_nerfstudio.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # E203: whitespace before : 3 | # E501: line too long ( characters) 4 | # W503: line break before binary operator 5 | ; ignore = E203,E501,D100,D101,D102,D103,W503 6 | ignore = E203,E501,W503 7 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | black-check: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v1 14 | - name: Black Code Formatter 15 | uses: lgeiger/black-action@master 16 | with: 17 | args: ". --check" 18 | -------------------------------------------------------------------------------- /.github/workflows/mypy.yml: -------------------------------------------------------------------------------- 1 | name: mypy 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | mypy: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ["3.8"] 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v1 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install mypy 26 | pip install -r requirements.txt 27 | - name: Test with mypy 28 | run: | 29 | mypy --install-types --non-interactive . 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.swo 3 | *.pyc 4 | *.egg-info 5 | __pycache__ 6 | .coverage 7 | htmlcov 8 | .mypy_cache 9 | .dmypy.json 10 | .hypothesis 11 | .ipynb_checkpoints 12 | data 13 | runs/ 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tensorf-jax 2 | 3 | JAX implementation of 4 | [Tensorial Radiance Fields](https://apchenstu.github.io/TensoRF/), written as an 5 | exercise. 6 | 7 | ``` 8 | @misc{TensoRF, 9 | title={TensoRF: Tensorial Radiance Fields}, 10 | author={Anpei Chen and Zexiang Xu and Andreas Geiger and and Jingyi Yu and Hao Su}, 11 | year={2022}, 12 | eprint={2203.09517}, 13 | archivePrefix={arXiv}, 14 | primaryClass={cs.CV} 15 | } 16 | ``` 17 | 18 | We don't attempt to reproduce the original paper exactly, but can achieve decent 19 | results after 5~10 minutes of training: 20 | 21 | ![Lego rendering GIF](./lego_render.gif) 22 | 23 | As proposed, TensoRF only supports scenes that fit in a fixed-size bounding box. 24 | We've also added basic support for unbounded "real" scenes via mip-NeRF 25 | 360-inspired scene contraction[^1]. From 26 | [nerfstudio](https://github.com/nerfstudio-project/nerfstudio)'s "dozer" 27 | dataset: 28 | 29 | ![Dozer rendering GIF](./dozer_render.gif) 30 | 31 | [^1]: 32 | Same as [the original](https://jonbarron.info/mipnerf360/), but with an 33 | $L-\infty$ norm instead of $L-2$ norm. 34 | 35 | ## Instructions 36 | 37 | 1. Download `nerf_synthetic` dataset: 38 | [Google Drive](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1). 39 | With the default training script arguments, we expect this to be extracted to 40 | `./data`, eg `./data/nerf_synthetic/lego`. 41 | 42 | 2. Install dependencies. Probably you want the GPU version of JAX; see the 43 | [official instructions](https://github.com/google/jax#Installation). Then: 44 | 45 | ```bash 46 | pip install -r requirements.txt 47 | ``` 48 | 49 | 3. To print training options: 50 | 51 | ```bash 52 | python ./train_lego.py --help 53 | ``` 54 | 55 | 4. To monitor training, we use Tensorboard: 56 | 57 | ```bash 58 | tensorboard --logdir=./runs/ 59 | ``` 60 | 61 | 5. To render: 62 | 63 | ```bash 64 | python ./render_360.py --help 65 | ``` 66 | 67 | ## Differences from the PyTorch implementation 68 | 69 | Things aren't totally matched to the official implementation: 70 | 71 | - The official implementation relies heavily on masking operations to improve 72 | runtime (for example, by using a weight threshold for sampled points). These 73 | require dynamic shapes and are currently difficult to implement in JAX, so we 74 | replace them with workarounds like weighted sampling. 75 | - Several training details that would likely improve performance are not yet 76 | implemented: bounding box refinement, ray filtering, regularization, etc. 77 | - We include mixed-precision training, which can speed training throughput up by 78 | a significant factor. (is this actually faster in terms of wall-clock time? 79 | unclear) 80 | 81 | ## References 82 | 83 | Implementation details are based loosely on the original PyTorch implementation 84 | [apchsenstu/TensoRF](https://github.com/apchenstu/TensoRF). 85 | 86 | [unixpickle/learn-nerf](https://github.com/unixpickle/learn-nerf) and 87 | [google-research/jaxnerf](https://github.com/google-research/google-research/tree/master/jaxnerf) 88 | were also really helpful for understanding core NeRF concepts + connecting them 89 | to JAX! 90 | 91 | ## To-do 92 | 93 | - [x] Main implementation 94 | - [x] Point sampling 95 | - [x] Feature MLP 96 | - [x] Rendering 97 | - [x] VM decomposition 98 | - [x] Basic implementation 99 | - [x] Vectorized 100 | - [x] Dataloading 101 | - [x] Blender 102 | - [x] nerfstudio 103 | - [x] Basics 104 | - [x] Fisheye support 105 | - [x] Compute samples without undistorting images (throws away a lot of 106 | pixels) 107 | - [x] Tricks for real data 108 | - [x] Scene contraction (~mip-NeRF 360) 109 | - [x] Camera embeddings 110 | - [x] Training 111 | - [x] Learning rate scheduler 112 | - [x] ADAM + grouped LR 113 | - [x] Exponential decay 114 | - [x] Reset decay after upsampling 115 | - [x] Running 116 | - [x] Checkpointing 117 | - [x] Logging 118 | - [x] Loss 119 | - [x] PSNR 120 | - [ ] Test metrics 121 | - [ ] Test images 122 | - [ ] Render previews 123 | - [ ] Ray filtering 124 | - [ ] Bounding box refinement 125 | - [x] Incremental upsampling 126 | - [ ] Regularization terms 127 | - [x] Performance 128 | - [x] Weight thresholding for computing appearance features 129 | - [x] per ray top-k 130 | - [x] global top-k (bad & deleted) 131 | - [x] Mixed-precision 132 | - [x] implemented 133 | - [x] stable 134 | - [ ] Multi-GPU (should be quick) 135 | - [x] Rendering 136 | - [x] RGB 137 | - [x] Depth (median) 138 | - [x] Depth (mean) 139 | - [x] Batching 140 | - [x] Generate some GIFs 141 | - [ ] Misc engineering 142 | - [x] Actions 143 | - [ ] Understand vmap performance differences 144 | ([details](https://github.com/google/jax/discussions/10332)) 145 | -------------------------------------------------------------------------------- /dozer_render.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brentyi/tensorf-jax/19f7ab24be969dba771f46126ec8136390d38e3d/dozer_render.gif -------------------------------------------------------------------------------- /lego_render.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brentyi/tensorf-jax/19f7ab24be969dba771f46126ec8136390d38e3d/lego_render.gif -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | ignore_missing_imports = True 3 | -------------------------------------------------------------------------------- /render_360.py: -------------------------------------------------------------------------------- 1 | """Visualization helper. 2 | 3 | Loads a radiance field and rotates a camera around it, rendering a viewpoint from each 4 | angle. 5 | 6 | For a summary of options: 7 | ``` 8 | python render_360.py --help 9 | ``` 10 | """ 11 | 12 | import dataclasses 13 | import pathlib 14 | 15 | import cv2 16 | import fifteen 17 | import jax 18 | import jax_dataclasses as jdc 19 | import jaxlie 20 | import numpy as onp 21 | import tyro 22 | from jax import numpy as jnp 23 | from PIL import Image 24 | from typing_extensions import Literal, assert_never 25 | 26 | import tensorf.cameras 27 | import tensorf.data 28 | import tensorf.render 29 | import tensorf.train_config 30 | import tensorf.training 31 | 32 | 33 | @dataclasses.dataclass 34 | class Args: 35 | run_dir: pathlib.Path 36 | """Path to training run outputs.""" 37 | 38 | output_dir: pathlib.Path = pathlib.Path("./renders") 39 | """Renders will be saved to `[output_dir]/image_[i].png`.""" 40 | 41 | mode: tensorf.render.RenderMode = tensorf.render.RenderMode.RGB 42 | """Render mode: RGB or depth.""" 43 | 44 | frames: int = 10 45 | """Number of frames to render.""" 46 | 47 | density_samples_per_ray: int = 512 48 | appearance_samples_per_ray: int = 128 49 | ray_batch_size: int = 4096 * 4 50 | 51 | render_width: int = 400 52 | render_height: int = 400 53 | render_fov_x: float = onp.pi / 2.0 54 | render_camera_index: int = 0 55 | """Camera embedding to use, if enabled.""" 56 | 57 | rotation_axis: Literal["world_z", "camera_up"] = "camera_up" 58 | 59 | 60 | def main(args: Args) -> None: 61 | experiment = fifteen.experiments.Experiment(data_dir=args.run_dir) 62 | config = experiment.read_metadata("config", tensorf.train_config.TensorfConfig) 63 | 64 | # Make sure output directory exists. 65 | args.output_dir.mkdir(parents=True, exist_ok=True) 66 | 67 | # Load the training state from a checkpoint. 68 | train_state = tensorf.training.TrainState.initialize( 69 | config=config, 70 | grid_dim=config.grid_dim_final, 71 | prng_key=jax.random.PRNGKey(0), 72 | num_cameras=experiment.read_metadata("num_cameras", int), 73 | ) 74 | train_state = experiment.restore_checkpoint(train_state) 75 | assert train_state.step > 0 76 | 77 | # Load the training dataset... we're only going to use this to grab a camera. 78 | dataset = tensorf.data.make_dataset( 79 | config.dataset_type, 80 | config.dataset_path, 81 | config.scene_scale, 82 | ) 83 | train_cameras = dataset.get_cameras() 84 | 85 | initial_T_camera_world = train_cameras[0].T_camera_world 86 | initial_T_camera_world = jaxlie.SE3.from_rotation_and_translation( 87 | initial_T_camera_world.rotation(), 88 | initial_T_camera_world.translation() * 0.8, 89 | ) 90 | camera = tensorf.cameras.Camera.from_fov( 91 | T_camera_world=initial_T_camera_world, 92 | image_width=args.render_width, 93 | image_height=args.render_height, 94 | fov_x_radians=args.render_fov_x, 95 | ) 96 | 97 | # Get rotation axis. 98 | if args.rotation_axis == "world_z": 99 | rotation_axis = onp.array([0.0, 0.0, 1.0]) 100 | elif args.rotation_axis == "camera_up": 101 | # In the OpenCV convention, the "camera up" is -Y. 102 | up_vectors = onp.array( 103 | [ 104 | camera.T_camera_world.rotation().inverse() @ onp.array([0.0, -1.0, 0.0]) 105 | for camera in train_cameras 106 | ] 107 | ) 108 | rotation_axis = onp.mean(up_vectors, axis=0) 109 | rotation_axis /= onp.linalg.norm(rotation_axis) 110 | else: 111 | assert_never(args.rotation_axis) 112 | 113 | del train_cameras 114 | 115 | # Used for distance rendering. 116 | min_invdist = None 117 | max_invdist = None 118 | 119 | for i in range(args.frames): 120 | print(f"Rendering frame {i + 1}/{args.frames}") 121 | 122 | # Render & save image. 123 | rendered = tensorf.render.render_rays_batched( 124 | appearance_mlp=train_state.appearance_mlp, 125 | learnable_params=train_state.learnable_params, 126 | aabb=train_state.aabb, 127 | rays_wrt_world=camera.pixel_rays_wrt_world( 128 | camera_index=args.render_camera_index 129 | ), 130 | prng_key=jax.random.PRNGKey(0), 131 | config=tensorf.render.RenderConfig( 132 | near=config.render_near, 133 | far=config.render_far, 134 | mode=args.mode, 135 | density_samples_per_ray=args.density_samples_per_ray, 136 | appearance_samples_per_ray=args.appearance_samples_per_ray, 137 | ), 138 | batch_size=args.ray_batch_size, 139 | ) 140 | if len(rendered.shape) == 3: 141 | # RGB: (H, W, 3) 142 | image = onp.array(rendered) 143 | image = onp.clip(image * 255.0, 0.0, 255.0).astype(onp.uint8) 144 | else: 145 | # Visualizing rendered distances: (H, W) 146 | # For this we use inverse distances, which is similar to disparity. 147 | image = onp.array(rendered) 148 | 149 | # Visualization heuristics for "depths". 150 | image = 1.0 / onp.maximum(image, 1e-4) 151 | 152 | # Compute scaling terms using first frame. 153 | if min_invdist is None or max_invdist is None: 154 | min_invdist = image.min() 155 | max_invdist = image.max() * 0.9 156 | 157 | image -= min_invdist 158 | image /= max_invdist - min_invdist 159 | image = onp.clip(image * 255.0, 0.0, 255.0).astype(onp.uint8) 160 | image = onp.tile(image[:, :, None], reps=(1, 1, 3)) 161 | 162 | Image.fromarray(image).save(args.output_dir / f"image_{i:03}.png") 163 | 164 | # Rotate camera. 165 | camera = jdc.replace( 166 | camera, 167 | T_camera_world=camera.T_camera_world 168 | @ jaxlie.SE3.from_rotation( 169 | jaxlie.SO3.exp(2 * jnp.pi / args.frames * rotation_axis) 170 | ), 171 | ) 172 | 173 | 174 | if __name__ == "__main__": 175 | fifteen.utils.pdb_safety_net() 176 | main(tyro.cli(Args)) 177 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tyro 2 | opencv-python 3 | flax 4 | jax 5 | jaxlib 6 | jaxlie 7 | jax_dataclasses>=1.5.1 8 | optax>=0.1.2 9 | tqdm 10 | # tensorflow is needed for tensorboard logging in JAX 11 | tensorflow 12 | git+https://github.com/brentyi/fifteen.git 13 | -------------------------------------------------------------------------------- /tensorf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brentyi/tensorf-jax/19f7ab24be969dba771f46126ec8136390d38e3d/tensorf/__init__.py -------------------------------------------------------------------------------- /tensorf/cameras.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import jax 4 | import jax_dataclasses as jdc 5 | import jaxlie 6 | import numpy as onp 7 | from jax import numpy as jnp 8 | from typing_extensions import Annotated 9 | 10 | 11 | @jdc.pytree_dataclass 12 | class Rays3D(jdc.EnforcedAnnotationsMixin): 13 | """Structure defining some rays in 3D space. Should contain origin and direction 14 | arrays of the same shape; `(*, 3)`.""" 15 | 16 | origins: Annotated[jnp.ndarray, (3,), jnp.floating] 17 | directions: Annotated[jnp.ndarray, (3,), jnp.floating] 18 | camera_indices: Annotated[ 19 | jnp.ndarray, (), jnp.uint32 20 | ] # Used for per-camera appearance embeddings. 21 | 22 | 23 | @jdc.pytree_dataclass 24 | class Camera(jdc.EnforcedAnnotationsMixin): 25 | K: Annotated[jnp.ndarray, (3, 3), jnp.floating] 26 | """Intrinsics. alpha * [u v 1]^T = K @ [x_c y_c z_c]^T""" 27 | 28 | T_camera_world: jaxlie.SE3 29 | """Extrinsics.""" 30 | 31 | image_width: jdc.Static[int] 32 | image_height: jdc.Static[int] 33 | 34 | @staticmethod 35 | def from_fov( 36 | T_camera_world: jaxlie.SE3, 37 | image_width: int, 38 | image_height: int, 39 | fov_x_radians: Union[float, jnp.ndarray, None] = None, 40 | fov_y_radians: Union[float, jnp.ndarray, None] = None, 41 | ) -> "Camera": 42 | """Initialize camera parameters from FOV. At least one of `fov_x_radians` or 43 | `fov_y_radians` must be passed in.""" 44 | # Offset by 1/2 pixel because (0,0) in pixel space corresponds actually to a 45 | # square whose upper-left corner is (0,0), and bottom-right corner is (1,1). 46 | cx = image_width / 2.0 - 0.5 47 | cy = image_height / 2.0 - 0.5 48 | 49 | fx = None 50 | fy = None 51 | 52 | if fov_x_radians is not None: 53 | fx = (image_width / 2.0) / jnp.tan(fov_x_radians / 2.0) 54 | if fov_y_radians is not None: 55 | fy = (image_height / 2.0) / jnp.tan(fov_y_radians / 2.0) 56 | 57 | if fx is None: 58 | assert fy is not None 59 | fx = fy 60 | if fy is None: 61 | assert fx is not None 62 | fy = fx 63 | 64 | K = jnp.array( 65 | [ 66 | [fx, 0.0, cx], 67 | [0.0, fy, cy], 68 | [0.0, 0.0, 1.0], 69 | ] 70 | ) 71 | return Camera( 72 | K=K, 73 | T_camera_world=T_camera_world, 74 | image_width=image_width, 75 | image_height=image_height, 76 | ) 77 | 78 | @jdc.jit 79 | def compute_fov_x_radians(self) -> jnp.ndarray: 80 | fx = self.K[0, 0] 81 | return 2.0 * jnp.arctan((self.image_width / 2.0) / fx) 82 | 83 | @jdc.jit 84 | def compute_fov_y_radians(self) -> jnp.ndarray: 85 | fy = self.K[1, 1] 86 | return 2.0 * jnp.arctan((self.image_height / 2.0) / fy) 87 | 88 | @jdc.jit 89 | def resize_with_fixed_fov( 90 | self, image_width: jdc.Static[int], image_height: jdc.Static[int] 91 | ) -> "Camera": 92 | return Camera.from_fov( 93 | self.T_camera_world, 94 | image_width=image_width, 95 | image_height=image_height, 96 | fov_x_radians=self.compute_fov_x_radians(), 97 | fov_y_radians=self.compute_fov_y_radians(), 98 | ) 99 | 100 | @jdc.jit 101 | def ray_wrt_world_from_uv(self, u: float, v: float, camera_index: int) -> Rays3D: 102 | """Input is a scalar u/v coordinate. Output is a Rays struct, with origin and 103 | directions of shape (3,).,""" 104 | 105 | # 2D -> 3D projection: `R_world_camera @ K^-1 @ [u v 1]^T`. 106 | uv_coord_homog = jnp.array([u, v, 1.0]) 107 | T_world_camera = self.T_camera_world.inverse() 108 | ray_direction_wrt_world = ( 109 | T_world_camera.rotation().as_matrix() 110 | @ jnp.linalg.inv(self.K) 111 | @ uv_coord_homog 112 | ) 113 | assert ray_direction_wrt_world.shape == (3,) 114 | 115 | ray_direction_wrt_world /= jnp.linalg.norm(ray_direction_wrt_world) + 1e-8 116 | rays_wrt_world = Rays3D( 117 | origins=T_world_camera.translation(), # type: ignore 118 | directions=ray_direction_wrt_world, 119 | camera_indices=jnp.array(camera_index, dtype=jnp.uint32), 120 | ) 121 | return rays_wrt_world 122 | 123 | # Hacky: JIT this on CPU. Currently used only for data generation. 124 | @jdc.jit(device=jax.devices("cpu")[0]) 125 | def pixel_rays_wrt_world(self, camera_index: int) -> Rays3D: 126 | """Get a length-3 vector for each pixel in image-space. Output shape is a ray 127 | structure with an origin field of shape `(image_height, image_width, 3)`, and 128 | direction field of shape `(image_height, image_width, 3)`.""" 129 | 130 | # Get width and height of image. 131 | image_width = self.image_width 132 | image_height = self.image_height 133 | 134 | # Get image-space uv coordinates. 135 | v, u = onp.mgrid[:image_height, :image_width] 136 | assert u.shape == v.shape == (image_height, image_width) 137 | 138 | # Compute world-space rays. 139 | rays_wrt_world = jax.vmap( 140 | jax.vmap(lambda u, v: self.ray_wrt_world_from_uv(u, v, camera_index)) 141 | )(u, v) 142 | assert rays_wrt_world.get_batch_axes() == (image_height, image_width) 143 | return rays_wrt_world 144 | -------------------------------------------------------------------------------- /tensorf/data.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import concurrent.futures 4 | import dataclasses 5 | import functools 6 | import json 7 | import pathlib 8 | from typing import Any, Dict, Iterable, List, Literal, Protocol, TypeVar 9 | 10 | import cv2 11 | import jax 12 | import jax_dataclasses as jdc 13 | import jaxlie 14 | import numpy as onp 15 | import PIL.Image 16 | from jax import numpy as jnp 17 | from optax._src.alias import transform 18 | from tqdm.auto import tqdm 19 | from typing_extensions import Annotated, assert_never 20 | 21 | from . import cameras 22 | 23 | 24 | class NerfDataset(Protocol): 25 | def get_training_rays(self) -> RenderedRays: 26 | ... 27 | 28 | def get_cameras(self) -> List[cameras.Camera]: 29 | ... 30 | 31 | 32 | def make_dataset( 33 | dataset_type: Literal["blender", "nerfstudio"], 34 | dataset_root: pathlib.Path, 35 | scene_scale: float, 36 | ) -> NerfDataset: 37 | if dataset_type == "blender": 38 | assert scene_scale == 1.0 39 | return BlenderDataset(dataset_root) 40 | elif dataset_type == "nerfstudio": 41 | return NerfstudioDataset(dataset_root, scene_scale) 42 | else: 43 | assert_never(dataset_type) 44 | 45 | 46 | @dataclasses.dataclass(frozen=True) 47 | class NerfstudioDataset: 48 | dataset_root: pathlib.Path 49 | scene_scale: float 50 | 51 | def get_training_rays(self) -> RenderedRays: 52 | metadata = self._get_metadata() 53 | 54 | camera_model = metadata["camera_model"] 55 | if metadata["camera_model"] == "OPENCV": 56 | dist_coeffs = onp.array( 57 | [metadata.get(k, 0.0) for k in ("k1", "k2", "p1", "p2")] 58 | ) 59 | elif metadata["camera_model"] == "OPENCV_FISHEYE": 60 | dist_coeffs = onp.array([metadata[k] for k in ("k1", "k2", "k3", "k4")]) 61 | else: 62 | assert False, f"Unsupported camera model {metadata['camera_model']}." 63 | 64 | image_paths = tuple( 65 | map( 66 | lambda frame: self.dataset_root / frame["file_path"], 67 | metadata["frames"], 68 | ) 69 | ) 70 | out: List[RenderedRays] = [] 71 | for i, image, camera in zip( 72 | range(len(image_paths)), 73 | _threaded_image_fetcher(image_paths), 74 | tqdm( 75 | self.get_cameras(), 76 | desc=f"Loading {self.dataset_root.stem}", 77 | ), 78 | ): 79 | h, w = image.shape[:2] 80 | orig_h, orig_w = h, w 81 | 82 | # Resize image to an arbitrary target dimension. 83 | target_pixels = 1200 * 1600 84 | scale = 1.0 85 | if h * w > target_pixels: 86 | scale = onp.sqrt(target_pixels / (h * w)) 87 | h = int(h * scale) 88 | w = int(w * scale) 89 | image = cv2.resize(image, (w, h)) 90 | 91 | image = (image / 255.0).astype(onp.float32) 92 | 93 | # (2, w, h) => (h, w, 2) => (h * w, 2) 94 | orig_image_points = ( 95 | onp.mgrid[:w, :h].T.reshape((h * w, 2)) 96 | / onp.array([h, w]) 97 | * onp.array([orig_h, orig_w]) 98 | ) 99 | 100 | if camera_model == "OPENCV": 101 | ray_directions = cv2.undistortPoints( 102 | src=orig_image_points, 103 | cameraMatrix=camera.K, 104 | distCoeffs=dist_coeffs, 105 | ).squeeze(axis=1) 106 | elif camera_model == "OPENCV_FISHEYE": 107 | ray_directions = cv2.fisheye.undistortPoints( 108 | distorted=orig_image_points[:, None, :], 109 | K=camera.K, 110 | D=dist_coeffs, 111 | ).squeeze(axis=1) 112 | else: 113 | assert False 114 | 115 | assert ray_directions.shape == (h * w, 2) 116 | ray_directions = onp.concatenate( 117 | [ray_directions, onp.ones((h * w, 1))], axis=-1 118 | ) 119 | ray_directions /= onp.linalg.norm(ray_directions, axis=-1, keepdims=True) 120 | assert ray_directions.shape == (h * w, 3) 121 | 122 | T_world_camera = camera.T_camera_world.inverse() 123 | out.append( 124 | RenderedRays( 125 | colors=image.reshape((-1, 3)), # type: ignore 126 | rays_wrt_world=cameras.Rays3D( 127 | origins=onp.tile( # type: ignore 128 | T_world_camera.translation()[None, :], (h * w, 1) 129 | ), 130 | directions=ray_directions @ onp.array(T_world_camera.rotation().as_matrix().T), # type: ignore 131 | camera_indices=onp.full( 132 | shape=(h * w), 133 | fill_value=i, 134 | dtype=onp.uint32, 135 | ), # type: ignore 136 | ), 137 | ) 138 | ) 139 | 140 | return jax.tree_map(lambda *leaves: onp.concatenate(leaves, axis=0), *out) 141 | 142 | def get_cameras(self) -> List[cameras.Camera]: 143 | # Transformation from Blender camera coordinates to OpenCV ones. We like the OpenCV 144 | # convention. 145 | T_blendercam_camera = jaxlie.SE3.from_rotation( 146 | jaxlie.SO3.from_x_radians(onp.pi) 147 | ) 148 | 149 | metadata = self._get_metadata() 150 | 151 | transform_matrices: List[onp.ndarray] = [] 152 | for frame in metadata["frames"]: 153 | # Expected keys in each frame. 154 | assert frame.keys() == {"file_path", "transform_matrix"} 155 | 156 | transform_matrices.append( 157 | onp.array(frame["transform_matrix"], dtype=onp.float32) 158 | ) 159 | assert transform_matrices[-1].shape == (4, 4) # Should be in SE(3). 160 | 161 | camera_matrix = onp.eye(3) 162 | camera_matrix[0, 0] = metadata["fl_x"] 163 | camera_matrix[1, 1] = metadata["fl_y"] 164 | camera_matrix[0, 2] = metadata["cx"] 165 | camera_matrix[1, 2] = metadata["cy"] 166 | 167 | # Compute pose bounding box. 168 | positions = onp.array(transform_matrices)[:, :3, 3] 169 | aabb_min = positions.min(axis=0) 170 | aabb_max = positions.max(axis=0) 171 | del positions 172 | assert aabb_min.shape == aabb_max.shape == (3,) 173 | 174 | out = [] 175 | for transform_matrix in transform_matrices: 176 | # Center and scale scene. In the future we might also auto-orient it. 177 | transform_matrix = transform_matrix.copy() 178 | transform_matrix[:3, 3] -= aabb_min 179 | transform_matrix[:3, 3] -= (aabb_max - aabb_min) / 2.0 180 | transform_matrix[:3, 3] /= (aabb_max - aabb_min).max() / 2.0 181 | transform_matrix[:3, 3] *= self.scene_scale 182 | 183 | # Compute extrinsics. 184 | T_world_blendercam = jaxlie.SE3.from_matrix(transform_matrix) 185 | T_camera_world = (T_world_blendercam @ T_blendercam_camera).inverse() 186 | 187 | out.append( 188 | cameras.Camera( 189 | K=camera_matrix, # type: ignore 190 | T_camera_world=T_camera_world, 191 | image_width=metadata["w"], 192 | image_height=metadata["h"], 193 | ) 194 | ) 195 | 196 | return out 197 | 198 | def _get_metadata(self) -> Dict[str, Any]: 199 | """Read metadata: image paths, transformation matrices, FOV.""" 200 | with open(self.dataset_root / f"transforms.json") as f: 201 | metadata: dict = json.load(f) 202 | return metadata 203 | 204 | 205 | @dataclasses.dataclass(frozen=True) 206 | class BlenderDataset: 207 | dataset_root: pathlib.Path 208 | 209 | def get_training_rays(self) -> RenderedRays: 210 | return rendered_rays_from_views(self._registered_views) 211 | 212 | def get_cameras(self) -> List[cameras.Camera]: 213 | return [v.camera for v in self._registered_views] 214 | 215 | @functools.cached_property 216 | def _registered_views(self) -> List[RegisteredRgbaView]: 217 | metadata = self._get_metadata() 218 | 219 | image_paths: List[pathlib.Path] = [] 220 | transform_matrices: List[onp.ndarray] = [] 221 | for frame in metadata["frames"]: 222 | assert frame.keys() == {"file_path", "rotation", "transform_matrix"} 223 | 224 | image_paths.append(self.dataset_root / f"{frame['file_path']}.png") 225 | transform_matrices.append( 226 | onp.array(frame["transform_matrix"], dtype=onp.float32) 227 | ) 228 | assert transform_matrices[-1].shape == (4, 4) # Should be in SE(3). 229 | fov_x_radians: float = metadata["camera_angle_x"] 230 | del metadata 231 | 232 | # Transformation from Blender camera coordinates to OpenCV ones. We like the OpenCV 233 | # convention. 234 | T_blendercam_camera = jaxlie.SE3.from_rotation( 235 | jaxlie.SO3.from_x_radians(onp.pi) 236 | ) 237 | 238 | out = [] 239 | for image, transform_matrix in zip( 240 | _threaded_image_fetcher(image_paths), 241 | tqdm( 242 | transform_matrices, 243 | desc=f"Loading {self.dataset_root.stem}", 244 | ), 245 | ): 246 | assert image.dtype == onp.uint8 247 | height, width = image.shape[:2] 248 | 249 | # Note that this is RGBA! 250 | assert image.shape == (height, width, 4) 251 | 252 | # [0, 255] => [0, 1] 253 | image = (image / 255.0).astype(onp.float32) 254 | 255 | # Compute extrinsics. 256 | T_world_blendercam = jaxlie.SE3.from_matrix(transform_matrix) 257 | T_camera_world = (T_world_blendercam @ T_blendercam_camera).inverse() 258 | out.append( 259 | RegisteredRgbaView( 260 | image_rgba=image, # type: ignore 261 | camera=cameras.Camera.from_fov( 262 | T_camera_world=T_camera_world, 263 | image_width=width, 264 | image_height=height, 265 | fov_x_radians=fov_x_radians, 266 | ), 267 | ) 268 | ) 269 | return out 270 | 271 | def _get_metadata(self) -> Dict[str, Any]: 272 | """Read metadata: image paths, transformation matrices, FOV.""" 273 | split = "train" 274 | with open(self.dataset_root / f"transforms_{split}.json") as f: 275 | metadata: dict = json.load(f) 276 | return metadata 277 | 278 | 279 | # Helpers. 280 | 281 | 282 | @jdc.pytree_dataclass 283 | class RegisteredRgbaView(jdc.EnforcedAnnotationsMixin): 284 | """Structure containing 2D image + camera pairs.""" 285 | 286 | image_rgba: Annotated[ 287 | jnp.ndarray, 288 | jnp.floating, # Range of contents is [0, 1]. 289 | ] 290 | camera: cameras.Camera 291 | 292 | 293 | @jdc.pytree_dataclass 294 | class RenderedRays(jdc.EnforcedAnnotationsMixin): 295 | """Structure containing individual 3D rays in world space + colors.""" 296 | 297 | colors: Annotated[jnp.ndarray, (3,), jnp.floating] 298 | rays_wrt_world: cameras.Rays3D 299 | 300 | 301 | def rendered_rays_from_views(views: List[RegisteredRgbaView]) -> RenderedRays: 302 | """Convert a list of registered 2D views into a pytree containing individual 303 | rendered rays.""" 304 | 305 | out = [] 306 | for i, view in enumerate(views): 307 | height = view.camera.image_height 308 | width = view.camera.image_width 309 | 310 | rays = view.camera.pixel_rays_wrt_world(camera_index=i) 311 | assert rays.get_batch_axes() == (height, width) 312 | 313 | rgba = view.image_rgba 314 | assert rgba.shape == (height, width, 4) 315 | 316 | # Add white background color; this is what the standard alpha compositing over 317 | # operator works out to with an opaque white background. 318 | rgb = rgba[..., :3] * rgba[..., 3:4] + (1.0 - rgba[..., 3:4]) 319 | 320 | out.append( 321 | RenderedRays( 322 | colors=rgb.reshape((-1, 3)), 323 | rays_wrt_world=cameras.Rays3D( 324 | origins=rays.origins.reshape((-1, 3)), 325 | directions=rays.directions.reshape((-1, 3)), 326 | camera_indices=rays.camera_indices.reshape((-1,)), 327 | ), 328 | ) 329 | ) 330 | 331 | out_concat: RenderedRays = jax.tree_map( 332 | lambda *children: onp.concatenate(children, axis=0), *out 333 | ) 334 | 335 | # Shape of rays should (N,3), colors should be (N,4), etc. 336 | assert len(out_concat.rays_wrt_world.get_batch_axes()) == 1 337 | return out_concat 338 | 339 | 340 | T = TypeVar("T", bound=Iterable) 341 | 342 | 343 | def _threaded_image_fetcher(paths: Iterable[pathlib.Path]) -> Iterable[onp.ndarray]: 344 | """Maps an iterable over image paths to an iterable over image arrays, which are 345 | opened via PIL. 346 | 347 | Helpful for parallelizing IO.""" 348 | with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: 349 | for image in executor.map( 350 | lambda p: onp.array(PIL.Image.open(p)), 351 | paths, 352 | chunksize=4, 353 | ): 354 | yield image 355 | -------------------------------------------------------------------------------- /tensorf/networks.py: -------------------------------------------------------------------------------- 1 | """Neural networks for volumetric rendering.""" 2 | from __future__ import annotations 3 | 4 | from typing import Any, Optional 5 | 6 | from flax import linen as nn 7 | from jax import numpy as jnp 8 | 9 | relu_layer_init = nn.initializers.kaiming_normal() # variance = 2.0 / fan_in 10 | linear_layer_init = nn.initializers.lecun_normal() # variance = 1.0 / fan_in 11 | 12 | Dtype = Any 13 | 14 | 15 | def _fourier_encode(coords: jnp.ndarray, n_freqs: int) -> jnp.ndarray: 16 | """Fourier feature helper. 17 | 18 | Args: 19 | coords (jnp.ndarray): Coordinates of shape (*, D). 20 | n_freqs (int): Number of fourier frequencies. 21 | 22 | Returns: 23 | jnp.ndarray: Shape (*, n_freqs * 2). 24 | """ 25 | *batch_axes, D = coords.shape 26 | coeffs = 2 ** jnp.arange(n_freqs, dtype=jnp.float32) 27 | inputs = coords[..., None] * coeffs 28 | assert inputs.shape == (*batch_axes, D, n_freqs) 29 | 30 | out = jnp.sin( 31 | jnp.concatenate( 32 | [inputs, inputs + 0.5 * jnp.pi], 33 | axis=-1, 34 | ) 35 | ) 36 | assert out.shape == (*batch_axes, D, 2 * n_freqs) 37 | return out.reshape((*batch_axes, D * 2 * n_freqs)) 38 | 39 | 40 | class FeatureMlp(nn.Module): 41 | feature_squash_dim: int = 27 42 | units: int = 128 43 | feature_n_freqs: int = 6 44 | viewdir_n_freqs: int = 6 45 | num_cameras: Optional[int] = None # Set to enable camera embeddings. 46 | 47 | @nn.compact 48 | def __call__( # type: ignore 49 | self, 50 | features: jnp.ndarray, 51 | viewdirs: jnp.ndarray, 52 | camera_indices: jnp.ndarray, 53 | # Computation dtype. Main parameters will always be float32. 54 | dtype: Any = jnp.float32, 55 | ) -> jnp.ndarray: 56 | *batch_axes, feat_dim = features.shape 57 | assert viewdirs.shape == (*batch_axes, 3) 58 | 59 | # Layer 0. This is `basis_mat` in the original implementation, and reduces the 60 | # computational requirements of the fourier encoding. 61 | features = nn.Dense( 62 | features=self.feature_squash_dim, 63 | kernel_init=linear_layer_init, 64 | use_bias=False, 65 | dtype=dtype, 66 | )(features) 67 | 68 | # Compute fourier features. 69 | # 70 | # This computes both sines and cosines to match other implementations, but since 71 | # cos(x) == sin(x + pi/2) we could also consider just adding a bias term to the 72 | # dense layer above and only picking one. 73 | x = jnp.concatenate( 74 | [ 75 | features, 76 | viewdirs, 77 | _fourier_encode(features, self.feature_n_freqs), 78 | _fourier_encode(viewdirs, self.viewdir_n_freqs), 79 | ], 80 | axis=-1, 81 | ) 82 | expected_encoded_dim = ( 83 | self.feature_squash_dim 84 | + 3 85 | + 2 * self.feature_n_freqs * self.feature_squash_dim 86 | + 2 * self.viewdir_n_freqs * 3 87 | ) 88 | assert x.shape == (*batch_axes, expected_encoded_dim) 89 | 90 | # Layer 1. 91 | x = nn.Dense( 92 | features=self.units, 93 | kernel_init=relu_layer_init, 94 | dtype=dtype, 95 | )(x) 96 | x = nn.relu(x) 97 | assert x.shape == (*batch_axes, self.units) 98 | 99 | # Layer 2. 100 | x = nn.Dense( 101 | features=self.units, 102 | kernel_init=relu_layer_init, 103 | dtype=dtype, 104 | )(x) 105 | x = nn.relu(x) 106 | assert x.shape == (*batch_axes, self.units) 107 | 108 | # Add FiLM-style conditioning for camera embeddings. 109 | # This is different from (and may be worse than) what published works do. 110 | if self.num_cameras is not None: 111 | conditioner = nn.Embed( 112 | num_embeddings=self.num_cameras, features=self.units 113 | )(camera_indices) 114 | assert conditioner.shape == (*batch_axes, self.units) 115 | x = x.at[..., self.units // 2 :].set( 116 | conditioner[..., : self.units // 2] * x[..., self.units // 2 :] 117 | + conditioner[..., self.units // 2 :] 118 | ) 119 | 120 | # Layer 3. 121 | x = nn.Dense( 122 | features=3, 123 | kernel_init=linear_layer_init, 124 | dtype=dtype, 125 | )(x) 126 | assert x.shape == (*batch_axes, 3) 127 | 128 | rgb = nn.sigmoid(x) 129 | return rgb 130 | -------------------------------------------------------------------------------- /tensorf/render.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import dataclasses 4 | import enum 5 | import math 6 | from typing import Any, Optional, Tuple, cast 7 | 8 | import flax 9 | import jax 10 | import jax_dataclasses as jdc 11 | import numpy as onp 12 | from jax import numpy as jnp 13 | from tqdm.auto import tqdm 14 | from typing_extensions import Annotated 15 | 16 | from . import cameras, networks, tensor_vm, utils 17 | 18 | 19 | class RenderMode(enum.Enum): 20 | # Note: we currently only support rendering distances from the camera origin, which 21 | # is a bit different from depth. (the latter is typically a local Z value) 22 | RGB = enum.auto() 23 | DIST_MEDIAN = enum.auto() 24 | DIST_MEAN = enum.auto() 25 | 26 | 27 | @dataclasses.dataclass(frozen=True) 28 | class RenderConfig: 29 | near: float 30 | far: float 31 | mode: RenderMode 32 | 33 | density_samples_per_ray: int 34 | """Number of points to sample densities at.""" 35 | 36 | appearance_samples_per_ray: int 37 | """Number of points to sample appearances at.""" 38 | 39 | 40 | @jdc.pytree_dataclass 41 | class LearnableParams: 42 | """Structure containing learnable parameters required for rendering.""" 43 | 44 | appearance_mlp_params: flax.core.FrozenDict 45 | appearance_tensor: tensor_vm.TensorVM 46 | density_tensor: tensor_vm.TensorVM 47 | scene_contraction: jdc.Static[bool] 48 | 49 | 50 | def render_rays_batched( 51 | appearance_mlp: networks.FeatureMlp, 52 | learnable_params: LearnableParams, 53 | aabb: jnp.ndarray, 54 | rays_wrt_world: cameras.Rays3D, 55 | prng_key: Optional[jax.random.KeyArray], 56 | config: RenderConfig, 57 | *, 58 | batch_size: int = 4096, 59 | use_tqdm: bool = True, 60 | ) -> onp.ndarray: 61 | """Render rays. Supports arbitrary batch axes (helpful for inputs with both height 62 | and width leading axes), and automatically splits rays into batches to prevent 63 | out-of-memory errors. 64 | 65 | Possibly this could just take the training state directly as input.""" 66 | batch_axes = rays_wrt_world.get_batch_axes() 67 | rays_wrt_world = ( 68 | cameras.Rays3D( # TODO: feels like this could be done less manually! 69 | origins=rays_wrt_world.origins.reshape((-1, 3)), 70 | directions=rays_wrt_world.directions.reshape((-1, 3)), 71 | camera_indices=rays_wrt_world.camera_indices.reshape((-1,)), 72 | ) 73 | ) 74 | (total_rays,) = rays_wrt_world.get_batch_axes() 75 | 76 | processed = 0 77 | out = [] 78 | 79 | for i in (tqdm if use_tqdm else lambda x: x)( 80 | range(math.ceil(total_rays / batch_size)) 81 | ): 82 | batch = jax.tree_map( 83 | lambda x: x[processed : min(total_rays, processed + batch_size)], 84 | rays_wrt_world, 85 | ) 86 | processed += batch_size 87 | out.append( 88 | render_rays( 89 | appearance_mlp=appearance_mlp, 90 | learnable_params=learnable_params, 91 | aabb=aabb, 92 | rays_wrt_world=batch, 93 | prng_key=prng_key, 94 | config=config, 95 | ) 96 | ) 97 | out_concatenated = onp.concatenate(out, axis=0) 98 | 99 | # Reshape, with generalization to both (*,) for depths and (*, 3) for RGB. 100 | return out_concatenated.reshape(batch_axes + out_concatenated.shape[1:]) 101 | 102 | 103 | @jdc.jit 104 | def render_rays( 105 | appearance_mlp: jdc.Static[networks.FeatureMlp], 106 | learnable_params: LearnableParams, 107 | aabb: jnp.ndarray, 108 | rays_wrt_world: cameras.Rays3D, 109 | prng_key: jax.random.KeyArray, 110 | config: jdc.Static[RenderConfig], 111 | *, 112 | dtype: jdc.Static[Any] = jnp.float32, 113 | ) -> jnp.ndarray: 114 | """Render a set of rays. 115 | 116 | Output should have shape `(ray_count, 3)`.""" 117 | 118 | # Cast everything to the desired dtype. 119 | learnable_params, aabb, rays_wrt_world = jax.tree_map( 120 | lambda x: x.astype(dtype) if jnp.issubdtype(jnp.floating, dtype) else x, 121 | (learnable_params, aabb, rays_wrt_world), 122 | ) 123 | 124 | (ray_count,) = rays_wrt_world.get_batch_axes() 125 | 126 | sample_prng_key, render_rgb_prng_key = jax.random.split(prng_key) 127 | 128 | if learnable_params.scene_contraction: 129 | # Contracted scene: sample linearly for close samples, then start spacing 130 | # samples out. 131 | # 132 | # An occupancy grid or proposal network would really help us here! 133 | close_samples_per_ray = config.density_samples_per_ray // 2 134 | far_samples_per_ray = config.density_samples_per_ray - close_samples_per_ray 135 | 136 | close_ts = jnp.linspace(config.near, config.near + 1.0, close_samples_per_ray) 137 | 138 | # Some heuristics for sampling far points, which should be close to sampling 139 | # linearly in disparity when k=1. This is probably reasonable, but it'd be a 140 | # good idea to look at what real NeRF codebases do. 141 | far_start = config.near + 1.0 + 1.0 / close_samples_per_ray 142 | k = 10.0 143 | far_deltas = ( 144 | 1.0 145 | / ( 146 | 1.0 147 | - onp.linspace( # onp here is important for float64. 148 | 0.0, 149 | 1.0 - 1 / ((config.far - far_start) / k + 1), 150 | far_samples_per_ray, 151 | ) 152 | ) 153 | - 1.0 154 | ) * onp.linspace(1.0, k, far_samples_per_ray) 155 | far_ts = far_start + far_deltas 156 | 157 | ts = jnp.tile(jnp.concatenate([close_ts, far_ts])[None, :], reps=(ray_count, 1)) 158 | 159 | # Compute step sizes. 160 | step_sizes = jnp.roll(ts, -1, axis=-1) - ts # Naive. Could be improved 161 | step_sizes = step_sizes.at[:, -1].set(step_sizes[:, -2]) 162 | 163 | # Jitter samples. 164 | sample_jitter = jax.random.uniform( 165 | sample_prng_key, shape=(ray_count, config.density_samples_per_ray) 166 | ) 167 | ts = ts + step_sizes * sample_jitter 168 | 169 | # Compute points in world space. 170 | points = ( 171 | rays_wrt_world.origins[:, None, :] 172 | + ts[:, :, None] * rays_wrt_world.directions[:, None, :] 173 | ) 174 | 175 | # Contract points to cube. 176 | norm = jnp.linalg.norm(points, ord=jnp.inf, axis=-1, keepdims=True) 177 | points = jnp.where(norm <= 1.0, points, (2.0 - 1.0 / norm) * points / norm) 178 | assert points.shape == (ray_count, config.density_samples_per_ray, 3) 179 | points = jnp.moveaxis(points, -1, 0) 180 | else: 181 | # Bounded scene: we sample points uniformly between the camera origin and bounding 182 | # box limit. 183 | points, ts, step_sizes = jax.vmap( 184 | lambda ray: sample_points_along_ray_within_bbox( 185 | ray_wrt_world=ray, 186 | aabb=aabb, 187 | samples_per_ray=config.density_samples_per_ray, 188 | prng_key=sample_prng_key, 189 | ), 190 | out_axes=(1, 0, 0), 191 | )(rays_wrt_world) 192 | step_sizes = jnp.tile(step_sizes[:, None], (1, config.density_samples_per_ray)) 193 | 194 | assert points.shape == (3, ray_count, config.density_samples_per_ray) 195 | assert ts.shape == (ray_count, config.density_samples_per_ray) 196 | assert step_sizes.shape == (ray_count, config.density_samples_per_ray) 197 | 198 | # Normalize points to [-1, 1]. 199 | points = ( 200 | (points - aabb[0][:, None, None]) / (aabb[1] - aabb[0])[:, None, None] - 0.5 201 | ) * 2.0 202 | assert points.shape == (3, ray_count, config.density_samples_per_ray) 203 | 204 | # Pull interpolated density features out of tensor decomposition. 205 | density_feat = learnable_params.density_tensor.interpolate(points) 206 | assert density_feat.shape == ( 207 | density_feat.shape[0], 208 | ray_count, 209 | config.density_samples_per_ray, 210 | ) 211 | 212 | # Density from features. 213 | sigmas = jax.nn.softplus(jnp.sum(density_feat, axis=0) + 10.0) 214 | assert sigmas.shape == (ray_count, config.density_samples_per_ray) 215 | 216 | # Compute segment probabilities for each ray. 217 | probs = compute_segment_probabilities(sigmas, step_sizes) 218 | assert ( 219 | probs.get_batch_axes() 220 | == probs.p_exits.shape 221 | == probs.p_terminates.shape 222 | == (ray_count, config.density_samples_per_ray) 223 | ) 224 | 225 | if config.mode is RenderMode.RGB: 226 | # Get RGB array. 227 | rgb, unbias_coeff = _rgb_from_points( 228 | rays_wrt_world=rays_wrt_world, 229 | probs=probs, 230 | learnable_params=learnable_params, 231 | points=points, 232 | appearance_mlp=appearance_mlp, 233 | config=config, 234 | prng_key=render_rgb_prng_key, 235 | dtype=dtype, 236 | ) 237 | assert rgb.shape == (ray_count, config.density_samples_per_ray, 3) 238 | assert unbias_coeff.shape == (ray_count,) 239 | 240 | # No need to backprop through the unbiasing coefficient! This can also cause 241 | # instability in mixed-precision mode. 242 | unbias_coeff = jax.lax.stop_gradient(unbias_coeff) 243 | 244 | # One thing I don't have intuition for: is there something special about RGB 245 | # that makes this weighted average/expected value meaningful? Is this 246 | # because RGB is additive? Can we just do this with any random color space? 247 | expected_rgb = ( 248 | jnp.sum(rgb * probs.p_terminates[:, :, None], axis=-2) 249 | * unbias_coeff[:, None] 250 | ) 251 | assert expected_rgb.shape == (ray_count, 3) 252 | 253 | # Add white background. 254 | assert probs.p_exits.shape == (ray_count, config.density_samples_per_ray) 255 | background_color = jnp.ones(3, dtype=dtype) 256 | expected_rgb_with_background = ( 257 | expected_rgb + probs.p_exits[:, -1:] * background_color 258 | ) 259 | assert expected_rgb_with_background.shape == (ray_count, 3) 260 | return expected_rgb_with_background 261 | 262 | elif config.mode is RenderMode.DIST_MEDIAN: 263 | # Compute depth via median. 264 | sample_distances = jnp.concatenate( 265 | [ts, jnp.full((ray_count, 1), jnp.inf, dtype=dtype)], axis=-1 266 | ) 267 | p_not_alive_padded = jnp.concatenate( 268 | [1.0 - probs.p_exits, jnp.ones((ray_count, 1), dtype=dtype)], axis=-1 269 | ) 270 | assert sample_distances.shape == p_not_alive_padded.shape 271 | 272 | median_mask = p_not_alive_padded > 0.5 273 | median_mask = ( 274 | jnp.zeros_like(median_mask) 275 | .at[..., 1:] 276 | .set(jnp.logical_xor(median_mask[..., :-1], median_mask[..., 1:])) 277 | ) 278 | 279 | # Output is medians. 280 | depths = jnp.sum(median_mask * sample_distances, axis=-1) 281 | return depths 282 | 283 | elif config.mode is RenderMode.DIST_MEAN: 284 | # Compute depth via expected value. 285 | sample_distances = jnp.concatenate([ts, ts[:, -1:]], axis=-1) 286 | p_terminates_padded = jnp.concatenate( 287 | [probs.p_terminates, probs.p_exits[:, -1:]], axis=-1 288 | ) 289 | assert sample_distances.shape == p_terminates_padded.shape 290 | return jnp.sum(p_terminates_padded * sample_distances, axis=-1) 291 | 292 | else: 293 | assert False 294 | 295 | 296 | @jdc.pytree_dataclass 297 | class SegmentProbabilities(jdc.EnforcedAnnotationsMixin): 298 | p_exits: Annotated[jnp.ndarray, (), jnp.floating] 299 | """P(ray exits segment s). 300 | 301 | Note that this also implies that the ray has exited (and thus entered) all previous 302 | segments.""" 303 | 304 | p_terminates: Annotated[jnp.ndarray, (), jnp.floating] 305 | """P(ray terminates at s, ray exits s - 1). 306 | 307 | For a ray to terminate in a segment, it must first pass through (and 'exit') all 308 | previous segments.""" 309 | 310 | 311 | def compute_segment_probabilities( 312 | sigmas: jnp.ndarray, step_sizes: jnp.ndarray 313 | ) -> SegmentProbabilities: 314 | r"""Compute some probabilities needed for rendering rays. Expects sigmas of shape 315 | (*, sample_count) and a per-ray step size of shape (*,). 316 | 317 | Each of the ray segments we're rendering is broken up into samples. We can treat the 318 | densities as piecewise constant and use an exponential distribution and compute: 319 | 320 | 1. P(ray exits s) = exp(\sum_{i=1}^s -(sigma_i * l_i) 321 | 2. P(ray terminates in s | ray exits s-1) = 1.0 - exp(-sigma_s * l_s) 322 | 3. P(ray terminates in s, ray exits s-1) 323 | = P(ray terminates at s | ray exits s-1) * P(ray exits s-1) 324 | 325 | where l_i is the length of segment i. 326 | """ 327 | 328 | # Support arbitrary leading batch axes. 329 | (*batch_axes, sample_count) = sigmas.shape 330 | assert step_sizes.shape == (*batch_axes, sample_count) 331 | 332 | # Equation 1. 333 | neg_scaled_sigmas = -sigmas * step_sizes 334 | p_exits = jnp.exp(jnp.cumsum(neg_scaled_sigmas, axis=-1)) 335 | assert p_exits.shape == (*batch_axes, sample_count) 336 | 337 | # Equation 2. Not used outside of this function, and not returned. 338 | p_terminates_given_exits_prev = 1.0 - jnp.exp(neg_scaled_sigmas) 339 | assert p_terminates_given_exits_prev.shape == (*batch_axes, sample_count) 340 | 341 | # Equation 3. 342 | p_terminates = jnp.multiply( 343 | p_terminates_given_exits_prev, 344 | # We prepend 1 because the ray is always alive initially. 345 | jnp.concatenate( 346 | [ 347 | jnp.ones((*batch_axes, 1), dtype=neg_scaled_sigmas.dtype), 348 | p_exits[..., :-1], 349 | ], 350 | axis=-1, 351 | ), 352 | ) 353 | assert p_terminates.shape == (*batch_axes, sample_count) 354 | 355 | return SegmentProbabilities( 356 | p_exits=p_exits, 357 | p_terminates=p_terminates, 358 | ) 359 | 360 | 361 | def sample_points_along_ray_within_bbox( 362 | ray_wrt_world: cameras.Rays3D, 363 | aabb: jnp.ndarray, 364 | samples_per_ray: int, 365 | prng_key: Optional[jax.random.KeyArray], 366 | ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: 367 | """Return points along a ray. 368 | 369 | Outputs are: 370 | - Points of shape `(3, samples_per_ray)`. 371 | - Distances from the ray origin of shape `(samples_per_ray,)`. 372 | - A scalar step size.""" 373 | assert ray_wrt_world.get_batch_axes() == () 374 | assert ray_wrt_world.origins.shape == ray_wrt_world.directions.shape == (3,) 375 | 376 | # Get segment of ray that's within the bounding box. 377 | segment = ray_segment_from_bounding_box( 378 | ray_wrt_world, aabb=aabb, min_segment_length=1e-3 379 | ) 380 | step_size = (segment.t_max - segment.t_min) / samples_per_ray 381 | 382 | # Get sample points along ray. 383 | ts = jnp.arange(samples_per_ray) 384 | if prng_key is not None: 385 | # Jitter if a PRNG key is passed in. 386 | ts = ts + jax.random.uniform( 387 | key=prng_key, 388 | shape=ts.shape, 389 | dtype=step_size.dtype, 390 | ) 391 | ts = ts * step_size 392 | ts = segment.t_min + ts 393 | 394 | # That's it! 395 | points = ( 396 | ray_wrt_world.origins[:, None] + ray_wrt_world.directions[:, None] * ts[None, :] 397 | ) 398 | assert points.shape == (3, samples_per_ray) 399 | assert ts.shape == (samples_per_ray,) 400 | assert step_size.shape == () 401 | return points, ts, step_size 402 | 403 | 404 | @jdc.pytree_dataclass 405 | class RaySegmentSpecification: 406 | t_min: jnp.ndarray 407 | t_max: jnp.ndarray 408 | 409 | 410 | def ray_segment_from_bounding_box( 411 | ray_wrt_world: cameras.Rays3D, 412 | aabb: jnp.ndarray, 413 | min_segment_length: float, 414 | ) -> RaySegmentSpecification: 415 | """Given a ray and bounding box, compute the near and far t values that define a 416 | segment that lies fully in the box.""" 417 | assert ray_wrt_world.origins.shape == ray_wrt_world.directions.shape == (3,) 418 | assert aabb.shape == (2, 3) 419 | 420 | # Find t for per-axis collision with the bounding box. 421 | # origin + t * direction = bounding box 422 | # t = (bounding box - origin) / direction 423 | offsets = aabb - ray_wrt_world.origins[None, :] 424 | t_intersections = offsets / ( 425 | ray_wrt_world.directions + utils.eps_from_dtype(offsets.dtype) 426 | ) 427 | 428 | # Compute near/far distances. 429 | t_min_per_axis = jnp.min(t_intersections, axis=0) 430 | t_max_per_axis = jnp.max(t_intersections, axis=0) 431 | assert t_min_per_axis.shape == t_max_per_axis.shape == (3,) 432 | 433 | # Clip. 434 | t_min = jnp.maximum(0.0, jnp.max(t_min_per_axis)) 435 | t_max = jnp.min(t_max_per_axis) 436 | t_max_clipped = jnp.maximum(t_max, t_min + min_segment_length) 437 | 438 | # TODO: this should likely be returned as well, and used as a mask for supervision. 439 | # Currently our loss includes rays outside of the bounding box. 440 | valid_mask = t_min < t_max 441 | 442 | return RaySegmentSpecification( 443 | t_min=jnp.where(valid_mask, t_min, 0.0), 444 | t_max=jnp.where(valid_mask, t_max_clipped, min_segment_length), 445 | ) 446 | 447 | 448 | def _rgb_from_points( 449 | rays_wrt_world: cameras.Rays3D, 450 | probs: SegmentProbabilities, 451 | learnable_params: LearnableParams, 452 | points: jnp.ndarray, 453 | appearance_mlp: networks.FeatureMlp, 454 | config: RenderConfig, 455 | prng_key: jax.random.KeyArray, 456 | dtype: Any, 457 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 458 | """Helper for rendering RGB values. Returns an RGB array of shape `(ray_count, 459 | config.samples_per_ray, 3)`, and an unbiasing coefficient array of shape 460 | `(ray_count,)`. 461 | 462 | The original PyTorch implementation speeds up training by only rendering RGB values 463 | for which the termination probability (sample weight) exceeds a provided threshold, 464 | but this requires some boolean masking and dynamic shapes which, alas, are quite 465 | difficult in JAX. To reduce the number of appearance computations needed, we instead 466 | resort to a weighted sampling approach. 467 | """ 468 | ray_count = points.shape[1] 469 | assert points.shape == (3, ray_count, config.density_samples_per_ray) 470 | 471 | # Render the most visible points for each ray, with weighted random sampling. 472 | assert probs.p_terminates.shape == (ray_count, config.density_samples_per_ray) 473 | appearance_indices = jax.vmap( 474 | lambda p: jax.random.choice( 475 | key=prng_key, 476 | a=config.density_samples_per_ray, 477 | shape=(config.appearance_samples_per_ray,), 478 | replace=False, 479 | p=p, 480 | ) 481 | )(probs.p_terminates) 482 | assert appearance_indices.shape == (ray_count, config.appearance_samples_per_ray) 483 | 484 | visible_points = points[:, jnp.arange(ray_count)[:, None], appearance_indices] 485 | assert visible_points.shape == (3, ray_count, config.appearance_samples_per_ray) 486 | 487 | appearance_tensor = learnable_params.appearance_tensor 488 | appearance_feat = appearance_tensor.interpolate(visible_points) 489 | assert appearance_feat.shape == ( 490 | appearance_tensor.channel_dim(), 491 | ray_count, 492 | config.appearance_samples_per_ray, 493 | ) 494 | 495 | total_sample_count = ray_count * config.appearance_samples_per_ray 496 | appearance_feat = jnp.moveaxis(appearance_feat, 0, -1).reshape( 497 | (total_sample_count, appearance_tensor.channel_dim()) 498 | ) 499 | viewdirs = jnp.tile( 500 | rays_wrt_world.directions[:, None, :], 501 | (1, config.appearance_samples_per_ray, 1), 502 | ).reshape((-1, 3)) 503 | 504 | camera_indices = rays_wrt_world.camera_indices 505 | assert camera_indices.shape == (ray_count,) 506 | camera_indices = jnp.repeat( 507 | camera_indices, repeats=config.appearance_samples_per_ray, axis=0 508 | ) 509 | assert camera_indices.shape == (ray_count * config.appearance_samples_per_ray,) 510 | 511 | visible_rgb = cast( 512 | jnp.ndarray, 513 | appearance_mlp.apply( 514 | learnable_params.appearance_mlp_params, 515 | features=appearance_feat.reshape( 516 | (ray_count * config.appearance_samples_per_ray, -1) 517 | ), 518 | viewdirs=viewdirs, 519 | camera_indices=camera_indices, 520 | dtype=dtype, 521 | ), 522 | ).reshape((ray_count, config.appearance_samples_per_ray, 3)) 523 | 524 | rgb = ( 525 | jnp.zeros( 526 | (ray_count, config.density_samples_per_ray, 3), 527 | dtype=dtype, 528 | ) 529 | .at[jnp.arange(ray_count)[:, None], appearance_indices, :] 530 | .set(visible_rgb) 531 | ) 532 | assert rgb.shape == (ray_count, config.density_samples_per_ray, 3) 533 | 534 | # Coefficients for unbiasing the expected RGB values using the sampling 535 | # probabilities. This is helpful because RGB values for points that are not chosen 536 | # by our appearance sampler are zeroed out. 537 | # 538 | # As an example: if the weights* for all density samples is 0.95** but the sum of 539 | # weights for our appearance samples is only 0.7, we can correct the resulting 540 | # expected RGB value by scaling by (0.95/0.7). 541 | # 542 | # *weight at a segment = termination probability at that segment 543 | # **equivalently: p=0.05 of the ray exiting the last segment and rendering the 544 | # background. 545 | sampled_p_terminates = probs.p_terminates[ 546 | jnp.arange(ray_count)[:, None], appearance_indices 547 | ] 548 | assert sampled_p_terminates.shape == ( 549 | ray_count, 550 | config.appearance_samples_per_ray, 551 | ) 552 | 553 | unbias_coeff = ( 554 | # The 0.95 term in the example. 555 | 1.0 556 | - probs.p_exits[:, -1] 557 | + utils.eps_from_dtype(dtype) 558 | ) / ( 559 | # The 0.7 term in the example. 560 | jnp.sum(sampled_p_terminates, axis=1) 561 | + utils.eps_from_dtype(dtype) 562 | ) 563 | assert unbias_coeff.shape == (ray_count,) 564 | 565 | return rgb, unbias_coeff 566 | -------------------------------------------------------------------------------- /tensorf/tensor_vm.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, Callable, Tuple, Union 4 | 5 | import jax 6 | import jax.scipy 7 | import jax_dataclasses as jdc 8 | from jax import numpy as jnp 9 | 10 | Shape = Tuple[int, ...] 11 | Dtype = Any 12 | 13 | Scalar = Union[float, jnp.ndarray] 14 | 15 | 16 | @jdc.pytree_dataclass 17 | class TensorVM: 18 | """A tensor decomposition consisted of three vector-matrix pairs.""" 19 | 20 | stacked_single_vm: TensorVMSingle 21 | """Three vector-matrix pairs, stacked along axis 0.""" 22 | 23 | @staticmethod 24 | def initialize( 25 | grid_dim: int, 26 | per_axis_channel_dim: int, 27 | init: Callable[[jax.random.KeyArray, Shape, Dtype], jnp.ndarray], 28 | prng_key: jax.random.KeyArray, 29 | dtype: Dtype, 30 | ) -> TensorVM: 31 | prng_keys = jax.random.split(prng_key, 3) 32 | return TensorVM( 33 | stacked_single_vm=jax.vmap( 34 | lambda prng_key: TensorVMSingle.initialize( 35 | grid_dim, 36 | per_axis_channel_dim, 37 | init, 38 | prng_key=prng_key, 39 | dtype=dtype, 40 | ) 41 | )(prng_keys) 42 | ) 43 | 44 | def interpolate(self, ijk: jnp.ndarray) -> jnp.ndarray: 45 | """Look up a coordinate in our VM decomposition. 46 | 47 | Input should have shape (3, *) and be in the range [-1.0, 1.0]. 48 | Output should have shape (channel_dim * 3, *).""" 49 | 50 | batch_axes = ijk.shape[1:] 51 | assert ijk.shape == (3, *batch_axes) 52 | kij = ijk[jnp.array([2, 0, 1]), ...] 53 | jki = ijk[jnp.array([1, 2, 0]), ...] 54 | indices = jnp.stack([ijk, kij, jki], axis=0) 55 | assert indices.shape == (3, 3, *batch_axes) 56 | 57 | interpolate_func = TensorVMSingle.interpolate 58 | if len(batch_axes) >= 2: 59 | # TODO: this magic vmap is unnecessary and doesn't impact numerical results, 60 | # but enables a massive performance increase. This is 3~4x better training 61 | # throughput for single-precision, ~1.5x in mixed-precision. 62 | # 63 | # I'm not exactly sure why, but it appears to: 64 | # - Shuffle the memory layout and improve access patterns. 65 | # - Reduce the length of the HLO generated during tracing by ~150 lines. 66 | # 67 | # Some plots/discussion: https://github.com/google/jax/discussions/10332 68 | # 69 | # Setting the axis to -1 also produces a speedup, albeit a slightly smaller 70 | # one. Numerical results are identical in either case. 71 | interpolate_func = jax.vmap( 72 | interpolate_func, 73 | in_axes=(None, -2), 74 | out_axes=-2, 75 | ) 76 | 77 | # Vectorize over axis=0, which will be of size 3. (one for each vector-matrix 78 | # pair) 79 | # 80 | # Empirically, applying this after the magic vmap above is slightly 81 | # faster than applying it before. 82 | interpolate_func = jax.vmap(interpolate_func) 83 | 84 | feature = interpolate_func(self.stacked_single_vm, indices) 85 | assert feature.shape == (3, self.stacked_single_vm.channel_dim(), *batch_axes) 86 | 87 | # Note the original implementation also has a basis matrix that left-multiplies 88 | # here; we fold this into the appearance network. 89 | feature = feature.reshape( 90 | (3 * self.stacked_single_vm.channel_dim(), *batch_axes) 91 | ) 92 | return feature 93 | 94 | @jdc.jit 95 | def resize(self, grid_dim: jdc.Static[int]) -> TensorVM: 96 | """Resize our tensor decomposition.""" 97 | 98 | d: TensorVMSingle 99 | return TensorVM( 100 | stacked_single_vm=jax.vmap( 101 | lambda inner: TensorVMSingle.resize(inner, grid_dim=grid_dim) 102 | )(self.stacked_single_vm) 103 | ) 104 | 105 | def grid_dim(self) -> int: 106 | return self.stacked_single_vm.grid_dim() 107 | 108 | def channel_dim(self) -> int: 109 | return self.stacked_single_vm.channel_dim() * 3 110 | 111 | 112 | @jdc.pytree_dataclass 113 | class TensorVMSingle: 114 | """Helper for 4D tensors decomposed into a vector-matrix pair.""" 115 | 116 | vector: jnp.ndarray 117 | matrix: jnp.ndarray 118 | 119 | @staticmethod 120 | def initialize( 121 | grid_dim: int, 122 | channel_dim: int, 123 | init: Callable[[jax.random.KeyArray, Shape, Dtype], jnp.ndarray], 124 | prng_key: jax.random.KeyArray, 125 | dtype: Dtype, 126 | ) -> TensorVMSingle: 127 | """ "Initialize a VM-decomposed 4D tensor (depth, width, height, channel). 128 | 129 | For now, we assume that the depth/width/height dimensions are equal.""" 130 | key0, key1 = jax.random.split(prng_key) 131 | 132 | # Note that putting channel dimension first is *much* faster. 133 | # vector_shape = (grid_dim, channel_dim) 134 | # matrix_shape = (grid_dim, grid_dim, channel_dim) 135 | vector_shape = (channel_dim, grid_dim) 136 | matrix_shape = (channel_dim, grid_dim, grid_dim) 137 | 138 | return TensorVMSingle( 139 | vector=init(key0, vector_shape, dtype), 140 | matrix=init(key1, matrix_shape, dtype), 141 | ) 142 | 143 | def interpolate(self, ijk: jnp.ndarray) -> jnp.ndarray: 144 | """Grid lookup with interpolation. 145 | 146 | ijk should be of shape (3, *) all be within [-1, 1]. 147 | Output will have shape (channel_dim, *).""" 148 | batch_axes = ijk.shape[1:] 149 | assert ijk.shape == (3, *batch_axes) 150 | assert jnp.issubdtype(ijk.dtype, jnp.floating) 151 | 152 | # [-1.0, 1.0] => [0.0, 1.0] 153 | ijk = (ijk + 1.0) / 2.0 154 | 155 | # [0.0, 1.0] => [0.0, grid_dim - 1.0] 156 | ijk = ijk * (self.grid_dim() - 1.0) 157 | 158 | vector_coeffs = linear_interpolation_with_channel_axis( 159 | self.vector, coordinates=ijk[0:1, ...] 160 | ) 161 | matrix_coeffs = linear_interpolation_with_channel_axis( 162 | self.matrix, coordinates=ijk[1:3, ...] 163 | ) 164 | 165 | assert ( 166 | vector_coeffs.shape 167 | == matrix_coeffs.shape 168 | == (self.channel_dim(), *batch_axes) 169 | ) 170 | return vector_coeffs * matrix_coeffs 171 | 172 | def grid_dim(self) -> int: 173 | """Returns the grid dimension.""" 174 | r0, r1 = self.matrix.shape[-2:] 175 | r2 = self.vector.shape[-1] 176 | assert r0 == r1 == r2 177 | return r0 178 | 179 | def channel_dim(self) -> int: 180 | """Returns the channel dimension.""" 181 | c0 = self.matrix.shape[-3] 182 | c1 = self.vector.shape[-2] 183 | assert c0 == c1 184 | return c0 185 | 186 | @jdc.jit 187 | def resize(self, grid_dim: jdc.Static[int]) -> TensorVMSingle: 188 | """Resize our decomposition, while interpolating linearly.""" 189 | 190 | channel_dim = self.channel_dim() 191 | matrix_shape = (channel_dim, grid_dim, grid_dim) 192 | vector_shape = (channel_dim, grid_dim) 193 | 194 | return TensorVMSingle( 195 | # Note that antialiasing only happens when downsampling. 196 | matrix=resize_with_aligned_corners( 197 | self.matrix, matrix_shape, "linear", antialias=True 198 | ), 199 | vector=resize_with_aligned_corners( 200 | self.vector, vector_shape, "linear", antialias=True 201 | ), 202 | ) 203 | 204 | 205 | def resize_with_aligned_corners( 206 | image: jax.Array, 207 | shape: Tuple[int, ...], 208 | method: Union[str, jax.image.ResizeMethod], 209 | antialias: bool, 210 | ): 211 | """Alternative to jax.image.resize(), which emulates align_corners=True in PyTorch's 212 | interpolation functions.""" 213 | spatial_dims = tuple( 214 | i 215 | for i in range(len(shape)) 216 | if not jax.core.symbolic_equal_dim(image.shape[i], shape[i]) 217 | ) 218 | scale = jnp.array([(shape[i] - 1.0) / (image.shape[i] - 1.0) for i in spatial_dims]) 219 | translation = -(scale / 2.0 - 0.5) 220 | return jax.image.scale_and_translate( 221 | image, 222 | shape, 223 | method=method, 224 | scale=scale, 225 | spatial_dims=spatial_dims, 226 | translation=translation, 227 | antialias=antialias, 228 | ) 229 | 230 | 231 | def linear_interpolation_with_channel_axis( 232 | grid: jnp.ndarray, coordinates: jnp.ndarray 233 | ) -> jnp.ndarray: 234 | """Thin wrapper around `jax.scipy.ndimage.map_coordinates()` for linear 235 | interpolation. 236 | 237 | Standard set of shapes might look like: 238 | grid (C, 128, 128, 128) 239 | coordinates (3, *) 240 | 241 | Which would return: 242 | (C, *) 243 | """ 244 | assert len(grid.shape[1:]) == coordinates.shape[0] 245 | # vmap to add a channel axis. 246 | output = jax.vmap( 247 | lambda g: jax.scipy.ndimage.map_coordinates( 248 | g, 249 | coordinates=tuple(coordinates), 250 | order=1, 251 | mode="nearest", 252 | ) 253 | )(grid) 254 | assert output.shape == grid.shape[:1] + coordinates.shape[1:] 255 | return output 256 | -------------------------------------------------------------------------------- /tensorf/train_config.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import pathlib 3 | from typing import Literal, Optional, Tuple 4 | 5 | 6 | @dataclasses.dataclass(frozen=True) 7 | class OptimizerConfig: 8 | lr_init_tensor: float = 0.02 # `lr_init` in the original code. 9 | lr_init_mlp: float = 1e-3 # `lr_basis` in the original code. 10 | lr_decay_iters: Optional[int] = None # If unset, defaults to n_iters. 11 | lr_decay_target_ratio: float = 0.1 12 | lr_upsample_reset: bool = True # Reset learning rate after upsampling. 13 | 14 | 15 | @dataclasses.dataclass(frozen=True) 16 | class TensorfConfig: 17 | run_dir: pathlib.Path 18 | 19 | # Input data directory. 20 | dataset_path: pathlib.Path 21 | 22 | # Dataset type. 23 | dataset_type: Literal["blender", "nerfstudio"] = "blender" 24 | 25 | # Training options. 26 | minibatch_size: int = 4096 27 | n_iters: int = 30000 28 | 29 | # Optimizer configuration. 30 | optimizer: OptimizerConfig = dataclasses.field(default_factory=OptimizerConfig) 31 | 32 | # Loss options. 33 | # TODO: these are not yet implemented :') 34 | # l1_weight_initial: float = 0.0 35 | # l1_weight_rest: float = 0.0 36 | # ortho_weight: float = 0.0 37 | # tv_weight_density: float = 0.0 38 | # tv_weight_app: float = 0.0 39 | 40 | initial_aabb_min: Tuple[float, float, float] = (-1.0, -1.0, -1.0) 41 | initial_aabb_max: Tuple[float, float, float] = (1.0, 1.0, 1.0) 42 | 43 | # Per-axis tensor decomposition components. 44 | appearance_feat_dim: int = 24 # n_lambd_sh 45 | density_feat_dim: int = 8 # n_lambd_sigma 46 | 47 | # Fourier feature frequency counts for both the interpolated feature vector and view 48 | # direction count; these are used in the appearance MLP. 49 | feature_n_freqs: int = 6 # fea_pe 50 | viewdir_n_freqs: int = 6 # view_pe 51 | 52 | # Grid parameters; we define the initial and final grid dimensions as well as when 53 | # to upsample or update the alpha mask. 54 | grid_dim_init: int = 128 # cbrt(N_voxel_init) 55 | grid_dim_final: int = 300 # cbrt(N_voxel_final) 56 | upsamp_iters: Tuple[int, ...] = (2000, 3000, 4000, 5500, 7000) 57 | 58 | # TODO: unimplemented. (can we even implement this in JAX?) 59 | # update_alphamask_iters: Tuple[int, ...] = (2000, 4000) 60 | 61 | # If enabled, we use mixed-precision training. This seems to work and speeds up 62 | # training throughput by a significant factor, but is disabled by default because we 63 | # haven't fully evaluated stability, impact on convergence, hyperparameters, etc. 64 | # 65 | # Important: if mixed precision is enabled, the loss scale should generally be set 66 | # to something high! 67 | mixed_precision: bool = False 68 | 69 | # Loss scale for preventing gradient underflow. 70 | # 71 | # Applied always but useful mostly for mixed-precision training, where we observe a 72 | # tradeoff where a higher value will produce lower errors and improve convergence, 73 | # but can run slower despite a nearly identical computation graph. (possibly due to 74 | # some reduced sparsity of gradients?) 75 | loss_scale: float = 1.0 76 | 77 | # Apply MipNeRF-360-inspired scene contraction. Useful for real data. 78 | scene_contraction: bool = False 79 | scene_scale: float = 1.0 80 | 81 | # Add NeRF in the wild-inspired camera embeddings. 82 | camera_embeddings: bool = False 83 | 84 | # Near and far limits for rendering. 85 | render_near: float = 0.05 86 | render_far: float = 200.0 87 | train_ray_sample_multiplier: float = 1.0 88 | -------------------------------------------------------------------------------- /tensorf/training.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | import random 5 | from typing import Tuple 6 | 7 | import fifteen 8 | import flax 9 | import jax 10 | import jax_dataclasses as jdc 11 | import optax 12 | from jax import numpy as jnp 13 | from tqdm.auto import tqdm 14 | from typing_extensions import Annotated, assert_never 15 | 16 | from . import data, networks, render, tensor_vm, train_config, utils 17 | 18 | 19 | @jdc.pytree_dataclass 20 | class TrainState(jdc.EnforcedAnnotationsMixin): 21 | config: jdc.Static[train_config.TensorfConfig] 22 | 23 | # Representation/parameters. 24 | appearance_mlp: jdc.Static[networks.FeatureMlp] 25 | learnable_params: render.LearnableParams 26 | 27 | # Optimizer. 28 | optimizer: jdc.Static[optax.GradientTransformation] 29 | optimizer_state: optax.OptState 30 | 31 | # Current axis-aligned bounding box. 32 | aabb: Annotated[jnp.ndarray, jnp.floating, (2, 3)] 33 | 34 | # Misc. 35 | prng_key: jax.random.KeyArray 36 | step: Annotated[jnp.ndarray, jnp.integer, ()] 37 | 38 | @staticmethod 39 | @jdc.jit 40 | def initialize( 41 | config: jdc.Static[train_config.TensorfConfig], 42 | grid_dim: jdc.Static[int], 43 | prng_key: jax.random.KeyArray, 44 | num_cameras: jdc.Static[int], 45 | ) -> TrainState: 46 | prng_keys = jax.random.split(prng_key, 5) 47 | normal_init = jax.nn.initializers.normal(stddev=0.1) 48 | 49 | def make_mlp() -> Tuple[networks.FeatureMlp, flax.core.FrozenDict]: 50 | dummy_features = jnp.zeros((1, config.appearance_feat_dim * 3)) 51 | dummy_viewdirs = jnp.zeros((1, 3)) 52 | appearance_mlp = networks.FeatureMlp( 53 | feature_n_freqs=config.feature_n_freqs, 54 | viewdir_n_freqs=config.viewdir_n_freqs, 55 | # If num_cameras is set, camera embeddings are enabled. 56 | num_cameras=num_cameras if config.camera_embeddings else None, 57 | ) 58 | dummy_camera_indices = jnp.zeros((1,), dtype=jnp.uint32) 59 | 60 | appearance_mlp_params = appearance_mlp.init( 61 | prng_keys[0], 62 | features=dummy_features, 63 | viewdirs=dummy_viewdirs, 64 | camera_indices=dummy_camera_indices, 65 | ) 66 | return appearance_mlp, appearance_mlp_params 67 | 68 | appearance_mlp, appearance_mlp_params = make_mlp() 69 | 70 | learnable_params = render.LearnableParams( 71 | appearance_mlp_params=appearance_mlp_params, 72 | appearance_tensor=tensor_vm.TensorVM.initialize( 73 | grid_dim=grid_dim, 74 | per_axis_channel_dim=config.appearance_feat_dim, 75 | init=normal_init, 76 | prng_key=prng_keys[1], 77 | dtype=jnp.float32, # Main copy of parameters are always float32. 78 | ), 79 | density_tensor=tensor_vm.TensorVM.initialize( 80 | grid_dim=grid_dim, 81 | per_axis_channel_dim=config.density_feat_dim, 82 | init=normal_init, 83 | prng_key=prng_keys[2], 84 | dtype=jnp.float32, 85 | ), 86 | scene_contraction=config.scene_contraction, 87 | ) 88 | optimizer = TrainState._make_optimizer( 89 | config.optimizer, config.scene_contraction 90 | ) 91 | optimizer_state = optimizer.init(learnable_params) 92 | 93 | return TrainState( 94 | config=config, 95 | appearance_mlp=appearance_mlp, 96 | learnable_params=learnable_params, 97 | optimizer=optimizer, 98 | optimizer_state=optimizer_state, 99 | aabb=jnp.array([config.initial_aabb_min, config.initial_aabb_max]), 100 | prng_key=prng_keys[4], 101 | step=jnp.array(0), 102 | ) 103 | 104 | @jdc.jit(donate_argnums=0) 105 | def training_step( 106 | self, minibatch: data.RenderedRays 107 | ) -> Tuple[TrainState, fifteen.experiments.TensorboardLogData]: 108 | """Single training step.""" 109 | render_prng_key, new_prng_key = jax.random.split(self.prng_key) 110 | 111 | # If in mixed-precision mode, we render and backprop in float16. 112 | if self.config.mixed_precision: 113 | compute_dtype = jnp.float16 114 | else: 115 | compute_dtype = jnp.float32 116 | 117 | def compute_loss( 118 | learnable_params: render.LearnableParams, 119 | ) -> Tuple[jnp.ndarray, fifteen.experiments.TensorboardLogData]: 120 | # Compute sample counts from grid dimensionality. 121 | # TODO: move heuristics into config? 122 | grid_dim = self.learnable_params.appearance_tensor.grid_dim() 123 | assert grid_dim == self.learnable_params.density_tensor.grid_dim() 124 | density_samples_per_ray = int( 125 | math.sqrt(3 * grid_dim**2) * self.config.train_ray_sample_multiplier 126 | ) 127 | appearance_samples_per_ray = int(0.15 * density_samples_per_ray) 128 | 129 | # Render and compute loss. 130 | rendered = render.render_rays( 131 | appearance_mlp=self.appearance_mlp, 132 | learnable_params=learnable_params, 133 | aabb=self.aabb, 134 | rays_wrt_world=minibatch.rays_wrt_world, 135 | prng_key=render_prng_key, 136 | config=render.RenderConfig( 137 | near=self.config.render_near, 138 | far=self.config.render_far, 139 | mode=render.RenderMode.RGB, 140 | density_samples_per_ray=density_samples_per_ray, 141 | appearance_samples_per_ray=appearance_samples_per_ray, 142 | ), 143 | dtype=compute_dtype, 144 | ) 145 | assert ( 146 | rendered.shape 147 | == minibatch.colors.shape 148 | == minibatch.get_batch_axes() + (3,) 149 | ) 150 | label_colors = minibatch.colors 151 | assert jnp.issubdtype(rendered.dtype, compute_dtype) 152 | assert jnp.issubdtype(label_colors.dtype, jnp.float32) 153 | 154 | mse = jnp.mean((rendered - label_colors) ** 2) 155 | loss = mse # TODO: add regularization terms. 156 | 157 | log_data = fifteen.experiments.TensorboardLogData( 158 | scalars={ 159 | "mse": mse, 160 | "psnr": utils.psnr_from_mse(mse), 161 | } 162 | ) 163 | return loss * self.config.loss_scale, log_data 164 | 165 | # Compute gradients. 166 | log_data: fifteen.experiments.TensorboardLogData 167 | grads: render.LearnableParams 168 | learnable_params = jax.tree_map( 169 | # Cast parameters to desired precision. 170 | lambda x: x.astype(compute_dtype), 171 | self.learnable_params, 172 | ) 173 | (loss, log_data), grads = jax.value_and_grad( 174 | compute_loss, 175 | has_aux=True, 176 | )(learnable_params) 177 | 178 | # To prevent NaNs from momentum computations in mixed-precision mode, it's 179 | # important that gradients are float32 before being passed to the optimizer. 180 | grads_unscaled = jax.tree_map( 181 | lambda x: x.astype(jnp.float32) / self.config.loss_scale, 182 | grads, 183 | ) 184 | assert jnp.issubdtype( 185 | jax.tree_util.tree_leaves(grads_unscaled)[0].dtype, jnp.float32 186 | ), "Gradients should always be float32." 187 | 188 | # Compute learning rate decay. 189 | # We could put this in the optax chain as well, but explicitly computing here 190 | # makes logging & reset handling easier. 191 | if self.config.optimizer.lr_upsample_reset: 192 | # For resetting after upsampling, we find the smallest non-negative value of 193 | # (current step - upsampling iteration #). 194 | step_deltas = self.step - jnp.array((0,) + self.config.upsamp_iters) 195 | step_deltas = jnp.where( 196 | step_deltas >= 0, step_deltas, jnp.iinfo(step_deltas.dtype).max 197 | ) 198 | resetted_step = jnp.min(step_deltas) 199 | else: 200 | resetted_step = self.step 201 | 202 | decay_iters = self.config.optimizer.lr_decay_iters 203 | if decay_iters is None: 204 | decay_iters = self.config.n_iters 205 | 206 | lr_decay_coeff = optax.exponential_decay( 207 | init_value=1.0, 208 | transition_steps=decay_iters, 209 | decay_rate=self.config.optimizer.lr_decay_target_ratio, 210 | end_value=self.config.optimizer.lr_decay_target_ratio, 211 | )(resetted_step) 212 | 213 | # Propagate gradients through ADAM, learning rate scheduler, etc. 214 | updates, new_optimizer_state = self.optimizer.update( 215 | grads_unscaled, self.optimizer_state, self.learnable_params 216 | ) 217 | updates = jax.tree_map(lambda x: lr_decay_coeff * x, updates) 218 | 219 | # Add learning rates to Tensorboard logs. 220 | log_data = log_data.merge_scalars( 221 | { 222 | "lr_tensor": lr_decay_coeff * self.config.optimizer.lr_init_tensor, 223 | "lr_mlp": lr_decay_coeff * self.config.optimizer.lr_init_mlp, 224 | "grad_norm": optax.global_norm(grads), 225 | } 226 | ) 227 | 228 | with jdc.copy_and_mutate(self, validate=True) as new_state: 229 | new_state.optimizer_state = new_optimizer_state 230 | new_state.learnable_params = optax.apply_updates( 231 | self.learnable_params, updates 232 | ) 233 | new_state.prng_key = new_prng_key 234 | new_state.step = new_state.step + 1 235 | return new_state, log_data.prefix("train/") 236 | 237 | @staticmethod 238 | def _make_optimizer( 239 | config: train_config.OptimizerConfig, 240 | scene_contraction: bool, 241 | ) -> optax.GradientTransformation: 242 | """Set up Adam optimizer.""" 243 | return optax.chain( 244 | # First, we rescale gradients with ADAM. Note that eps=1e-8 is OK because 245 | # gradients are always converted to float32 before being passed to the 246 | # optimizer. 247 | optax.scale_by_adam( 248 | b1=0.9, 249 | b2=0.99, 250 | eps=1e-8, 251 | eps_root=0.0, 252 | ), 253 | # Apply MLP parameter learning rate. Note the negative sign needed for 254 | # gradient descent. 255 | optax.masked( 256 | optax.scale(-config.lr_init_mlp), 257 | mask=render.LearnableParams( 258 | appearance_mlp_params=True, # type: ignore 259 | appearance_tensor=False, # type: ignore 260 | density_tensor=False, # type: ignore 261 | scene_contraction=scene_contraction, 262 | ), 263 | ), 264 | # Apply tensor decomposition learning rate. Note the negative sign needed 265 | # for gradient descent. 266 | optax.masked( 267 | optax.scale(-config.lr_init_tensor), 268 | mask=render.LearnableParams( 269 | appearance_mlp_params=False, # type: ignore 270 | appearance_tensor=True, # type: ignore 271 | density_tensor=True, # type: ignore 272 | scene_contraction=scene_contraction, 273 | ), 274 | ), 275 | ) 276 | 277 | def resize_grid(self, new_grid_dim: int) -> TrainState: 278 | """Resize the grid underlying a training state by linearly interpolating grid 279 | parameters.""" 280 | with jdc.copy_and_mutate(self, validate=False) as resized: 281 | # Resample the feature grids, with linear interpolation. 282 | resized.learnable_params.density_tensor = ( 283 | resized.learnable_params.density_tensor.resize(new_grid_dim) 284 | ) 285 | resized.learnable_params.appearance_tensor = ( 286 | resized.learnable_params.appearance_tensor.resize(new_grid_dim) 287 | ) 288 | 289 | # Perform some nasty surgery to resample the momentum parameters as well. 290 | adam_state = resized.optimizer_state[0] 291 | assert isinstance(adam_state, optax.ScaleByAdamState) 292 | nu: render.LearnableParams = adam_state.nu 293 | mu: render.LearnableParams = adam_state.mu 294 | resized.optimizer_state = ( 295 | adam_state._replace( # NamedTuple `_replace()`. 296 | nu=jdc.replace( 297 | nu, 298 | density_tensor=nu.density_tensor.resize(new_grid_dim), 299 | appearance_tensor=nu.appearance_tensor.resize(new_grid_dim), 300 | ), 301 | mu=jdc.replace( 302 | mu, 303 | density_tensor=mu.density_tensor.resize(new_grid_dim), 304 | appearance_tensor=mu.appearance_tensor.resize(new_grid_dim), 305 | ), 306 | ), 307 | ) + resized.optimizer_state[1:] 308 | return resized 309 | 310 | 311 | def run_training_loop( 312 | config: train_config.TensorfConfig, 313 | restore_checkpoint: bool = False, 314 | clear_existing: bool = False, 315 | ) -> None: 316 | """Full training loop implementation.""" 317 | 318 | # Set up our experiment: for checkpoints, logs, metadata, etc. 319 | experiment = fifteen.experiments.Experiment(data_dir=config.run_dir) 320 | if restore_checkpoint: 321 | experiment.assert_exists() 322 | config = experiment.read_metadata("config", train_config.TensorfConfig) 323 | else: 324 | if clear_existing: 325 | experiment.clear() 326 | else: 327 | experiment.assert_new() 328 | experiment.write_metadata("config", config) 329 | 330 | # Load dataset. 331 | dataset = data.make_dataset( 332 | config.dataset_type, 333 | config.dataset_path, 334 | config.scene_scale, 335 | ) 336 | num_cameras = len(dataset.get_cameras()) 337 | experiment.write_metadata("num_cameras", num_cameras) 338 | 339 | # Initialize training state. 340 | train_state: TrainState 341 | train_state = TrainState.initialize( 342 | config, 343 | grid_dim=config.grid_dim_init, 344 | prng_key=jax.random.PRNGKey(94709), 345 | num_cameras=num_cameras, 346 | ) 347 | if restore_checkpoint: 348 | train_state = experiment.restore_checkpoint(train_state) 349 | 350 | dataloader = fifteen.data.InMemoryDataLoader( 351 | dataset=dataset.get_training_rays(), 352 | minibatch_size=config.minibatch_size, 353 | ) 354 | minibatches = fifteen.data.cycled_minibatches(dataloader, shuffle_seed=0) 355 | minibatches = iter(minibatches) 356 | 357 | # Run! 358 | print("Training with config:", config) 359 | loop_metrics: fifteen.utils.LoopMetrics 360 | for loop_metrics in tqdm( 361 | fifteen.utils.range_with_metrics(config.n_iters - int(train_state.step)), 362 | desc="Training", 363 | ): 364 | # Load minibatch. 365 | minibatch = next(minibatches) 366 | assert minibatch.get_batch_axes() == (config.minibatch_size,) 367 | assert minibatch.colors.shape == (config.minibatch_size, 3) 368 | 369 | # Training step. 370 | log_data: fifteen.experiments.TensorboardLogData 371 | train_state, log_data = train_state.training_step(minibatch) 372 | 373 | # Log & checkpoint. 374 | train_step = int(train_state.step) 375 | experiment.log( 376 | log_data.merge_scalars( 377 | {"train/iterations_per_sec": loop_metrics.iterations_per_sec} 378 | ), 379 | step=train_step, 380 | log_scalars_every_n=5, 381 | log_histograms_every_n=100, 382 | ) 383 | if train_step % 1000 == 0: 384 | experiment.save_checkpoint( 385 | train_state, 386 | step=int(train_state.step), 387 | keep_every_n_steps=2000, 388 | ) 389 | 390 | # Grid upsampling. We linearly interpolate between the initial and final grid 391 | # dimensionalities. 392 | if train_step in config.upsamp_iters: 393 | upsamp_index = config.upsamp_iters.index(train_step) 394 | train_state = train_state.resize_grid( 395 | new_grid_dim=int( 396 | config.grid_dim_init 397 | + (config.grid_dim_final - config.grid_dim_init) 398 | * ((upsamp_index + 1) / len(config.upsamp_iters)) 399 | ) 400 | ) 401 | -------------------------------------------------------------------------------- /tensorf/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from jax import numpy as jnp 4 | 5 | 6 | def psnr_from_mse(mse: jnp.ndarray) -> jnp.ndarray: 7 | # Threshold to avoid NaNs. 8 | mse = jnp.maximum( 9 | mse, 10 | eps_from_dtype( 11 | mse.dtype, 12 | eps_f16=1e-7, 13 | eps_f32=1e-10, 14 | ), 15 | ) 16 | 17 | psnr = -10.0 * jnp.log10(mse) 18 | return psnr.astype(jnp.float32) 19 | 20 | 21 | def eps_from_dtype(dtype: Any, eps_f16: float = 1e-4, eps_f32: float = 1e-8) -> float: 22 | """Get precision constants from data-type.""" 23 | if jnp.issubdtype(dtype, jnp.float16): 24 | return eps_f16 25 | elif jnp.issubdtype(dtype, jnp.float32): 26 | return eps_f32 27 | else: 28 | assert False 29 | -------------------------------------------------------------------------------- /train_lego.py: -------------------------------------------------------------------------------- 1 | """Training script for lego dataset. 2 | 3 | For helptext, try running: 4 | ``` 5 | python train_lego.py --help 6 | ``` 7 | """ 8 | 9 | import functools 10 | import pathlib 11 | 12 | import fifteen 13 | import tyro 14 | 15 | import tensorf.train_config 16 | import tensorf.training 17 | 18 | if __name__ == "__main__": 19 | # Open PDB after runtime errors. 20 | fifteen.utils.pdb_safety_net() 21 | 22 | # Default configuration for lego dataset. 23 | default_config = tensorf.train_config.TensorfConfig( 24 | run_dir=pathlib.Path(f"./runs/lego-{fifteen.utils.timestamp()}"), 25 | dataset_path=pathlib.Path("./data/nerf_synthetic/lego"), 26 | dataset_type="blender", 27 | n_iters=30000, 28 | initial_aabb_min=(-0.6585, -1.1833, -0.4651), 29 | initial_aabb_max=(0.6636, 1.1929, 1.0512), 30 | appearance_feat_dim=48, 31 | density_feat_dim=16, 32 | feature_n_freqs=2, 33 | viewdir_n_freqs=2, 34 | grid_dim_init=128, 35 | grid_dim_final=300, 36 | upsamp_iters=(2000, 3000, 4000, 5500, 7000), 37 | ) 38 | 39 | # Run training loop! Note that we can set a default value for a function via 40 | # `functools.partial()`. 41 | tyro.cli( 42 | functools.partial( 43 | tensorf.training.run_training_loop, 44 | config=default_config, 45 | ) 46 | ) 47 | -------------------------------------------------------------------------------- /train_nerfstudio.py: -------------------------------------------------------------------------------- 1 | """Training script for real scenes stored using the nerfstudio format. 2 | 3 | For helptext, try running: 4 | ``` 5 | python train_nerfstudio.py --help 6 | ``` 7 | """ 8 | 9 | import functools 10 | import pathlib 11 | 12 | import fifteen 13 | import tyro 14 | 15 | import tensorf.train_config 16 | import tensorf.training 17 | 18 | if __name__ == "__main__": 19 | # Open PDB after runtime errors. 20 | fifteen.utils.pdb_safety_net() 21 | 22 | # Default configuration for nerfstudio dataset. 23 | default_config = tensorf.train_config.TensorfConfig( 24 | run_dir=pathlib.Path(f"./runs/nerfstudio-{fifteen.utils.timestamp()}"), 25 | dataset_path=pathlib.Path("./data/dozer"), 26 | dataset_type="nerfstudio", 27 | n_iters=30000, 28 | # Note that the aabb is ignored when scene contraction is on. 29 | initial_aabb_min=(-2.0, -2.0, -2.0), 30 | initial_aabb_max=(2.0, 2.0, 2.0), 31 | appearance_feat_dim=48, 32 | density_feat_dim=32, 33 | feature_n_freqs=6, 34 | viewdir_n_freqs=6, 35 | grid_dim_init=128, 36 | grid_dim_final=300, 37 | upsamp_iters=(2_500, 5_000, 10_000), 38 | scene_contraction=True, 39 | camera_embeddings=True, 40 | render_near=0.05, 41 | render_far=200.0, 42 | train_ray_sample_multiplier=3.0, 43 | minibatch_size=2048, 44 | ) 45 | 46 | # Run training loop! Note that we can set a default value for a function via 47 | # `functools.partial()`. 48 | tyro.cli( 49 | functools.partial( 50 | tensorf.training.run_training_loop, 51 | config=default_config, 52 | ) 53 | ) 54 | --------------------------------------------------------------------------------