├── LICENSE ├── README.md ├── setup.py ├── src └── multires_hash_encoding │ ├── __init__.py │ ├── hash_tensor.py │ ├── interpolate.py │ └── modules.py └── tests ├── test_hash_tensor.py ├── test_interpolate.py └── test_modules.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Penn Jenks 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Hash Encoding 2 | 3 | This is an unofficial pytorch implementation of the key datastructures from [Instant Neural Graphics Primitives](https://github.com/NVlabs/instant-ngp) 4 | 5 | ``` 6 | @article{mueller2022instant, 7 | title = {Instant Neural Graphics Primitives with a Multiresolution Hash Encoding}, 8 | author = {Thomas M\"uller and Alex Evans and Christoph Schied and Alexander Keller}, 9 | journal = {arXiv:2201.05989}, 10 | year = {2022}, 11 | month = jan 12 | } 13 | ``` 14 | 15 | For an example of how to create a drop in replacement for standard NeRF models, take a look at: 16 | - [https://github.com/jenkspt/NeuS/tree/hash](https://github.com/jenkspt/NeuS/tree/hash) 17 | - [https://github.com/jenkspt/NeuS/blob/hash/models/hash_fields.py](https://github.com/jenkspt/NeuS/blob/hash/models/hash_fields.py) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="multires-hash-encoding-pytorch", 8 | version="0.0.1", 9 | author="Penn Jenks", 10 | author_email="jenkspt@gmail.com", 11 | description="", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="", 15 | project_urls={ 16 | "Bug Tracker": "", 17 | }, 18 | classifiers=[ 19 | "Programming Language :: Python :: 3", 20 | "License :: OSI Approved :: MIT License", 21 | "Operating System :: OS Independent", 22 | ], 23 | package_dir={"": "src"}, 24 | packages=setuptools.find_packages(where="src"), 25 | python_requires=">=3.6", 26 | install_requires=[ 27 | 'torch', 28 | ] 29 | ) 30 | -------------------------------------------------------------------------------- /src/multires_hash_encoding/__init__.py: -------------------------------------------------------------------------------- 1 | from .hash_tensor import HashTensor 2 | from .interpolate import nd_linear_interp, Interpolate 3 | from .modules import * 4 | -------------------------------------------------------------------------------- /src/multires_hash_encoding/hash_tensor.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Any, Iterable, List 2 | from math import log, exp 3 | from functools import reduce 4 | import operator 5 | from dataclasses import dataclass, field 6 | import numpy as np 7 | import torch 8 | from torch import Tensor 9 | import torch.nn as nn 10 | 11 | Shape = Iterable[int] 12 | 13 | 14 | #PRIMES = torch.tensor([1, 73856093, 19349663, 83492791]) 15 | 16 | def spatial_hash(coords: List[Tensor]) -> Tensor: 17 | PRIMES = (1, 2654435761, 805459861, 3674653429) 18 | #assert len(coords) <= len(PRIMES), "Add more PRIMES!" 19 | if len(coords) == 1: 20 | i = (coords[0] ^ PRIMES[1]) 21 | else: 22 | i = coords[0] ^ PRIMES[0] 23 | for c, p in zip(coords[1:], PRIMES[1:]): 24 | i ^= c * p 25 | return i 26 | 27 | 28 | class HashTensor(nn.Module): 29 | """ 30 | This is a sparse array backed by simple hash table. It minimally implements an array 31 | interface as to be used for (nd) linear interpolation. 32 | There is no collision resolution or even bounds checking. 33 | 34 | Attributes: 35 | data: The hash table represented as a 2D array. 36 | First dim is the feature and second dim is indexed with the hash index 37 | shape: The shape of the array. 38 | 39 | NVIDIA Implementation of multi-res hash grid: 40 | https://github.com/NVlabs/tiny-cuda-nn/blob/master/include/tiny-cuda-nn/encodings/grid.h#L66-L80 41 | """ 42 | 43 | def __init__(self, data, shape): 44 | """ 45 | Attributes: 46 | data: The hash table represented as a 2D array. 47 | First dim is the feature and second dim is indexed with the hash index 48 | shape: The shape of the array. 49 | """ 50 | assert data.ndim == 2, "Hash table data should be 2d" 51 | assert data.shape[0] == shape[0] 52 | super().__init__() 53 | self.data = data 54 | self.shape = shape 55 | 56 | @property 57 | def ndim(self): 58 | return len(self.shape) 59 | 60 | @property 61 | def dtype(self): 62 | return self.data.dtype 63 | 64 | @property 65 | def device(self): 66 | return self.data.device 67 | 68 | def forward(self, index): 69 | #feature_i, *spatial_i = i if len(i) == self.ndim else (Ellipsis, *i) 70 | assert len(index) == self.ndim 71 | feature_i, *spatial_i = index 72 | i = spatial_hash(spatial_i) % self.data.shape[1] 73 | return self.data[feature_i, i] 74 | 75 | def __getitem__(self, index): 76 | return self.forward(index) 77 | 78 | def __array__(self, dtype=None): 79 | _, *S = self.shape 80 | index = torch.meshgrid(*(torch.arange(s) for s in S)) 81 | arr = self[(slice(0, None), *index)].detach().cpu().__array__(dtype) 82 | return arr 83 | 84 | def __repr__(self): 85 | return "HashTensor(" + str(np.asarray(self)) + ")" 86 | 87 | 88 | def growth_factor(levels: int, minres: int, maxres: int): 89 | return exp((log(maxres) - log(minres)) / (levels - 1)) 90 | 91 | 92 | def _get_level_res(levels: int, minres: int, maxres: int): 93 | b = growth_factor(levels, minres, maxres) 94 | res = [int(round(minres * (b ** l))) for l in range(0, levels)] 95 | return res 96 | 97 | 98 | def _get_level_res_nd(levels: int, minres: Iterable[int], maxres: Iterable[int]): 99 | it = (_get_level_res(levels, _min, _max) 100 | for _min, _max in zip(minres, maxres)) 101 | return list(zip(*it)) 102 | -------------------------------------------------------------------------------- /src/multires_hash_encoding/interpolate.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from functools import partial 3 | import torch 4 | from torch import Tensor, dtype, device, Size 5 | import torch.nn as nn 6 | 7 | __all__ = ["nd_linear_interp", "Interpolate"] 8 | 9 | 10 | def nd_corners(d: int, dtype: Optional[dtype]=None, device: Optional[device]=None) -> Tensor: 11 | """ 12 | Generates the corner coordinates for an N dimensional hyper cube. 13 | with d=2 will generate 4 coordinates, for d=3 will generate 8 coordinates, etc ... 14 | """ 15 | xi = [torch.arange(2, dtype=dtype, device=device) for i in range(d)] 16 | corners = torch.stack(torch.meshgrid(*xi, indexing='ij'), -1) 17 | return corners.reshape(1, 2**d, d) 18 | 19 | 20 | @torch.jit.script 21 | def weights_fn(x: Tensor, i: Tensor) -> Tensor: 22 | return torch.abs(x - torch.floor(x) + i - 1) 23 | 24 | 25 | @torch.jit.script 26 | def index_fn(x: Tensor, i: Tensor) -> Tensor: 27 | return torch.floor(x).to(i.dtype) + i 28 | 29 | 30 | def nearest_index_fn(i: Tensor, shape: Size): 31 | """ Replaces out of bounds index with the nearest valid index """ 32 | high = torch.tensor(shape, dtype=i.dtype, device=i.device) 33 | low = torch.zeros_like(high) 34 | return i.clamp(low, high - 1) 35 | 36 | 37 | def nd_linear_interp( 38 | input: Tensor, 39 | coords: Tensor, 40 | mode: Optional[str] = 'nearest', 41 | corners: Optional[Tensor]=None) -> Tensor: 42 | assert coords.shape[-1] <= input.ndim 43 | 44 | *S, d = coords.shape 45 | corners = nd_corners( 46 | d, torch.int64, coords.device) if corners is None else corners 47 | coords = coords.reshape(-1, 1, d) # Combine broadcast dimensions 48 | 49 | weights = weights_fn(coords, corners).prod(-1) # [N, 2**D] 50 | index = index_fn(coords, corners) # [N, 2**D, D] 51 | if mode == None: 52 | pass 53 | elif mode == 'nearest': 54 | index = nearest_index_fn(index, input.shape[-d:]) 55 | else: 56 | raise ValueError("only `nearest` mode or `None` is currently supported") 57 | values = (weights * input[(..., *index.unbind(-1))]).sum(-1) # [..., N]) 58 | return values.reshape(*input.shape[:-d], *S) 59 | 60 | 61 | class Interpolate(nn.Module): 62 | """ N-d Interpolation class """ 63 | 64 | def __init__(self, input: Tensor, d: int, order=1, mode='nearest'): 65 | """ 66 | Args: 67 | input: The input array 68 | d: Dimension of interpolation. d=2 will perform bilinear interpolation 69 | on the last 2 dimensions of `input` (i.e. image with shape [3, H, W]) 70 | order: Interpolation order. default is 1 71 | mode: determines how the input array is extended beyond its boundaries. 72 | Default is 'nearest' 73 | """ 74 | super().__init__() 75 | assert order == 1 76 | assert mode in (None, 'nearest') 77 | 78 | self.input = input 79 | self.d = d 80 | self.order = order 81 | self.mode = mode 82 | 83 | self.corners = nd_corners(d, torch.int64, input.device) 84 | self._shape = torch.tensor( 85 | input.shape[-d:], dtype=input.dtype, device=input.device) 86 | 87 | def _unnormalize(self, coords: Tensor) -> Tensor: 88 | """ map from [-1, 1] ---> grid coords """ 89 | return (coords + 1) / 2 * (self._shape - 1) 90 | 91 | def forward(self, coords, normalized=False): 92 | """ If normalized -- assumes coords are in range [-1, 1]""" 93 | if normalized: 94 | coords = self._unnormalize(coords) 95 | return nd_linear_interp(self.input, coords, self.mode, self.corners) -------------------------------------------------------------------------------- /src/multires_hash_encoding/modules.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | from dataclasses import dataclass 3 | import torch 4 | import torch.nn as nn 5 | 6 | from multires_hash_encoding.hash_tensor import HashTensor, _get_level_res_nd 7 | from multires_hash_encoding.interpolate import Interpolate 8 | 9 | __all__ = [ 10 | "DenseEncodingLevel", 11 | "HashEncodingLevel", 12 | "MultiresEncoding", 13 | "MultiresEncodingConfig", 14 | "MLP", 15 | "ViewEncoding", 16 | "MultiresHashNeRF", 17 | ] 18 | 19 | Shape = Iterable[int] 20 | 21 | 22 | class DenseEncodingLevel(nn.Module): 23 | def __init__(self, shape, device=None, dtype=None): 24 | factory_kwargs = dict(device=device, dtype=dtype) 25 | super().__init__() 26 | self.shape = shape 27 | grid = nn.Parameter(torch.empty(shape, **factory_kwargs)) 28 | self.interp = Interpolate(grid, d=len(shape) - 1, mode='nearest') 29 | self.reset_parameters() 30 | 31 | def reset_parameters(self) -> None: 32 | nn.init.uniform_(self.interp.input, -1e-4, 1e-4) 33 | 34 | def forward(self, coords, normalized=True): 35 | return self.interp(coords, normalized).permute(1, 0) 36 | 37 | 38 | class HashEncodingLevel(nn.Module): 39 | def __init__(self, shape, table_size, device=None, dtype=None): 40 | factory_kwargs = dict(device=device, dtype=dtype) 41 | super().__init__() 42 | self.shape = shape 43 | grid = nn.Parameter(torch.empty( 44 | (shape[0], table_size), **factory_kwargs)) 45 | hash_tensor = HashTensor(grid, shape) 46 | assert hash_tensor.shape == shape 47 | self.interp = Interpolate(hash_tensor, d=len(shape) - 1, mode=None) 48 | self.reset_parameters() 49 | 50 | def reset_parameters(self) -> None: 51 | nn.init.uniform_(self.interp.input.data, -1e-4, 1e-4) 52 | 53 | def forward(self, coords, normalized=True): 54 | return self.interp(coords, normalized).permute(1, 0) 55 | 56 | 57 | class MultiresEncoding(nn.Module): 58 | def __init__(self, 59 | nlevels: int = 16, 60 | features: int = 2, 61 | table_size: int = 2**18, 62 | minres: Shape = (16, 16, 16), 63 | maxres: Shape = (512, 512, 512), 64 | device=None, 65 | dtype=None,): 66 | super().__init__() 67 | factory_kwargs = dict(device=device, dtype=dtype) 68 | res_levels = _get_level_res_nd(nlevels, minres, maxres) 69 | level0 = DenseEncodingLevel( 70 | (features, *res_levels[0]), **factory_kwargs) 71 | levelN = (HashEncodingLevel((features, *l), table_size, **factory_kwargs) 72 | for l in res_levels[1:]) 73 | self.levels = nn.ModuleList([level0, *levelN]) 74 | self._maxres = torch.tensor(maxres, **factory_kwargs) 75 | self.encoding_size = nlevels * features 76 | 77 | def forward(self, coords, normalized=True): 78 | if not normalized: 79 | coords = coords / (self._maxres - 1) * 2 - 1 80 | # Look up features at each level/resolution 81 | features = [l(coords, True) for l in self.levels] 82 | return torch.cat(features, -1) 83 | 84 | 85 | @dataclass 86 | class MultiresEncodingConfig: 87 | nlevels: int = 16 88 | features: int = 2 89 | table_size: int = 2**22 90 | minres: Shape = (16, 16, 16) 91 | maxres: Shape = (1024, 1024, 1024) 92 | 93 | 94 | class MLP(nn.Sequential): 95 | def __init__(self, *features, activation=nn.ReLU()): 96 | l1, *ln = (nn.Linear(*f) 97 | for f in zip(features[:-1], features[1:])) 98 | activations = (activation for _ in range(len(ln))) 99 | super().__init__(l1, *(m for t in zip(activations, ln) for m in t)) 100 | 101 | 102 | # source: https://github.com/yashbhalgat/HashNeRF-pytorch/blob/a7d64eeb8844b57f5ba90463185a1506e2cbb4b8/hash_encoding.py#L75 103 | class ViewEncoding(nn.Module): 104 | # Spherical Harmonic Coefficients 105 | C0 = 0.28209479177387814 106 | C1 = 0.4886025119029199 107 | C2 = (1.0925484305920792, -1.0925484305920792, 108 | 0.31539156525252005, -1.0925484305920792, 109 | 0.5462742152960396,) 110 | C3 = (-0.5900435899266435, 2.890611442640554, 111 | -0.4570457994644658, 0.3731763325901154, 112 | -0.4570457994644658, 1.445305721320277, 113 | -0.5900435899266435,) 114 | C4 = (2.5033429417967046, -1.7701307697799304, 115 | 0.9461746957575601, -0.6690465435572892, 116 | 0.10578554691520431, -0.6690465435572892, 117 | 0.47308734787878004, -1.7701307697799304, 118 | 0.6258357354491761,) 119 | 120 | def __init__(self, degree=4): 121 | assert degree >= 1 and degree <= 5 122 | super().__init__() 123 | self.degree = degree 124 | self.encoding_size = degree ** 2 125 | 126 | def forward(self, input): 127 | result = torch.empty( 128 | (*input.shape[:-1], self.encoding_size), dtype=input.dtype, device=input.device) 129 | x, y, z = input.unbind(-1) 130 | 131 | result[..., 0] = self.C0 132 | if self.degree > 1: 133 | result[..., 1] = -self.C1 * y 134 | result[..., 2] = self.C1 * z 135 | result[..., 3] = -self.C1 * x 136 | if self.degree > 2: 137 | xx, yy, zz = x * x, y * y, z * z 138 | xy, yz, xz = x * y, y * z, x * z 139 | result[..., 4] = self.C2[0] * xy 140 | result[..., 5] = self.C2[1] * yz 141 | result[..., 6] = self.C2[2] * (2.0 * zz - xx - yy) 142 | result[..., 7] = self.C2[3] * xz 143 | result[..., 8] = self.C2[4] * (xx - yy) 144 | if self.degree > 3: 145 | result[..., 9] = self.C3[0] * y * (3 * xx - yy) 146 | result[..., 10] = self.C3[1] * xy * z 147 | result[..., 11] = self.C3[2] * y * (4 * zz - xx - yy) 148 | result[..., 12] = self.C3[3] * \ 149 | z * (2 * zz - 3 * xx - 3 * yy) 150 | result[..., 13] = self.C3[4] * x * (4 * zz - xx - yy) 151 | result[..., 14] = self.C3[5] * z * (xx - yy) 152 | result[..., 15] = self.C3[6] * x * (xx - 3 * yy) 153 | if self.degree > 4: 154 | result[..., 16] = self.C4[0] * xy * (xx - yy) 155 | result[..., 17] = self.C4[1] * yz * (3 * xx - yy) 156 | result[..., 18] = self.C4[2] * xy * (7 * zz - 1) 157 | result[..., 19] = self.C4[3] * yz * (7 * zz - 3) 158 | result[..., 20] = self.C4[4] * \ 159 | (zz * (35 * zz - 30) + 3) 160 | result[..., 21] = self.C4[5] * xz * (7 * zz - 3) 161 | result[..., 22] = self.C4[6] * (xx - yy) * (7 * zz - 1) 162 | result[..., 23] = self.C4[7] * xz * (xx - 3 * yy) 163 | result[..., 24] = self.C4[8] * \ 164 | (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) 165 | return result 166 | 167 | 168 | class MultiresHashNeRF(nn.Module): 169 | 170 | def __init__(self, mlp_width=64, color_channels=3, 171 | view_encoding_degree=4, 172 | multires_encoding_config=MultiresEncodingConfig()): 173 | self.position_encoding = MultiresEncoding( 174 | **vars(multires_encoding_config)) 175 | self.view_encoding = ViewEncoding(view_encoding_degree) 176 | self.feature_mlp = MLP( 177 | self.position_encoding.encoding_size, mlp_width, mlp_width) 178 | 179 | self.rgb_mlp = MLP( 180 | mlp_width + self.view_encoding.encoding_size, mlp_width, color_channels) 181 | 182 | def forward(self, x): 183 | input_pts, input_views = torch.split( 184 | x, [3, 3], dim=-1) 185 | h = self.position_encoder(input_pts) 186 | h = self.feature_mlp(h) 187 | sigma = h[..., 0] 188 | h = torch.cat([h, self.view_encoder(input_views)], -1) 189 | rgb = self.rgb_mlp(h) 190 | return rgb, sigma 191 | -------------------------------------------------------------------------------- /tests/test_hash_tensor.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from multires_hash_encoding.hash_tensor import ( 4 | HashTensor, 5 | ) 6 | 7 | 8 | def test_HashArray(): 9 | data = torch.ones((1,)).reshape(1, 1) 10 | 11 | # 1D 12 | ha = HashTensor(data, (1, 2)) 13 | assert ha.shape == (1, 2) 14 | assert ha.ndim == 2 15 | assert ha[:, 0] == 1 16 | 17 | # 2D 18 | ha = HashTensor(data, (1, 2, 2)) 19 | assert ha[:, 0, 0] == 1 20 | 21 | # 3D 22 | ha = HashTensor(data, (1, 2, 2, 2)) 23 | assert ha[:, 0, 0, 0] == 1 24 | -------------------------------------------------------------------------------- /tests/test_interpolate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.ndimage import map_coordinates 3 | import torch 4 | from multires_hash_encoding.interpolate import ( 5 | nd_linear_interp, 6 | Interpolate, 7 | ) 8 | 9 | 10 | def test_bilinear_interpolate(): 11 | coords = torch.rand(4, 2) 12 | arr = torch.ones((1, 2, 2)) 13 | out = nd_linear_interp(arr, coords) 14 | assert torch.allclose(out, torch.ones_like(out)) 15 | 16 | arr = torch.tensor([[1., 2.], [1., 2.]]).reshape(1, 2, 2) 17 | coords = torch.tensor([[.5, .5]]) 18 | 19 | assert torch.allclose(nd_linear_interp(arr, coords), torch.tensor([1.5])) 20 | 21 | coords = torch.tensor([[.5, 0]]) 22 | assert torch.allclose(nd_linear_interp(arr, coords), torch.tensor([1.0])) 23 | 24 | 25 | def to_np_coords(coords): 26 | return tuple(c.numpy() for c in coords.unbind(-1)) 27 | 28 | 29 | def test_map_coordinates(): 30 | 31 | img = torch.rand(1, 10, 10) 32 | coords = torch.tensor([[5., 5.]]) 33 | 34 | assert np.allclose( 35 | map_coordinates(img[0].numpy(), to_np_coords( 36 | coords), order=1, mode='nearest'), 37 | nd_linear_interp(img, coords, mode='nearest').numpy()) 38 | 39 | num_coords = 5 40 | for shape in ((2, 10), (3, 10, 12), (4, 10, 12, 14)): 41 | for order in [1]: 42 | for mode in ['nearest']: 43 | print(f'Shape: {shape}') 44 | n = len(shape) - 1 45 | signal = torch.rand(*shape) 46 | coords = torch.rand(num_coords, n) * (10 - 1) 47 | result = nd_linear_interp(signal, coords, mode) 48 | assert result.shape == (shape[0], num_coords) 49 | target1 = map_coordinates( 50 | signal[0, ...].numpy(), to_np_coords(coords), order=order, mode=mode, prefilter=False) 51 | target2 = map_coordinates( 52 | signal[1, ...].numpy(), to_np_coords(coords), order=order, mode=mode, prefilter=False) 53 | assert np.allclose(result[0, :].numpy(), target1) and np.allclose( 54 | result[1, :].numpy(), target2) 55 | 56 | 57 | def test_Interpolate(): 58 | 59 | img = torch.tensor([[1, 2, 3], [1, 2, 3]]).reshape(1, 2, 3) 60 | coords = torch.tensor([[0, .5]]) 61 | interp = Interpolate(img, d=2, order=1, mode='nearest') 62 | out = interp(coords, normalized=False) 63 | assert torch.allclose(out, torch.tensor([1.5])) 64 | 65 | out = interp(coords, normalized=True) 66 | assert torch.allclose(out, torch.tensor([2.5])) 67 | -------------------------------------------------------------------------------- /tests/test_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from multires_hash_encoding.modules import ( 5 | DenseEncodingLevel, 6 | HashEncodingLevel, 7 | MultiresEncoding, 8 | MLP, 9 | ViewEncoding, 10 | MultiresEncodingConfig, 11 | #MultiResHashNeRF, 12 | ) 13 | 14 | 15 | def test_DenseEncodingLevel(): 16 | l = DenseEncodingLevel((2, 8, 10, 12)) 17 | nn.init.ones_(l.interp.input) 18 | x = torch.zeros(4, 3) 19 | out = l(x) 20 | assert out.shape == (4, 2) 21 | assert torch.allclose(out, torch.ones_like(out)) 22 | 23 | 24 | def test_HashEncodingLevel(): 25 | l = HashEncodingLevel((2, 8, 10, 12), table_size=1) 26 | nn.init.ones_(l.interp.input.data) 27 | x = torch.zeros(4, 3) 28 | out = l(x) 29 | assert out.shape == (4, 2) 30 | assert torch.allclose(out, torch.ones_like(out)) 31 | 32 | 33 | def test_MultiresEncodingLayer(): 34 | config = MultiresEncodingConfig() 35 | model = MultiresEncoding(**vars(config)) 36 | nn.init.ones_(model.levels[0].interp.input) 37 | for l in model.levels[1:]: 38 | nn.init.ones_(l.interp.input.data) 39 | x = torch.zeros(4, 3) 40 | out = model(x, normalized=True) 41 | assert out.shape == (4, config.features * config.nlevels) 42 | assert torch.allclose(out, torch.ones_like(out)) 43 | 44 | 45 | def test_MLP(): 46 | model = MLP(3, 16, 4) 47 | x = torch.rand(10, 3) 48 | out = model(x) 49 | assert out.shape == (10, 4) 50 | 51 | 52 | def test_ViewEncoding(): 53 | model = ViewEncoding(2) 54 | x = torch.rand(10, 3) 55 | out = model(x) 56 | assert out.shape == (10, 4) 57 | --------------------------------------------------------------------------------