├── LICENSE ├── README.md ├── setup.py ├── src ├── neural_hash_encoding │ ├── __init__.py │ ├── hash_array.py │ ├── interpolate.py │ └── model.py └── train_image.py └── viz ├── viz_step_1000_psnr_30.1.png ├── viz_step_100_psnr_25.0.png └── viz_step_10_psnr_15.3.png /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Penn Jenks 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Hash Encoding 2 | 3 | This is a work in progress reimplementation of [Instant Neural Graphics Primitives](https://github.com/NVlabs/instant-ngp) 4 | Currently this can train an implicit representation of a gigapixel image using a multires hash encoding. 5 | 6 | FYI: This is brand new -- most parameters in the training script are hard coded right now 7 | 8 | Check out results in [viz](./viz) 9 | 10 | ## Setup: 11 | 12 | Download [the Tokyo image](https://www.flickr.com/photos/trevor_dobson_inefekt69/29314390837) 13 | 14 | ```bash 15 | wget -O tokyo.jpg https://live.staticflickr.com/1859/29314390837_b39ae4876e_o_d.jpg 16 | ``` 17 | 18 | Convert to numpy binary format for faster reading (1s w/ .npz vs 14s with .jpg) 19 | 20 | ```python 21 | from PIL import Image 22 | Image.MAX_IMAGE_PIXELS = 10**10 23 | 24 | img = np.asarray(Image.open("tokyo.jpg")) # Abount 3.5 gb 25 | np.save("tokyo.npy", img) 26 | ``` 27 | 28 | ## Train: 29 | 30 | ```bash 31 | python src/train_image.py 32 | ``` 33 | 34 | # Implementation Notes (From the Paper) 35 | 36 | ### Architecture 37 | 38 | > In all tasks, except for NeRF which we will 39 | > describe later, we use an MLP with two hidden layers that have 40 | > a width of 64 neurons and rectified linear unit (ReLU) 41 | 42 | ### 4. Initialization 43 | 44 | - Initialize hash table entries with uniform distribution [-1e-4, 1e-4] 45 | 46 | ### 4. Training 47 | 48 | - Optimizer 49 | - Adam: β1 = 0.9, β2 = 0.99, ϵ = 1e−15 50 | - Learning rate: 1e-2 ([ref: tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn/blob/master/samples/mlp_learning_an_image.cu#L130)) 51 | - Regularization: 52 | - L2: 10e-6 Applied to the MLP weigths not the hash table weights 53 | 54 | > we skip Adam steps for hash table entries whose gradient 55 | > is exactly 0. This saves ∼10% performance when gradients are sparse 56 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="neural-hash-encoding", 8 | version="0.0.1", 9 | author="Penn Jenks", 10 | author_email="jenkspt@gmail.com", 11 | description="", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="", 15 | project_urls={ 16 | "Bug Tracker": "", 17 | }, 18 | classifiers=[ 19 | "Programming Language :: Python :: 3", 20 | "License :: OSI Approved :: MIT License", 21 | "Operating System :: OS Independent", 22 | ], 23 | package_dir={"": "src"}, 24 | packages=setuptools.find_packages(where="src"), 25 | python_requires=">=3.6", 26 | install_requires=[ 27 | 'jax', 28 | 'jaxlib', 29 | 'jaxopt', 30 | 'flax', 31 | 'optax', 32 | ] 33 | ) -------------------------------------------------------------------------------- /src/neural_hash_encoding/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jenkspt/neural_hash_encoding/f874ea9a743b5d5112c2963594d44404421018ab/src/neural_hash_encoding/__init__.py -------------------------------------------------------------------------------- /src/neural_hash_encoding/hash_array.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Any, Iterable 2 | from functools import reduce 3 | import operator 4 | import numpy as np 5 | from dataclasses import dataclass, field 6 | import jax.numpy as jnp 7 | from jax.tree_util import register_pytree_node_class 8 | 9 | """ 10 | Mutli-res Grid code from NVIDIA Paper: 11 | https://github.com/NVlabs/tiny-cuda-nn/blob/d0639158dd64b5d146d659eb881702304abaaa24/include/tiny-cuda-nn/encodings/grid.h 12 | """ 13 | 14 | Shape = Iterable[int] 15 | Dtype = Any # this could be a real type? 16 | Array = Any 17 | 18 | PRIMES = (1, 73856093, 19349663, 83492791) 19 | 20 | 21 | @register_pytree_node_class 22 | @dataclass 23 | class HashArray: 24 | """ 25 | This is a sparse array backed by simple hash table. It minimally implements an array 26 | interface as to be used for (nd) linear interpolation. 27 | There is no collision resolution or even bounds checking. 28 | 29 | Attributes: 30 | data: The hash table represented as a 2D array. 31 | First dim is indexed with the hash index and second dim is the feature 32 | shape: The shape of the array. 33 | 34 | NVIDIA Implementation of multi-res hash grid: 35 | https://github.com/NVlabs/tiny-cuda-nn/blob/master/include/tiny-cuda-nn/encodings/grid.h#L66-L80 36 | """ 37 | data: Array 38 | shape: Shape 39 | 40 | def __post_init__(self): 41 | assert self.data.ndim == 2, "Hash table data should be 2d" 42 | assert self.data.shape[1] == self.shape[-1] 43 | 44 | @property 45 | def ndim(self): 46 | return len(self.shape) 47 | 48 | @property 49 | def dtype(self): 50 | return self.data.dtype 51 | 52 | def spatial_hash(self, coords): 53 | assert len(coords) <= len(PRIMES), "Add more PRIMES!" 54 | if len(coords) == 1: 55 | i = (coords[0] ^ PRIMES[1]) 56 | else: 57 | i = reduce(operator.xor, (c * p for c, p in zip(coords, PRIMES))) 58 | return i % self.data.shape[0] 59 | 60 | def __getitem__(self, i): 61 | *spatial_i, feature_i = i if len(i) == self.ndim else (*i, Ellipsis) 62 | i = self.spatial_hash(spatial_i) 63 | return self.data[i, feature_i] 64 | 65 | def __array__(self, dtype=None): 66 | H, W, _ = self.shape 67 | y, x = jnp.mgrid[0:H:1, 0:W:1] 68 | arr = self[y, x, :].__array__(dtype) 69 | return arr 70 | 71 | def __repr__(self): 72 | return "HashArray(" + str(np.asarray(self)) + ")" 73 | 74 | def tree_flatten(self): 75 | return (self.data, self.shape) 76 | 77 | @classmethod 78 | def tree_unflatten(cls, shape, data): 79 | return cls(data, shape) 80 | 81 | 82 | def growth_factor(levels: int, minres: int, maxres: int): 83 | return np.exp((np.log(maxres) - np.log(minres)) / (levels - 1)) 84 | 85 | 86 | def _get_level_res(levels: int, minres: int, maxres: int): 87 | b = growth_factor(levels, minres, maxres) 88 | res = [int(round(minres * (b ** l))) for l in range(0, levels)] 89 | return res 90 | 91 | 92 | def _get_level_res_nd(levels: int, minres: Tuple[int, ...], maxres: Tuple[int, ...]): 93 | it = (_get_level_res(levels, _min, _max) \ 94 | for _min, _max in zip(minres, maxres)) 95 | return list(zip(*it)) -------------------------------------------------------------------------------- /src/neural_hash_encoding/interpolate.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Sequence 2 | from dataclasses import dataclass 3 | import jax.numpy as jnp 4 | from jax.tree_util import register_pytree_node_class 5 | 6 | import operator 7 | import itertools 8 | import functools 9 | from jax._src.scipy.ndimage import ( 10 | _nonempty_prod, 11 | _nonempty_sum, 12 | _INDEX_FIXERS, 13 | _round_half_away_from_zero, 14 | _nearest_indices_and_weights, 15 | _linear_indices_and_weights, 16 | ) 17 | 18 | Array = Any 19 | 20 | 21 | def bilinear_interpolate(arr, x, y, clip_to_bounds=False): 22 | assert len(arr.shape) == 3 23 | H, W, _ = arr.shape 24 | 25 | x = jnp.asarray(x) 26 | y = jnp.asarray(y) 27 | 28 | x0 = jnp.floor(x).astype(int) 29 | x1 = x0 + 1 30 | y0 = jnp.floor(y).astype(int) 31 | y1 = y0 + 1 32 | 33 | if clip_to_bounds: 34 | x0 = jnp.clip(x0, 0, W-1) 35 | x1 = jnp.clip(x1, 0, W-1) 36 | y0 = jnp.clip(y0, 0, H-1) 37 | y1 = jnp.clip(y1, 0, H-1) 38 | 39 | Ia = arr[y0, x0, :] 40 | Ib = arr[y1, x0, :] 41 | Ic = arr[y0, x1, :] 42 | Id = arr[y1, x1, :] 43 | 44 | wa = ((x1-x) * (y1-y))[..., None] 45 | wb = ((x1-x) * (y-y0))[..., None] 46 | wc = ((x-x0) * (y1-y))[..., None] 47 | wd = ((x-x0) * (y-y0))[..., None] 48 | 49 | return wa*Ia + wb*Ib + wc*Ic + wd*Id 50 | 51 | 52 | def map_coordinates(input, coordinates, order, mode='constant', cval=0.0): 53 | """ 54 | Adapted from jax.scipy.map_coordinates, but with a few key differences. 55 | 56 | 1.) interpolations are always broadcasted along the last dimension of the `input` 57 | i.e. a 3 channel rgb image with shape [H, W, 3] will be interpolated with 2d 58 | coordinates and broadcasted across the channel dimension 59 | 60 | 2.) `input` isn't required to be jax `DeviceArray` -- it can be any type that 61 | supports numpy fancy indexing 62 | 63 | Note on interpolation: `map_coordinates` indexes in the order of the axes, 64 | so for an image it indexes the coordinates as [y, x] 65 | """ 66 | 67 | coordinates = [jnp.asarray(c) for c in coordinates] 68 | cval = jnp.asarray(cval, input.dtype) 69 | 70 | if len(coordinates) != input.ndim-1: 71 | raise ValueError('coordinates must be a sequence of length input.ndim - 1, but ' 72 | '{} != {}'.format(len(coordinates), input.ndim - 1)) 73 | 74 | index_fixer = _INDEX_FIXERS.get(mode) 75 | if index_fixer is None: 76 | raise NotImplementedError( 77 | 'map_coordinates does not support mode {}. ' 78 | 'Currently supported modes are {}.'.format(mode, set(_INDEX_FIXERS))) 79 | 80 | if mode == 'constant': 81 | is_valid = lambda index, size: (0 <= index) & (index < size) 82 | else: 83 | is_valid = lambda index, size: True 84 | 85 | if order == 0: 86 | interp_fun = _nearest_indices_and_weights 87 | elif order == 1: 88 | interp_fun = _linear_indices_and_weights 89 | else: 90 | raise NotImplementedError( 91 | 'map_coordinates currently requires order<=1') 92 | 93 | valid_1d_interpolations = [] 94 | for coordinate, size in zip(coordinates, input.shape[:-1]): 95 | interp_nodes = interp_fun(coordinate) 96 | valid_interp = [] 97 | for index, weight in interp_nodes: 98 | fixed_index = index_fixer(index, size) 99 | valid = is_valid(index, size) 100 | valid_interp.append((fixed_index, valid, weight)) 101 | valid_1d_interpolations.append(valid_interp) 102 | 103 | outputs = [] 104 | for items in itertools.product(*valid_1d_interpolations): 105 | indices, validities, weights = zip(*items) 106 | if all(valid is True for valid in validities): 107 | # fast path 108 | contribution = input[(*indices, Ellipsis)] 109 | else: 110 | all_valid = functools.reduce(operator.and_, validities) 111 | contribution = jnp.where(all_valid[..., None], input[(*indices, Ellipsis)], cval) 112 | outputs.append(_nonempty_prod(weights)[..., None] * contribution) 113 | 114 | result = _nonempty_sum(outputs) 115 | if jnp.issubdtype(input.dtype, jnp.integer): 116 | result = _round_half_away_from_zero(result) 117 | return result.astype(input.dtype) 118 | 119 | 120 | @dataclass 121 | @register_pytree_node_class 122 | class Interpolate: 123 | arr: Array 124 | order: int 125 | mode: str 126 | cval: float = 0.0 127 | 128 | def __call__(self, coords, normalized=True): 129 | coords = [jnp.asarray(c) for c in coords] 130 | assert len(coords) == (self.arr.ndim - 1) 131 | if normalized: 132 | # un-normalize 133 | coords = [c * (s-1) for c, s in zip(coords, self.arr.shape)] 134 | return map_coordinates(self.arr, coords, order=self.order, mode=self.mode, cval=self.cval) 135 | 136 | def tree_flatten(self): 137 | return (self.arr, None) 138 | 139 | @classmethod 140 | def tree_unflatten(cls, aux_data, data): 141 | return cls(data) -------------------------------------------------------------------------------- /src/neural_hash_encoding/model.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Callable, Any, Iterable 2 | from dataclasses import field 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | from flax import linen as nn 7 | 8 | from neural_hash_encoding.hash_array import HashArray, _get_level_res_nd 9 | from neural_hash_encoding.interpolate import Interpolate 10 | 11 | # Copied from flax 12 | PRNGKey = Any 13 | Shape = Iterable[int] 14 | Dtype = Any # this could be a real type? 15 | Array = Any 16 | 17 | def uniform_init(minval=0, maxval=0.01, dtype=jnp.float64): 18 | def init(key, shape, dtype=dtype): 19 | return jax.random.uniform(key, shape, dtype, minval, maxval) 20 | return init 21 | 22 | 23 | class DenseEncodingLevel(nn.Module): 24 | res: Shape 25 | features: int = 2 26 | dtype: Dtype = jnp.float32 27 | param_dtype: Dtype = jnp.float32 28 | table_init: Callable[[PRNGKey, Shape, Dtype], Array] = uniform_init(-1e-4, 1e-4) 29 | 30 | interp: Interpolate = field(init=False) 31 | 32 | def setup(self): 33 | array = self.param('table', 34 | self.table_init, 35 | (*self.res, self.features), 36 | self.param_dtype) 37 | self.interp = Interpolate(jnp.asarray(array), order=1, mode='nearest') 38 | 39 | def __call__(self, coords): 40 | assert len(coords) == (self.interp.arr.ndim - 1) 41 | return self.interp(coords, normalized=True) 42 | 43 | 44 | class HashEncodingLevel(nn.Module): 45 | res: Shape 46 | features: int = 2 47 | table_size: int = 2**14 48 | dtype: Dtype = jnp.float32 49 | param_dtype: Dtype = jnp.float32 50 | table_init: Callable[[PRNGKey, Shape, Dtype], Array] = uniform_init(-1e-4, 1e-4) 51 | 52 | interp: Interpolate = field(init=False) 53 | 54 | def setup(self): 55 | table = self.param('table', 56 | self.table_init, 57 | (self.table_size, self.features), 58 | self.param_dtype) 59 | shape = (*self.res, self.features) 60 | array = HashArray(jnp.asarray(table), shape) 61 | self.interp = Interpolate(array, order=1, mode='nearest') 62 | 63 | def __call__(self, coords): 64 | assert len(coords) == (self.interp.arr.ndim - 1) 65 | return self.interp(coords, normalized=True) 66 | 67 | 68 | class MultiResEncoding(nn.Module): 69 | levels: int=16 70 | table_size: int = 2**14 71 | features: int = 2 72 | minres: Shape = (16, 16) 73 | maxres: Shape = (512, 512) 74 | dtype: Dtype = jnp.float32 75 | param_dtype: Dtype = jnp.float32 76 | param_init: Callable[[PRNGKey, Shape, Dtype], Array] = uniform_init(-1e-4, 1e-4) 77 | 78 | L: Tuple[nn.Module, ...] = field(init=False) 79 | 80 | def setup(self): 81 | res_levels = _get_level_res_nd(self.levels, self.minres, self.maxres) 82 | kwargs = dict( 83 | features=self.features, dtype=self.dtype, 84 | param_dtype=self.param_dtype, table_init=self.param_init) 85 | # First level is always dense 86 | L0 = DenseEncodingLevel(res_levels[0], **kwargs) 87 | # Rest are sparse hash arrays 88 | self.L = tuple([L0, *(HashEncodingLevel(l, table_size=self.table_size, **kwargs) for l in res_levels[1:])]) 89 | 90 | def __call__(self, coords): 91 | features = [l(coords) for l in self.L] 92 | features = jnp.concatenate(features, -1) 93 | return features 94 | 95 | 96 | class MLP(nn.Module): 97 | features: Tuple[int, ...] = (64, 64, 3) 98 | 99 | @nn.compact 100 | def __call__(self, x): 101 | assert len(self.features) >= 2 102 | *hidden, linear = self.features 103 | for h in hidden: 104 | x = nn.relu(nn.Dense(h)(x)) 105 | x = nn.Dense(linear)(x) 106 | return x 107 | 108 | 109 | class ImageModel(nn.Module): 110 | res: Shape 111 | channels: int=3 112 | levels: int=16 113 | table_size: int = 2**14 114 | features: int = 2 115 | minres: Shape = (16, 16) 116 | 117 | def setup(self): 118 | self.embedding = MultiResEncoding(self.levels, self.table_size, 119 | self.features, self.minres, self.res) 120 | self.decoder = MLP((64, 64, self.channels)) 121 | 122 | def __call__(self, coords): 123 | features = self.embedding(coords) 124 | color = self.decoder(features) 125 | return color -------------------------------------------------------------------------------- /src/train_image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | import flax 7 | from flax import linen as nn 8 | from flax.training import train_state 9 | import optax 10 | from jaxopt.tree_util import tree_l2_norm 11 | 12 | from torch.utils.data import IterableDataset, DataLoader 13 | 14 | import matplotlib.pyplot as plt 15 | 16 | from neural_hash_encoding.model import ImageModel 17 | 18 | 19 | class RandomPixelData(IterableDataset): 20 | def __init__(self, img, batch_size): 21 | self.img = img 22 | H, W, _ = img.shape 23 | self.batch_size = batch_size 24 | num_complete_batches, leftover = divmod(H * W, batch_size) 25 | self._len = num_complete_batches + bool(leftover) 26 | 27 | def __len__(self): 28 | return self._len 29 | 30 | def __iter__(self): 31 | for _ in range(len(self)): 32 | x = np.random.randint(0, W-1, self.batch_size) 33 | y = np.random.randint(0, H-1, self.batch_size) 34 | rgb = img[y, x, :] 35 | # Normalize coordinates to [0, 1] 36 | yx = [y / (H-1), x / (W-1)] 37 | yield yx, rgb / 255 38 | 39 | 40 | class RandomPixelLoader(DataLoader): 41 | def __init__(self, dataset, *args, **kwargs): 42 | super(self.__class__, self).__init__(dataset, 43 | *args, 44 | batch_size=None, 45 | shuffle=False, 46 | sampler=None, 47 | batch_sampler=None, 48 | collate_fn=lambda batch: batch, 49 | pin_memory=False, 50 | drop_last=False, 51 | **kwargs) 52 | 53 | 54 | def PSNR(mse): 55 | return -10.0 * jnp.log(mse) / jnp.log(10.0) 56 | 57 | 58 | def mse_loss(preds, targets): 59 | return jnp.mean((targets - preds)**2) 60 | 61 | 62 | def relative_mse(preds, targets): 63 | return jnp.mean((targets - preds)**2 / (preds + 0.01)**2) 64 | 65 | 66 | def l2_loss(params, l2=1e-6): 67 | sqnorm = tree_l2_norm(params, squared=True) 68 | return .5 * l2 * sqnorm 69 | 70 | 71 | if __name__ == "__main__": 72 | img = np.load("tokyo.npy") 73 | img.flags.writeable = False # be safe! 74 | H, W, C = img.shape 75 | print(f"Image shape: {img.shape}") 76 | table_size = 2**22 77 | 78 | def create_train_state(rng, learning_rate): 79 | """Creates initial `TrainState`.""" 80 | image_model = ImageModel(res=(H, W), table_size=table_size) 81 | x = jnp.ones((2, 1)) # Dummy data 82 | params = image_model.init(rng, x)['params'] 83 | tx = optax.adamw(learning_rate, b1=.9, b2=.99, eps=1e-10) 84 | return train_state.TrainState.create( 85 | apply_fn=image_model.apply, params=params, tx=tx) 86 | 87 | @jax.jit 88 | def train_step(state, batch): 89 | """Train for a single step.""" 90 | yx, colors_targ = batch 91 | 92 | def loss_fn(params, weight_decay=1e-6): 93 | colors_pred = ImageModel((H, W), table_size=table_size).apply({'params': params}, yx) 94 | mlp_params = params['decoder'] 95 | loss = mse_loss(colors_pred, colors_targ) + l2_loss(mlp_params, weight_decay) 96 | return loss, colors_pred 97 | 98 | grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 99 | (loss, _), grads = grad_fn(state.params) 100 | state = state.apply_gradients(grads=grads) 101 | metrics = {'loss': loss, 'psnr': PSNR(loss)} 102 | return state, metrics 103 | 104 | 105 | def write_region_plot(path, params, img, s=np.s_[10000:12048, 20000:22048, :]): 106 | assert len(s) == 3 107 | crop = img[s] 108 | 109 | yx = jnp.mgrid[s[:2]].reshape(2, -1) / jnp.array([H-1, W-1]).reshape(2, 1) 110 | rgb = ImageModel((H, W), table_size=table_size).apply({'params': params}, yx) 111 | crop2 = (rgb.reshape(*crop.shape) * 255).round(0).clip(0, 255).astype(np.uint8) 112 | 113 | fig, axs = plt.subplots(1, 2, figsize=(16, 12)) 114 | axs[0].imshow(crop) 115 | axs[0].set_title('Reference') 116 | axs[1].imshow(crop2) 117 | axs[1].set_title(f'Encoding') 118 | fig.tight_layout() 119 | fig.savefig(path) 120 | 121 | 122 | print("Creating train state ... ") 123 | rng = jax.random.PRNGKey(420) 124 | rng, init_rng = jax.random.split(rng) 125 | learning_rate = 1e-2 126 | state = create_train_state(init_rng, learning_rate) 127 | del init_rng 128 | 129 | print("Training ... ") 130 | epochs = 1 131 | batch_size = 2**22 # number of pixels 132 | ds = RandomPixelData(img, batch_size) 133 | loader = RandomPixelLoader(ds, num_workers=7) 134 | for epoch in range(epochs): 135 | for i, batch in enumerate(loader): 136 | step = epoch * len(ds) + i 137 | state, metrics = train_step(state, batch) 138 | loss, psnr = metrics['loss'], metrics['psnr'] 139 | if step > 1 and (np.log10(step) == int(np.log10(step))): # exponential logging 140 | path = f'viz/viz_step_{step}_psnr_{psnr:.1f}.png' 141 | write_region_plot(path, state.params, img) 142 | if step % 100 == 0: 143 | print(f'step: {step}, loss (mse): {loss:.4f}, psnr: {psnr:.4f}') -------------------------------------------------------------------------------- /viz/viz_step_1000_psnr_30.1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jenkspt/neural_hash_encoding/f874ea9a743b5d5112c2963594d44404421018ab/viz/viz_step_1000_psnr_30.1.png -------------------------------------------------------------------------------- /viz/viz_step_100_psnr_25.0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jenkspt/neural_hash_encoding/f874ea9a743b5d5112c2963594d44404421018ab/viz/viz_step_100_psnr_25.0.png -------------------------------------------------------------------------------- /viz/viz_step_10_psnr_15.3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jenkspt/neural_hash_encoding/f874ea9a743b5d5112c2963594d44404421018ab/viz/viz_step_10_psnr_15.3.png --------------------------------------------------------------------------------